Add subcription feature
This commit is contained in:
@@ -2,10 +2,11 @@
|
||||
Subscription management API for GDPR-compliant cancellation and reactivation
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path, Body
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import UUID
|
||||
from typing import Optional, Dict, Any, List
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
@@ -17,6 +18,8 @@ from app.core.database import get_db
|
||||
from app.models.tenants import Subscription, Tenant
|
||||
from app.services.subscription_limit_service import SubscriptionLimitService
|
||||
from app.services.subscription_service import SubscriptionService
|
||||
from app.services.subscription_orchestration_service import SubscriptionOrchestrationService
|
||||
from app.services.payment_service import PaymentService
|
||||
from shared.clients.stripe_client import StripeProvider
|
||||
from app.core.config import settings
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
@@ -134,9 +137,9 @@ async def cancel_subscription(
|
||||
5. Gateway enforces read-only mode for 'pending_cancellation' and 'inactive' statuses
|
||||
"""
|
||||
try:
|
||||
# Use service layer instead of direct database access
|
||||
subscription_service = SubscriptionService(db)
|
||||
result = await subscription_service.cancel_subscription(
|
||||
# Use orchestration service for complete workflow
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
result = await orchestration_service.orchestrate_subscription_cancellation(
|
||||
request.tenant_id,
|
||||
request.reason
|
||||
)
|
||||
@@ -195,9 +198,9 @@ async def reactivate_subscription(
|
||||
- inactive (after effective date)
|
||||
"""
|
||||
try:
|
||||
# Use service layer instead of direct database access
|
||||
subscription_service = SubscriptionService(db)
|
||||
result = await subscription_service.reactivate_subscription(
|
||||
# Use orchestration service for complete workflow
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
result = await orchestration_service.orchestrate_subscription_reactivation(
|
||||
request.tenant_id,
|
||||
request.plan
|
||||
)
|
||||
@@ -296,9 +299,10 @@ async def get_tenant_invoices(
|
||||
Get invoice history for a tenant from Stripe
|
||||
"""
|
||||
try:
|
||||
# Use service layer instead of direct database access
|
||||
# Use service layer for invoice retrieval
|
||||
subscription_service = SubscriptionService(db)
|
||||
invoices_data = await subscription_service.get_tenant_invoices(tenant_id)
|
||||
payment_service = PaymentService()
|
||||
invoices_data = await subscription_service.get_tenant_invoices(tenant_id, payment_service)
|
||||
|
||||
# Transform to response format
|
||||
invoices = []
|
||||
@@ -592,14 +596,25 @@ async def validate_plan_upgrade(
|
||||
async def upgrade_subscription_plan(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
new_plan: str = Query(..., description="New plan name"),
|
||||
billing_cycle: Optional[str] = Query(None, description="Billing cycle (monthly/yearly)"),
|
||||
immediate_change: bool = Query(True, description="Apply change immediately"),
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
limit_service: SubscriptionLimitService = Depends(get_subscription_limit_service),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Upgrade subscription plan for a tenant"""
|
||||
"""
|
||||
Upgrade subscription plan for a tenant.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the upgrade is allowed
|
||||
2. Calculates proration costs
|
||||
3. Updates subscription in Stripe
|
||||
4. Updates local database
|
||||
5. Invalidates caches and tokens
|
||||
"""
|
||||
|
||||
try:
|
||||
# First validate the upgrade
|
||||
# Step 1: Validate the upgrade
|
||||
validation = await limit_service.validate_plan_upgrade(str(tenant_id), new_plan)
|
||||
if not validation.get("can_upgrade", False):
|
||||
raise HTTPException(
|
||||
@@ -607,10 +622,8 @@ async def upgrade_subscription_plan(
|
||||
detail=validation.get("reason", "Cannot upgrade to this plan")
|
||||
)
|
||||
|
||||
# Use SubscriptionService for the upgrade
|
||||
# Step 2: Get current subscription to determine billing cycle
|
||||
subscription_service = SubscriptionService(db)
|
||||
|
||||
# Get current subscription
|
||||
current_subscription = await subscription_service.get_subscription_by_tenant_id(tenant_id)
|
||||
if not current_subscription:
|
||||
raise HTTPException(
|
||||
@@ -618,19 +631,23 @@ async def upgrade_subscription_plan(
|
||||
detail="No active subscription found for this tenant"
|
||||
)
|
||||
|
||||
# Update the subscription plan using service layer
|
||||
# Note: This should be enhanced in SubscriptionService to handle plan upgrades
|
||||
# For now, we'll use the repository directly but this should be moved to service layer
|
||||
from app.repositories.subscription_repository import SubscriptionRepository
|
||||
from app.models.tenants import Subscription as SubscriptionModel
|
||||
|
||||
subscription_repo = SubscriptionRepository(SubscriptionModel, db)
|
||||
updated_subscription = await subscription_repo.update_subscription_plan(
|
||||
str(current_subscription.id),
|
||||
new_plan
|
||||
# Use current billing cycle if not provided
|
||||
if not billing_cycle:
|
||||
billing_cycle = current_subscription.billing_interval or "monthly"
|
||||
|
||||
# Step 3: Use orchestration service for the upgrade
|
||||
from app.services.subscription_orchestration_service import SubscriptionOrchestrationService
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
|
||||
upgrade_result = await orchestration_service.orchestrate_plan_upgrade(
|
||||
tenant_id=str(tenant_id),
|
||||
new_plan=new_plan,
|
||||
proration_behavior="create_prorations",
|
||||
immediate_change=immediate_change,
|
||||
billing_cycle=billing_cycle
|
||||
)
|
||||
|
||||
# Invalidate subscription cache to ensure immediate availability of new tier
|
||||
# Step 4: Invalidate subscription cache
|
||||
try:
|
||||
from app.services.subscription_cache import get_subscription_cache_service
|
||||
import shared.redis_utils
|
||||
@@ -647,8 +664,7 @@ async def upgrade_subscription_plan(
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(cache_error))
|
||||
|
||||
# SECURITY: Invalidate all existing tokens for this tenant
|
||||
# Forces users to re-authenticate and get new JWT with updated tier
|
||||
# Step 5: Invalidate all existing tokens for this tenant
|
||||
try:
|
||||
redis_client = await get_redis_client()
|
||||
if redis_client:
|
||||
@@ -656,7 +672,7 @@ async def upgrade_subscription_plan(
|
||||
await redis_client.set(
|
||||
f"tenant:{tenant_id}:subscription_changed_at",
|
||||
str(changed_timestamp),
|
||||
ex=86400 # 24 hour TTL
|
||||
ex=86400
|
||||
)
|
||||
logger.info("Set subscription change timestamp for token invalidation",
|
||||
tenant_id=tenant_id,
|
||||
@@ -666,7 +682,7 @@ async def upgrade_subscription_plan(
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(token_error))
|
||||
|
||||
# Also publish event for real-time notification
|
||||
# Step 6: Publish event for real-time notification
|
||||
try:
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
event_publisher = UnifiedEventPublisher()
|
||||
@@ -693,9 +709,9 @@ async def upgrade_subscription_plan(
|
||||
"message": f"Plan successfully upgraded to {new_plan}",
|
||||
"old_plan": current_subscription.plan,
|
||||
"new_plan": new_plan,
|
||||
"new_monthly_price": updated_subscription.monthly_price,
|
||||
"proration_details": upgrade_result.get("proration_details"),
|
||||
"validation": validation,
|
||||
"requires_token_refresh": True # Signal to frontend
|
||||
"requires_token_refresh": True
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
@@ -707,16 +723,130 @@ async def upgrade_subscription_plan(
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to upgrade subscription plan"
|
||||
detail=f"Failed to upgrade subscription plan: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/{tenant_id}/change-billing-cycle")
|
||||
async def change_billing_cycle(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
new_billing_cycle: str = Query(..., description="New billing cycle (monthly/yearly)"),
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Change billing cycle for a tenant's subscription.
|
||||
|
||||
This endpoint:
|
||||
1. Validates the tenant has an active subscription
|
||||
2. Calculates proration costs
|
||||
3. Updates subscription in Stripe
|
||||
4. Updates local database
|
||||
5. Returns proration details to user
|
||||
"""
|
||||
|
||||
try:
|
||||
# Validate billing cycle parameter
|
||||
if new_billing_cycle not in ["monthly", "yearly"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Billing cycle must be 'monthly' or 'yearly'"
|
||||
)
|
||||
|
||||
# Get current subscription
|
||||
subscription_service = SubscriptionService(db)
|
||||
current_subscription = await subscription_service.get_subscription_by_tenant_id(tenant_id)
|
||||
|
||||
if not current_subscription:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found for this tenant"
|
||||
)
|
||||
|
||||
# Check if already on requested billing cycle
|
||||
current_cycle = current_subscription.billing_interval or "monthly"
|
||||
if current_cycle == new_billing_cycle:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Subscription is already on {new_billing_cycle} billing"
|
||||
)
|
||||
|
||||
# Use orchestration service for the billing cycle change
|
||||
from app.services.subscription_orchestration_service import SubscriptionOrchestrationService
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
|
||||
change_result = await orchestration_service.orchestrate_billing_cycle_change(
|
||||
tenant_id=str(tenant_id),
|
||||
new_billing_cycle=new_billing_cycle,
|
||||
immediate_change=True
|
||||
)
|
||||
|
||||
# Invalidate subscription cache
|
||||
try:
|
||||
from app.services.subscription_cache import get_subscription_cache_service
|
||||
import shared.redis_utils
|
||||
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
cache_service = get_subscription_cache_service(redis_client)
|
||||
await cache_service.invalidate_subscription_cache(str(tenant_id))
|
||||
|
||||
logger.info("Subscription cache invalidated after billing cycle change",
|
||||
tenant_id=str(tenant_id),
|
||||
new_billing_cycle=new_billing_cycle)
|
||||
except Exception as cache_error:
|
||||
logger.error("Failed to invalidate subscription cache",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(cache_error))
|
||||
|
||||
# Publish event for real-time notification
|
||||
try:
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
event_publisher = UnifiedEventPublisher()
|
||||
await event_publisher.publish_business_event(
|
||||
event_type="subscription.billing_cycle_changed",
|
||||
tenant_id=str(tenant_id),
|
||||
data={
|
||||
"tenant_id": str(tenant_id),
|
||||
"old_billing_cycle": current_cycle,
|
||||
"new_billing_cycle": new_billing_cycle,
|
||||
"action": "billing_cycle_change"
|
||||
}
|
||||
)
|
||||
logger.info("Published billing cycle change event",
|
||||
tenant_id=str(tenant_id))
|
||||
except Exception as event_error:
|
||||
logger.error("Failed to publish billing cycle change event",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(event_error))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Billing cycle changed to {new_billing_cycle}",
|
||||
"old_billing_cycle": current_cycle,
|
||||
"new_billing_cycle": new_billing_cycle,
|
||||
"proration_details": change_result.get("proration_details")
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to change billing cycle",
|
||||
tenant_id=str(tenant_id),
|
||||
new_billing_cycle=new_billing_cycle,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to change billing cycle: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/register-with-subscription")
|
||||
async def register_with_subscription(
|
||||
user_data: dict = Depends(get_current_user_dep),
|
||||
plan_id: str = Query(..., description="Plan ID to subscribe to"),
|
||||
plan_id: str = Query(..., description="Plan ID to subscribe to (starter, professional, enterprise)"),
|
||||
payment_method_id: str = Query(..., description="Payment method ID from frontend"),
|
||||
use_trial: bool = Query(False, description="Whether to use trial period for pilot users"),
|
||||
coupon_code: Optional[str] = Query(None, description="Coupon code to apply (e.g., PILOT2025)"),
|
||||
billing_interval: str = Query("monthly", description="Billing interval (monthly or yearly)"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Process user registration with subscription creation"""
|
||||
@@ -729,7 +859,9 @@ async def register_with_subscription(
|
||||
user_data.get('tenant_id'),
|
||||
plan_id,
|
||||
payment_method_id,
|
||||
14 if use_trial else None
|
||||
None, # Trial period handled by coupon logic
|
||||
billing_interval,
|
||||
coupon_code # Pass coupon code for trial period determination
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -745,6 +877,127 @@ async def register_with_subscription(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/{tenant_id}/create")
|
||||
async def create_subscription_endpoint(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
plan_id: str = Query(..., description="Plan ID (starter, professional, enterprise)"),
|
||||
payment_method_id: str = Query(..., description="Payment method ID from frontend"),
|
||||
billing_interval: str = Query("monthly", description="Billing interval (monthly or yearly)"),
|
||||
trial_period_days: Optional[int] = Query(None, description="Trial period in days"),
|
||||
coupon_code: Optional[str] = Query(None, description="Optional coupon code"),
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create a new subscription for a tenant using orchestration service
|
||||
|
||||
This endpoint orchestrates the complete subscription creation workflow
|
||||
including payment provider integration and tenant updates.
|
||||
"""
|
||||
try:
|
||||
# Prepare user data for orchestration service
|
||||
user_data = {
|
||||
'user_id': current_user.get('sub'),
|
||||
'email': current_user.get('email'),
|
||||
'full_name': current_user.get('name', 'Unknown User'),
|
||||
'tenant_id': tenant_id
|
||||
}
|
||||
|
||||
# Use orchestration service for complete workflow
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
|
||||
result = await orchestration_service.orchestrate_subscription_creation(
|
||||
tenant_id,
|
||||
user_data,
|
||||
plan_id,
|
||||
payment_method_id,
|
||||
billing_interval,
|
||||
coupon_code
|
||||
)
|
||||
|
||||
logger.info("subscription_created_via_orchestration",
|
||||
tenant_id=tenant_id,
|
||||
plan_id=plan_id,
|
||||
billing_interval=billing_interval,
|
||||
coupon_applied=result.get("coupon_applied", False))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Subscription created successfully",
|
||||
"data": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create subscription via API",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
plan_id=plan_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create subscription"
|
||||
)
|
||||
|
||||
class CreateForRegistrationRequest(BaseModel):
|
||||
"""Request model for create-for-registration endpoint"""
|
||||
user_data: dict = Field(..., description="User data for subscription creation")
|
||||
plan_id: str = Field(..., description="Plan ID (starter, professional, enterprise)")
|
||||
payment_method_id: str = Field(..., description="Payment method ID from frontend")
|
||||
billing_interval: str = Field("monthly", description="Billing interval (monthly or yearly)")
|
||||
coupon_code: Optional[str] = Field(None, description="Optional coupon code")
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/create-for-registration")
|
||||
async def create_subscription_for_registration(
|
||||
request: CreateForRegistrationRequest = Body(..., description="Subscription creation request"),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create a tenant-independent subscription during user registration
|
||||
|
||||
This endpoint creates a subscription that is not linked to any tenant.
|
||||
The subscription will be linked to a tenant during the onboarding flow.
|
||||
|
||||
This is used during the new registration flow where users register
|
||||
and pay before creating their tenant/bakery.
|
||||
"""
|
||||
try:
|
||||
logger.info("Creating tenant-independent subscription for registration",
|
||||
user_id=request.user_data.get('user_id'),
|
||||
plan_id=request.plan_id)
|
||||
|
||||
# Use orchestration service for tenant-independent subscription creation
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
|
||||
result = await orchestration_service.create_tenant_independent_subscription(
|
||||
request.user_data,
|
||||
request.plan_id,
|
||||
request.payment_method_id,
|
||||
request.billing_interval,
|
||||
request.coupon_code
|
||||
)
|
||||
|
||||
logger.info("Tenant-independent subscription created successfully",
|
||||
user_id=request.user_data.get('user_id'),
|
||||
subscription_id=result["subscription_id"],
|
||||
plan_id=request.plan_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Tenant-independent subscription created successfully",
|
||||
"data": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tenant-independent subscription",
|
||||
error=str(e),
|
||||
user_id=request.user_data.get('user_id'),
|
||||
plan_id=request.plan_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to create tenant-independent subscription"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/{tenant_id}/update-payment-method")
|
||||
async def update_payment_method(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
@@ -813,3 +1066,314 @@ async def update_payment_method(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred while updating payment method"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# NEW SUBSCRIPTION UPDATE ENDPOINTS WITH PRORATION SUPPORT
|
||||
# ============================================================================
|
||||
|
||||
class SubscriptionChangePreviewRequest(BaseModel):
|
||||
"""Request model for subscription change preview"""
|
||||
new_plan: str = Field(..., description="New plan name (starter, professional, enterprise) or 'same' for billing cycle changes")
|
||||
proration_behavior: str = Field("create_prorations", description="Proration behavior (create_prorations, none, always_invoice)")
|
||||
billing_cycle: str = Field("monthly", description="Billing cycle for the new plan (monthly, yearly)")
|
||||
|
||||
|
||||
class SubscriptionChangePreviewResponse(BaseModel):
|
||||
"""Response model for subscription change preview"""
|
||||
success: bool
|
||||
current_plan: str
|
||||
current_billing_cycle: str
|
||||
current_price: float
|
||||
new_plan: str
|
||||
new_billing_cycle: str
|
||||
new_price: float
|
||||
proration_details: Dict[str, Any]
|
||||
current_plan_features: List[str]
|
||||
new_plan_features: List[str]
|
||||
change_type: str
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/{tenant_id}/preview-change", response_model=SubscriptionChangePreviewResponse)
|
||||
async def preview_subscription_change(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request: SubscriptionChangePreviewRequest = Body(...),
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Preview the cost impact of a subscription change
|
||||
|
||||
This endpoint allows users to see the proration details before confirming a subscription change.
|
||||
It shows the cost difference, credits, and other financial impacts of changing plans or billing cycles.
|
||||
"""
|
||||
try:
|
||||
# Use SubscriptionService for preview
|
||||
subscription_service = SubscriptionService(db)
|
||||
|
||||
# Create payment service for proration calculation
|
||||
payment_service = PaymentService()
|
||||
result = await subscription_service.preview_subscription_change(
|
||||
tenant_id,
|
||||
request.new_plan,
|
||||
request.proration_behavior,
|
||||
request.billing_cycle,
|
||||
payment_service
|
||||
)
|
||||
|
||||
logger.info("subscription_change_previewed",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id"),
|
||||
new_plan=request.new_plan,
|
||||
proration_amount=result["proration_details"].get("net_amount", 0))
|
||||
|
||||
return SubscriptionChangePreviewResponse(**result)
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("preview_subscription_change_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(ve)
|
||||
)
|
||||
except DatabaseError as de:
|
||||
logger.error("preview_subscription_change_failed",
|
||||
error=str(de), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to preview subscription change"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("preview_subscription_change_unexpected_error",
|
||||
error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred while previewing subscription change"
|
||||
)
|
||||
|
||||
|
||||
class SubscriptionPlanUpdateRequest(BaseModel):
|
||||
"""Request model for subscription plan update"""
|
||||
new_plan: str = Field(..., description="New plan name (starter, professional, enterprise)")
|
||||
proration_behavior: str = Field("create_prorations", description="Proration behavior (create_prorations, none, always_invoice)")
|
||||
immediate_change: bool = Field(False, description="Whether to apply changes immediately or at period end")
|
||||
billing_cycle: str = Field("monthly", description="Billing cycle for the new plan (monthly, yearly)")
|
||||
|
||||
|
||||
class SubscriptionPlanUpdateResponse(BaseModel):
|
||||
"""Response model for subscription plan update"""
|
||||
success: bool
|
||||
message: str
|
||||
old_plan: str
|
||||
new_plan: str
|
||||
proration_details: Dict[str, Any]
|
||||
immediate_change: bool
|
||||
new_status: str
|
||||
new_period_end: str
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/{tenant_id}/update-plan", response_model=SubscriptionPlanUpdateResponse)
|
||||
async def update_subscription_plan(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request: SubscriptionPlanUpdateRequest = Body(...),
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update subscription plan with proration support
|
||||
|
||||
This endpoint allows users to change their subscription plan with proper proration handling.
|
||||
It supports both immediate changes and changes that take effect at the end of the billing period.
|
||||
"""
|
||||
try:
|
||||
# Use orchestration service for complete plan upgrade workflow
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
|
||||
result = await orchestration_service.orchestrate_plan_upgrade(
|
||||
tenant_id,
|
||||
request.new_plan,
|
||||
request.proration_behavior,
|
||||
request.immediate_change,
|
||||
request.billing_cycle
|
||||
)
|
||||
|
||||
logger.info("subscription_plan_updated",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id"),
|
||||
old_plan=result["old_plan"],
|
||||
new_plan=result["new_plan"],
|
||||
proration_amount=result["proration_details"].get("net_amount", 0),
|
||||
immediate_change=request.immediate_change)
|
||||
|
||||
return SubscriptionPlanUpdateResponse(**result)
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("update_subscription_plan_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(ve)
|
||||
)
|
||||
except DatabaseError as de:
|
||||
logger.error("update_subscription_plan_failed",
|
||||
error=str(de), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update subscription plan"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("update_subscription_plan_unexpected_error",
|
||||
error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred while updating subscription plan"
|
||||
)
|
||||
|
||||
|
||||
class BillingCycleChangeRequest(BaseModel):
|
||||
"""Request model for billing cycle change"""
|
||||
new_billing_cycle: str = Field(..., description="New billing cycle (monthly, yearly)")
|
||||
proration_behavior: str = Field("create_prorations", description="Proration behavior (create_prorations, none, always_invoice)")
|
||||
|
||||
|
||||
class BillingCycleChangeResponse(BaseModel):
|
||||
"""Response model for billing cycle change"""
|
||||
success: bool
|
||||
message: str
|
||||
old_billing_cycle: str
|
||||
new_billing_cycle: str
|
||||
proration_details: Dict[str, Any]
|
||||
new_status: str
|
||||
new_period_end: str
|
||||
|
||||
|
||||
@router.post("/api/v1/subscriptions/{tenant_id}/change-billing-cycle", response_model=BillingCycleChangeResponse)
|
||||
async def change_billing_cycle(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request: BillingCycleChangeRequest = Body(...),
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Change billing cycle (monthly ↔ yearly) for a subscription
|
||||
|
||||
This endpoint allows users to switch between monthly and yearly billing cycles.
|
||||
It handles proration and creates appropriate charges or credits.
|
||||
"""
|
||||
try:
|
||||
# Use orchestration service for complete billing cycle change workflow
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
|
||||
result = await orchestration_service.orchestrate_billing_cycle_change(
|
||||
tenant_id,
|
||||
request.new_billing_cycle,
|
||||
request.proration_behavior
|
||||
)
|
||||
|
||||
logger.info("subscription_billing_cycle_changed",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id"),
|
||||
old_billing_cycle=result["old_billing_cycle"],
|
||||
new_billing_cycle=result["new_billing_cycle"],
|
||||
proration_amount=result["proration_details"].get("net_amount", 0))
|
||||
|
||||
return BillingCycleChangeResponse(**result)
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("change_billing_cycle_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(ve)
|
||||
)
|
||||
except DatabaseError as de:
|
||||
logger.error("change_billing_cycle_failed",
|
||||
error=str(de), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to change billing cycle"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("change_billing_cycle_unexpected_error",
|
||||
error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred while changing billing cycle"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# COUPON REDEMPTION ENDPOINTS
|
||||
# ============================================================================
|
||||
|
||||
class CouponRedemptionRequest(BaseModel):
|
||||
"""Request model for coupon redemption"""
|
||||
coupon_code: str = Field(..., description="Coupon code to redeem")
|
||||
base_trial_days: int = Field(14, description="Base trial days without coupon")
|
||||
|
||||
class CouponRedemptionResponse(BaseModel):
|
||||
"""Response model for coupon redemption"""
|
||||
success: bool
|
||||
coupon_applied: bool
|
||||
discount: Optional[Dict[str, Any]] = None
|
||||
message: str
|
||||
error: Optional[str] = None
|
||||
|
||||
@router.post("/api/v1/subscriptions/{tenant_id}/redeem-coupon", response_model=CouponRedemptionResponse)
|
||||
async def redeem_coupon(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
request: CouponRedemptionRequest = Body(...),
|
||||
current_user: dict = Depends(get_current_user_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Redeem a coupon for a tenant
|
||||
|
||||
This endpoint handles the complete coupon redemption workflow including
|
||||
validation, redemption, and tenant updates.
|
||||
"""
|
||||
try:
|
||||
# Use orchestration service for complete coupon redemption workflow
|
||||
orchestration_service = SubscriptionOrchestrationService(db)
|
||||
|
||||
result = await orchestration_service.orchestrate_coupon_redemption(
|
||||
tenant_id,
|
||||
request.coupon_code,
|
||||
request.base_trial_days
|
||||
)
|
||||
|
||||
logger.info("coupon_redeemed",
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id"),
|
||||
coupon_code=request.coupon_code,
|
||||
success=result["success"])
|
||||
|
||||
return CouponRedemptionResponse(
|
||||
success=result["success"],
|
||||
coupon_applied=result.get("coupon_applied", False),
|
||||
discount=result.get("discount"),
|
||||
message=result.get("message", "Coupon redemption processed"),
|
||||
error=result.get("error")
|
||||
)
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("coupon_redemption_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(ve)
|
||||
)
|
||||
except DatabaseError as de:
|
||||
logger.error("coupon_redemption_failed", error=str(de), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to redeem coupon"
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("coupon_redemption_failed", error=str(e), tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="An unexpected error occurred while redeeming coupon"
|
||||
)
|
||||
|
||||
@@ -11,7 +11,8 @@ from app.schemas.tenants import (
|
||||
ChildTenantCreate,
|
||||
BulkChildTenantsCreate,
|
||||
BulkChildTenantsResponse,
|
||||
ChildTenantResponse
|
||||
ChildTenantResponse,
|
||||
TenantHierarchyResponse
|
||||
)
|
||||
from app.services.tenant_service import EnhancedTenantService
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
@@ -219,6 +220,115 @@ async def get_tenant_children_count(
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/hierarchy", include_tenant_prefix=False), response_model=TenantHierarchyResponse)
|
||||
@track_endpoint_metrics("tenant_hierarchy")
|
||||
async def get_tenant_hierarchy(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Get tenant hierarchy information.
|
||||
|
||||
Returns hierarchy metadata for a tenant including:
|
||||
- Tenant type (standalone, parent, child)
|
||||
- Parent tenant ID (if this is a child)
|
||||
- Hierarchy path (materialized path)
|
||||
- Number of child tenants (for parent tenants)
|
||||
- Hierarchy level (depth in the tree)
|
||||
|
||||
This endpoint is used by the authentication layer for hierarchical access control
|
||||
and by enterprise features for network management.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Get tenant hierarchy request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
user_type=current_user.get("type", "user"),
|
||||
is_service=current_user.get("type") == "service"
|
||||
)
|
||||
|
||||
# Get tenant from database
|
||||
from app.models.tenants import Tenant
|
||||
async with tenant_service.database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
|
||||
# Get the tenant
|
||||
tenant = await tenant_repo.get(str(tenant_id))
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tenant {tenant_id} not found"
|
||||
)
|
||||
|
||||
# Skip access check for service-to-service calls
|
||||
is_service_call = current_user.get("type") == "service"
|
||||
if not is_service_call:
|
||||
# Verify user has access to this tenant
|
||||
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
|
||||
if not access_info.has_access:
|
||||
logger.warning(
|
||||
"Access denied to tenant for hierarchy query",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Service-to-service call - bypassing access check",
|
||||
service=current_user.get("service"),
|
||||
tenant_id=str(tenant_id)
|
||||
)
|
||||
|
||||
# Get child count if this is a parent tenant
|
||||
child_count = 0
|
||||
if tenant.tenant_type in ["parent", "standalone"]:
|
||||
child_count = await tenant_repo.get_child_tenant_count(str(tenant_id))
|
||||
|
||||
# Calculate hierarchy level from hierarchy_path
|
||||
hierarchy_level = 0
|
||||
if tenant.hierarchy_path:
|
||||
# hierarchy_path format: "parent_id" or "parent_id.child_id" or "parent_id.child_id.grandchild_id"
|
||||
hierarchy_level = tenant.hierarchy_path.count('.')
|
||||
|
||||
# Build response
|
||||
hierarchy_info = TenantHierarchyResponse(
|
||||
tenant_id=str(tenant.id),
|
||||
tenant_type=tenant.tenant_type or "standalone",
|
||||
parent_tenant_id=str(tenant.parent_tenant_id) if tenant.parent_tenant_id else None,
|
||||
hierarchy_path=tenant.hierarchy_path,
|
||||
child_count=child_count,
|
||||
hierarchy_level=hierarchy_level
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Get tenant hierarchy successful",
|
||||
tenant_id=str(tenant_id),
|
||||
tenant_type=tenant.tenant_type,
|
||||
parent_tenant_id=str(tenant.parent_tenant_id) if tenant.parent_tenant_id else None,
|
||||
child_count=child_count,
|
||||
hierarchy_level=hierarchy_level
|
||||
)
|
||||
|
||||
return hierarchy_info
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant hierarchy failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Get tenant hierarchy failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/tenants/{tenant_id}/bulk-children", response_model=BulkChildTenantsResponse)
|
||||
@track_endpoint_metrics("bulk_create_child_tenants")
|
||||
async def bulk_create_child_tenants(
|
||||
|
||||
@@ -22,6 +22,8 @@ from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role_dep
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from shared.auth.access_control import owner_role_required, admin_role_required
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
from shared.database.base import create_database_manager
|
||||
@@ -94,7 +96,6 @@ def get_payment_service():
|
||||
logger.error("Failed to create payment service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Payment service initialization failed")
|
||||
|
||||
# ============================================================================
|
||||
# TENANT REGISTRATION & ACCESS OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
@@ -103,81 +104,142 @@ async def register_bakery(
|
||||
bakery_data: BakeryRegistration,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service),
|
||||
payment_service: PaymentService = Depends(get_payment_service)
|
||||
payment_service: PaymentService = Depends(get_payment_service),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Register a new bakery/tenant with enhanced validation and features"""
|
||||
|
||||
try:
|
||||
# Validate coupon if provided
|
||||
# Initialize variables to avoid UnboundLocalError
|
||||
coupon_validation = None
|
||||
if bakery_data.coupon_code:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
success = None
|
||||
discount = None
|
||||
error = None
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
# Temp tenant ID for validation (will be replaced with actual after creation)
|
||||
temp_tenant_id = f"temp_{current_user['user_id']}"
|
||||
|
||||
coupon_validation = payment_service.validate_coupon_code(
|
||||
bakery_data.coupon_code,
|
||||
temp_tenant_id,
|
||||
session
|
||||
)
|
||||
|
||||
if not coupon_validation["valid"]:
|
||||
logger.warning(
|
||||
"Invalid coupon code provided during registration",
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
error=coupon_validation["error_message"]
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=coupon_validation["error_message"]
|
||||
)
|
||||
|
||||
# Create bakery/tenant
|
||||
# Create bakery/tenant first
|
||||
result = await tenant_service.create_bakery(
|
||||
bakery_data,
|
||||
current_user["user_id"]
|
||||
)
|
||||
|
||||
# CRITICAL: Create default subscription for new tenant
|
||||
try:
|
||||
from app.repositories.subscription_repository import SubscriptionRepository
|
||||
from app.models.tenants import Subscription
|
||||
from datetime import datetime, timedelta, timezone
|
||||
tenant_id = result.id
|
||||
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
async with database_manager.get_session() as session:
|
||||
subscription_repo = SubscriptionRepository(Subscription, session)
|
||||
# NEW ARCHITECTURE: Check if we need to link an existing subscription
|
||||
if bakery_data.link_existing_subscription and bakery_data.subscription_id:
|
||||
logger.info("Linking existing subscription to new tenant",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=bakery_data.subscription_id,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
# Create starter subscription with 14-day trial
|
||||
trial_end_date = datetime.now(timezone.utc) + timedelta(days=14)
|
||||
next_billing_date = trial_end_date
|
||||
try:
|
||||
# Import subscription service for linking
|
||||
from app.services.subscription_service import SubscriptionService
|
||||
|
||||
await subscription_repo.create_subscription(
|
||||
tenant_id=str(result.id),
|
||||
plan="starter",
|
||||
status="active",
|
||||
billing_cycle="monthly",
|
||||
next_billing_date=next_billing_date,
|
||||
trial_ends_at=trial_end_date
|
||||
subscription_service = SubscriptionService(db)
|
||||
|
||||
# Link the subscription to the tenant
|
||||
linking_result = await subscription_service.link_subscription_to_tenant(
|
||||
subscription_id=bakery_data.subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Default subscription created for new tenant",
|
||||
tenant_id=str(result.id),
|
||||
plan="starter",
|
||||
trial_days=14
|
||||
)
|
||||
except Exception as subscription_error:
|
||||
logger.error(
|
||||
"Failed to create default subscription for tenant",
|
||||
tenant_id=str(result.id),
|
||||
error=str(subscription_error)
|
||||
logger.info("Subscription linked successfully during tenant registration",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=bakery_data.subscription_id)
|
||||
|
||||
except Exception as linking_error:
|
||||
logger.error("Error linking subscription during tenant registration",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=bakery_data.subscription_id,
|
||||
error=str(linking_error))
|
||||
# Don't fail tenant creation if subscription linking fails
|
||||
# The subscription can be linked later manually
|
||||
|
||||
elif bakery_data.coupon_code:
|
||||
# If no subscription but coupon provided, just validate and redeem coupon
|
||||
coupon_validation = payment_service.validate_coupon_code(
|
||||
bakery_data.coupon_code,
|
||||
tenant_id,
|
||||
db
|
||||
)
|
||||
# Don't fail tenant creation if subscription creation fails
|
||||
|
||||
if not coupon_validation["valid"]:
|
||||
logger.warning(
|
||||
"Invalid coupon code provided during registration",
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
error=coupon_validation["error_message"]
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=coupon_validation["error_message"]
|
||||
)
|
||||
|
||||
# Redeem coupon
|
||||
success, discount, error = payment_service.redeem_coupon(
|
||||
bakery_data.coupon_code,
|
||||
tenant_id,
|
||||
db
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Coupon redeemed during registration",
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
tenant_id=tenant_id)
|
||||
else:
|
||||
logger.warning("Failed to redeem coupon during registration",
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
error=error)
|
||||
else:
|
||||
# No subscription plan provided - check if tenant already has a subscription
|
||||
# (from new registration flow where subscription is created first)
|
||||
try:
|
||||
from app.repositories.subscription_repository import SubscriptionRepository
|
||||
from app.models.tenants import Subscription
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.core.config import settings
|
||||
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
async with database_manager.get_session() as session:
|
||||
subscription_repo = SubscriptionRepository(Subscription, session)
|
||||
|
||||
# Check if tenant already has an active subscription
|
||||
existing_subscription = await subscription_repo.get_by_tenant_id(str(result.id))
|
||||
|
||||
if existing_subscription:
|
||||
logger.info(
|
||||
"Tenant already has an active subscription, skipping default subscription creation",
|
||||
tenant_id=str(result.id),
|
||||
existing_plan=existing_subscription.plan,
|
||||
subscription_id=str(existing_subscription.id)
|
||||
)
|
||||
else:
|
||||
# Create starter subscription with 14-day trial
|
||||
trial_end_date = datetime.now(timezone.utc) + timedelta(days=14)
|
||||
next_billing_date = trial_end_date
|
||||
|
||||
await subscription_repo.create_subscription({
|
||||
"tenant_id": str(result.id),
|
||||
"plan": "starter",
|
||||
"status": "trial",
|
||||
"billing_cycle": "monthly",
|
||||
"next_billing_date": next_billing_date,
|
||||
"trial_ends_at": trial_end_date
|
||||
})
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Default free trial subscription created for new tenant",
|
||||
tenant_id=str(result.id),
|
||||
plan="starter",
|
||||
trial_days=14
|
||||
)
|
||||
except Exception as subscription_error:
|
||||
logger.error(
|
||||
"Failed to create default subscription for tenant",
|
||||
tenant_id=str(result.id),
|
||||
error=str(subscription_error)
|
||||
)
|
||||
|
||||
# If coupon was validated, redeem it now with actual tenant_id
|
||||
if coupon_validation and coupon_validation["valid"]:
|
||||
@@ -1068,9 +1130,101 @@ async def upgrade_subscription_plan(
|
||||
@router.post(route_builder.build_base_route("subscriptions/register-with-subscription", include_tenant_prefix=False))
|
||||
async def register_with_subscription(
|
||||
user_data: Dict[str, Any],
|
||||
plan_id: str = Query(..., description="Plan ID to subscribe to"),
|
||||
plan_id: str = Query(..., description="Plan ID to subscribe to (starter, professional, enterprise)"),
|
||||
payment_method_id: str = Query(..., description="Payment method ID from frontend"),
|
||||
use_trial: bool = Query(False, description="Whether to use trial period for pilot users"),
|
||||
coupon_code: str = Query(None, description="Coupon code for discounts or trial periods"),
|
||||
billing_interval: str = Query("monthly", description="Billing interval (monthly or yearly)"),
|
||||
payment_service: PaymentService = Depends(get_payment_service)
|
||||
):
|
||||
"""Process user registration with subscription creation"""
|
||||
|
||||
@router.post(route_builder.build_base_route("payment-customers/create", include_tenant_prefix=False))
|
||||
async def create_payment_customer(
|
||||
user_data: Dict[str, Any],
|
||||
payment_method_id: Optional[str] = Query(None, description="Optional payment method ID"),
|
||||
payment_service: PaymentService = Depends(get_payment_service)
|
||||
):
|
||||
"""
|
||||
Create a payment customer in the payment provider
|
||||
|
||||
This endpoint is designed for service-to-service communication from auth service
|
||||
during user registration. It creates a payment customer that can be used later
|
||||
for subscription creation.
|
||||
|
||||
Args:
|
||||
user_data: User data including email, name, etc.
|
||||
payment_method_id: Optional payment method ID to attach
|
||||
|
||||
Returns:
|
||||
Dictionary with payment customer details
|
||||
"""
|
||||
try:
|
||||
logger.info("Creating payment customer via service-to-service call",
|
||||
email=user_data.get('email'),
|
||||
user_id=user_data.get('user_id'))
|
||||
|
||||
# Step 1: Create payment customer
|
||||
customer = await payment_service.create_customer(user_data)
|
||||
logger.info("Payment customer created successfully",
|
||||
customer_id=customer.id,
|
||||
email=customer.email)
|
||||
|
||||
# Step 2: Attach payment method if provided
|
||||
payment_method_details = None
|
||||
if payment_method_id:
|
||||
try:
|
||||
payment_method = await payment_service.update_payment_method(
|
||||
customer.id,
|
||||
payment_method_id
|
||||
)
|
||||
payment_method_details = {
|
||||
"id": payment_method.id,
|
||||
"type": payment_method.type,
|
||||
"brand": payment_method.brand,
|
||||
"last4": payment_method.last4,
|
||||
"exp_month": payment_method.exp_month,
|
||||
"exp_year": payment_method.exp_year
|
||||
}
|
||||
logger.info("Payment method attached to customer",
|
||||
customer_id=customer.id,
|
||||
payment_method_id=payment_method.id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to attach payment method to customer",
|
||||
customer_id=customer.id,
|
||||
error=str(e),
|
||||
payment_method_id=payment_method_id)
|
||||
# Continue without attached payment method
|
||||
|
||||
# Step 3: Return comprehensive result
|
||||
return {
|
||||
"success": True,
|
||||
"payment_customer_id": customer.id,
|
||||
"payment_method": payment_method_details,
|
||||
"customer": {
|
||||
"id": customer.id,
|
||||
"email": customer.email,
|
||||
"name": customer.name,
|
||||
"created_at": customer.created_at.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create payment customer via service-to-service call",
|
||||
error=str(e),
|
||||
email=user_data.get('email'),
|
||||
user_id=user_data.get('user_id'))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create payment customer: {str(e)}"
|
||||
)
|
||||
|
||||
@router.post(route_builder.build_base_route("subscriptions/register-with-subscription", include_tenant_prefix=False))
|
||||
async def register_with_subscription(
|
||||
user_data: Dict[str, Any],
|
||||
plan_id: str = Query(..., description="Plan ID to subscribe to (starter, professional, enterprise)"),
|
||||
payment_method_id: str = Query(..., description="Payment method ID from frontend"),
|
||||
coupon_code: str = Query(None, description="Coupon code for discounts or trial periods"),
|
||||
billing_interval: str = Query("monthly", description="Billing interval (monthly or yearly)"),
|
||||
payment_service: PaymentService = Depends(get_payment_service)
|
||||
):
|
||||
"""Process user registration with subscription creation"""
|
||||
@@ -1080,7 +1234,8 @@ async def register_with_subscription(
|
||||
user_data,
|
||||
plan_id,
|
||||
payment_method_id,
|
||||
use_trial
|
||||
coupon_code,
|
||||
billing_interval
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -1095,6 +1250,61 @@ async def register_with_subscription(
|
||||
detail="Failed to register with subscription"
|
||||
)
|
||||
|
||||
@router.post(route_builder.build_base_route("subscriptions/link", include_tenant_prefix=False))
|
||||
async def link_subscription_to_tenant(
|
||||
tenant_id: str = Query(..., description="Tenant ID to link subscription to"),
|
||||
subscription_id: str = Query(..., description="Subscription ID to link"),
|
||||
user_id: str = Query(..., description="User ID performing the linking"),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Link a pending subscription to a tenant
|
||||
|
||||
This endpoint completes the registration flow by associating the subscription
|
||||
created during registration with the tenant created during onboarding.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to link to
|
||||
subscription_id: Subscription ID to link
|
||||
user_id: User ID performing the linking (for validation)
|
||||
|
||||
Returns:
|
||||
Dictionary with linking results
|
||||
"""
|
||||
try:
|
||||
logger.info("Linking subscription to tenant",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
user_id=user_id)
|
||||
|
||||
# Link subscription to tenant
|
||||
result = await tenant_service.link_subscription_to_tenant(
|
||||
tenant_id, subscription_id, user_id
|
||||
)
|
||||
|
||||
logger.info("Subscription linked to tenant successfully",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
user_id=user_id)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Subscription linked to tenant successfully",
|
||||
"data": result
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to link subscription to tenant",
|
||||
error=str(e),
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
user_id=user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to link subscription to tenant"
|
||||
)
|
||||
|
||||
|
||||
async def _invalidate_tenant_tokens(tenant_id: str, redis_client):
|
||||
"""
|
||||
|
||||
@@ -1,36 +1,37 @@
|
||||
"""
|
||||
Webhook endpoints for handling payment provider events
|
||||
These endpoints receive events from payment providers like Stripe
|
||||
All event processing is handled by SubscriptionOrchestrationService
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import stripe
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.payment_service import PaymentService
|
||||
from app.services.subscription_orchestration_service import SubscriptionOrchestrationService
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from app.models.tenants import Subscription, Tenant
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
def get_payment_service():
|
||||
|
||||
def get_subscription_orchestration_service(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> SubscriptionOrchestrationService:
|
||||
"""Dependency injection for SubscriptionOrchestrationService"""
|
||||
try:
|
||||
return PaymentService()
|
||||
return SubscriptionOrchestrationService(db)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create payment service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Payment service initialization failed")
|
||||
logger.error("Failed to create subscription orchestration service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
|
||||
@router.post("/webhooks/stripe")
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
payment_service: PaymentService = Depends(get_payment_service)
|
||||
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
|
||||
):
|
||||
"""
|
||||
Stripe webhook endpoint to handle payment events
|
||||
@@ -74,39 +75,14 @@ async def stripe_webhook(
|
||||
event_type=event_type,
|
||||
event_id=event.get('id'))
|
||||
|
||||
# Process different types of events
|
||||
if event_type == 'checkout.session.completed':
|
||||
# Handle successful checkout
|
||||
await handle_checkout_completed(event_data, db)
|
||||
# Use orchestration service to handle the event
|
||||
result = await orchestration_service.handle_payment_webhook(event_type, event_data)
|
||||
|
||||
elif event_type == 'customer.subscription.created':
|
||||
# Handle new subscription
|
||||
await handle_subscription_created(event_data, db)
|
||||
logger.info("Webhook event processed via orchestration service",
|
||||
event_type=event_type,
|
||||
actions_taken=result.get("actions_taken", []))
|
||||
|
||||
elif event_type == 'customer.subscription.updated':
|
||||
# Handle subscription update
|
||||
await handle_subscription_updated(event_data, db)
|
||||
|
||||
elif event_type == 'customer.subscription.deleted':
|
||||
# Handle subscription cancellation
|
||||
await handle_subscription_deleted(event_data, db)
|
||||
|
||||
elif event_type == 'invoice.payment_succeeded':
|
||||
# Handle successful payment
|
||||
await handle_payment_succeeded(event_data, db)
|
||||
|
||||
elif event_type == 'invoice.payment_failed':
|
||||
# Handle failed payment
|
||||
await handle_payment_failed(event_data, db)
|
||||
|
||||
elif event_type == 'customer.subscription.trial_will_end':
|
||||
# Handle trial ending soon (3 days before)
|
||||
await handle_trial_will_end(event_data, db)
|
||||
|
||||
else:
|
||||
logger.info("Unhandled webhook event type", event_type=event_type)
|
||||
|
||||
return {"success": True, "event_type": event_type}
|
||||
return {"success": True, "event_type": event_type, "actions_taken": result.get("actions_taken", [])}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -116,260 +92,3 @@ async def stripe_webhook(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Webhook processing error"
|
||||
)
|
||||
|
||||
|
||||
async def handle_checkout_completed(session: Dict[str, Any], db: AsyncSession):
|
||||
"""Handle successful checkout session completion"""
|
||||
logger.info("Processing checkout.session.completed",
|
||||
session_id=session.get('id'))
|
||||
|
||||
customer_id = session.get('customer')
|
||||
subscription_id = session.get('subscription')
|
||||
|
||||
if customer_id and subscription_id:
|
||||
# Update tenant with subscription info
|
||||
query = select(Tenant).where(Tenant.stripe_customer_id == customer_id)
|
||||
result = await db.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.info("Checkout completed for tenant",
|
||||
tenant_id=str(tenant.id),
|
||||
subscription_id=subscription_id)
|
||||
|
||||
|
||||
async def handle_subscription_created(subscription: Dict[str, Any], db: AsyncSession):
|
||||
"""Handle new subscription creation"""
|
||||
logger.info("Processing customer.subscription.created",
|
||||
subscription_id=subscription.get('id'))
|
||||
|
||||
customer_id = subscription.get('customer')
|
||||
subscription_id = subscription.get('id')
|
||||
status_value = subscription.get('status')
|
||||
|
||||
# Find tenant by customer ID
|
||||
query = select(Tenant).where(Tenant.stripe_customer_id == customer_id)
|
||||
result = await db.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.info("Subscription created for tenant",
|
||||
tenant_id=str(tenant.id),
|
||||
subscription_id=subscription_id,
|
||||
status=status_value)
|
||||
|
||||
|
||||
async def handle_subscription_updated(subscription: Dict[str, Any], db: AsyncSession):
|
||||
"""Handle subscription updates (status changes, plan changes, etc.)"""
|
||||
subscription_id = subscription.get('id')
|
||||
status_value = subscription.get('status')
|
||||
|
||||
logger.info("Processing customer.subscription.updated",
|
||||
subscription_id=subscription_id,
|
||||
new_status=status_value)
|
||||
|
||||
# Find subscription in database
|
||||
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
||||
result = await db.execute(query)
|
||||
db_subscription = result.scalar_one_or_none()
|
||||
|
||||
if db_subscription:
|
||||
# Update subscription status
|
||||
db_subscription.status = status_value
|
||||
db_subscription.current_period_end = datetime.fromtimestamp(
|
||||
subscription.get('current_period_end')
|
||||
)
|
||||
|
||||
# Update active status based on Stripe status
|
||||
if status_value == 'active':
|
||||
db_subscription.is_active = True
|
||||
elif status_value in ['canceled', 'past_due', 'unpaid']:
|
||||
db_subscription.is_active = False
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Invalidate cache
|
||||
try:
|
||||
from app.services.subscription_cache import get_subscription_cache_service
|
||||
import shared.redis_utils
|
||||
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
cache_service = get_subscription_cache_service(redis_client)
|
||||
await cache_service.invalidate_subscription_cache(str(db_subscription.tenant_id))
|
||||
except Exception as cache_error:
|
||||
logger.error("Failed to invalidate cache", error=str(cache_error))
|
||||
|
||||
logger.info("Subscription updated in database",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=str(db_subscription.tenant_id))
|
||||
|
||||
|
||||
async def handle_subscription_deleted(subscription: Dict[str, Any], db: AsyncSession):
|
||||
"""Handle subscription cancellation/deletion"""
|
||||
subscription_id = subscription.get('id')
|
||||
|
||||
logger.info("Processing customer.subscription.deleted",
|
||||
subscription_id=subscription_id)
|
||||
|
||||
# Find subscription in database
|
||||
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
||||
result = await db.execute(query)
|
||||
db_subscription = result.scalar_one_or_none()
|
||||
|
||||
if db_subscription:
|
||||
db_subscription.status = 'canceled'
|
||||
db_subscription.is_active = False
|
||||
db_subscription.canceled_at = datetime.utcnow()
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Invalidate cache
|
||||
try:
|
||||
from app.services.subscription_cache import get_subscription_cache_service
|
||||
import shared.redis_utils
|
||||
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
cache_service = get_subscription_cache_service(redis_client)
|
||||
await cache_service.invalidate_subscription_cache(str(db_subscription.tenant_id))
|
||||
except Exception as cache_error:
|
||||
logger.error("Failed to invalidate cache", error=str(cache_error))
|
||||
|
||||
logger.info("Subscription canceled in database",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=str(db_subscription.tenant_id))
|
||||
|
||||
|
||||
async def handle_payment_succeeded(invoice: Dict[str, Any], db: AsyncSession):
|
||||
"""Handle successful invoice payment"""
|
||||
invoice_id = invoice.get('id')
|
||||
subscription_id = invoice.get('subscription')
|
||||
|
||||
logger.info("Processing invoice.payment_succeeded",
|
||||
invoice_id=invoice_id,
|
||||
subscription_id=subscription_id)
|
||||
|
||||
if subscription_id:
|
||||
# Find subscription and ensure it's active
|
||||
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
||||
result = await db.execute(query)
|
||||
db_subscription = result.scalar_one_or_none()
|
||||
|
||||
if db_subscription:
|
||||
db_subscription.status = 'active'
|
||||
db_subscription.is_active = True
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Payment succeeded, subscription activated",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=str(db_subscription.tenant_id))
|
||||
|
||||
|
||||
async def handle_payment_failed(invoice: Dict[str, Any], db: AsyncSession):
|
||||
"""Handle failed invoice payment"""
|
||||
invoice_id = invoice.get('id')
|
||||
subscription_id = invoice.get('subscription')
|
||||
customer_id = invoice.get('customer')
|
||||
|
||||
logger.error("Processing invoice.payment_failed",
|
||||
invoice_id=invoice_id,
|
||||
subscription_id=subscription_id,
|
||||
customer_id=customer_id)
|
||||
|
||||
if subscription_id:
|
||||
# Find subscription and mark as past_due
|
||||
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
||||
result = await db.execute(query)
|
||||
db_subscription = result.scalar_one_or_none()
|
||||
|
||||
if db_subscription:
|
||||
db_subscription.status = 'past_due'
|
||||
db_subscription.is_active = False
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.warning("Payment failed, subscription marked past_due",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=str(db_subscription.tenant_id))
|
||||
|
||||
# TODO: Send notification to user about payment failure
|
||||
# You can integrate with your notification service here
|
||||
|
||||
|
||||
async def handle_trial_will_end(subscription: Dict[str, Any], db: AsyncSession):
|
||||
"""Handle notification that trial will end in 3 days"""
|
||||
subscription_id = subscription.get('id')
|
||||
trial_end = subscription.get('trial_end')
|
||||
|
||||
logger.info("Processing customer.subscription.trial_will_end",
|
||||
subscription_id=subscription_id,
|
||||
trial_end_timestamp=trial_end)
|
||||
|
||||
# Find subscription
|
||||
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
||||
result = await db.execute(query)
|
||||
db_subscription = result.scalar_one_or_none()
|
||||
|
||||
if db_subscription:
|
||||
logger.info("Trial ending soon for subscription",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=str(db_subscription.tenant_id))
|
||||
|
||||
# TODO: Send notification to user about trial ending soon
|
||||
# You can integrate with your notification service here
|
||||
|
||||
@router.post("/webhooks/generic")
|
||||
async def generic_webhook(
|
||||
request: Request,
|
||||
payment_service: PaymentService = Depends(get_payment_service)
|
||||
):
|
||||
"""
|
||||
Generic webhook endpoint that can handle events from any payment provider
|
||||
"""
|
||||
try:
|
||||
# Get the payload
|
||||
payload = await request.json()
|
||||
|
||||
# Log the event for debugging
|
||||
logger.info("Received generic webhook", payload=payload)
|
||||
|
||||
# Process the event based on its type
|
||||
event_type = payload.get('type', 'unknown')
|
||||
event_data = payload.get('data', {})
|
||||
|
||||
# Process different types of events
|
||||
if event_type == 'subscription.created':
|
||||
# Handle new subscription
|
||||
logger.info("Processing new subscription event", subscription_id=event_data.get('id'))
|
||||
# Update database with new subscription
|
||||
elif event_type == 'subscription.updated':
|
||||
# Handle subscription update
|
||||
logger.info("Processing subscription update event", subscription_id=event_data.get('id'))
|
||||
# Update database with subscription changes
|
||||
elif event_type == 'subscription.deleted':
|
||||
# Handle subscription cancellation
|
||||
logger.info("Processing subscription cancellation event", subscription_id=event_data.get('id'))
|
||||
# Update database with cancellation
|
||||
elif event_type == 'payment.succeeded':
|
||||
# Handle successful payment
|
||||
logger.info("Processing successful payment event", payment_id=event_data.get('id'))
|
||||
# Update payment status in database
|
||||
elif event_type == 'payment.failed':
|
||||
# Handle failed payment
|
||||
logger.info("Processing failed payment event", payment_id=event_data.get('id'))
|
||||
# Update payment status and notify user
|
||||
elif event_type == 'invoice.created':
|
||||
# Handle new invoice
|
||||
logger.info("Processing new invoice event", invoice_id=event_data.get('id'))
|
||||
# Store invoice information
|
||||
else:
|
||||
logger.warning("Unknown event type received", event_type=event_type)
|
||||
|
||||
return {"success": True}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error processing generic webhook", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Webhook error"
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user