376 lines
14 KiB
Python
376 lines
14 KiB
Python
"""
|
|
Webhook endpoints for handling payment provider events
|
|
These endpoints receive events from payment providers like Stripe
|
|
"""
|
|
|
|
import structlog
|
|
import stripe
|
|
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
|
from typing import Dict, Any
|
|
from datetime import datetime
|
|
|
|
from app.services.payment_service import PaymentService
|
|
from app.core.config import settings
|
|
from app.core.database import get_db
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from app.models.tenants import Subscription, Tenant
|
|
|
|
logger = structlog.get_logger()
|
|
router = APIRouter()
|
|
|
|
def get_payment_service():
|
|
try:
|
|
return PaymentService()
|
|
except Exception as e:
|
|
logger.error("Failed to create payment service", error=str(e))
|
|
raise HTTPException(status_code=500, detail="Payment service initialization failed")
|
|
|
|
@router.post("/webhooks/stripe")
|
|
async def stripe_webhook(
|
|
request: Request,
|
|
db: AsyncSession = Depends(get_db),
|
|
payment_service: PaymentService = Depends(get_payment_service)
|
|
):
|
|
"""
|
|
Stripe webhook endpoint to handle payment events
|
|
This endpoint verifies webhook signatures and processes Stripe events
|
|
"""
|
|
try:
|
|
# Get the payload and signature
|
|
payload = await request.body()
|
|
sig_header = request.headers.get('stripe-signature')
|
|
|
|
if not sig_header:
|
|
logger.error("Missing stripe-signature header")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Missing signature header"
|
|
)
|
|
|
|
# Verify the webhook signature
|
|
try:
|
|
event = stripe.Webhook.construct_event(
|
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
|
)
|
|
except stripe.error.SignatureVerificationError as e:
|
|
logger.error("Invalid webhook signature", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid signature"
|
|
)
|
|
except ValueError as e:
|
|
logger.error("Invalid payload", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid payload"
|
|
)
|
|
|
|
# Get event type and data
|
|
event_type = event['type']
|
|
event_data = event['data']['object']
|
|
|
|
logger.info("Processing Stripe webhook event",
|
|
event_type=event_type,
|
|
event_id=event.get('id'))
|
|
|
|
# Process different types of events
|
|
if event_type == 'checkout.session.completed':
|
|
# Handle successful checkout
|
|
await handle_checkout_completed(event_data, db)
|
|
|
|
elif event_type == 'customer.subscription.created':
|
|
# Handle new subscription
|
|
await handle_subscription_created(event_data, db)
|
|
|
|
elif event_type == 'customer.subscription.updated':
|
|
# Handle subscription update
|
|
await handle_subscription_updated(event_data, db)
|
|
|
|
elif event_type == 'customer.subscription.deleted':
|
|
# Handle subscription cancellation
|
|
await handle_subscription_deleted(event_data, db)
|
|
|
|
elif event_type == 'invoice.payment_succeeded':
|
|
# Handle successful payment
|
|
await handle_payment_succeeded(event_data, db)
|
|
|
|
elif event_type == 'invoice.payment_failed':
|
|
# Handle failed payment
|
|
await handle_payment_failed(event_data, db)
|
|
|
|
elif event_type == 'customer.subscription.trial_will_end':
|
|
# Handle trial ending soon (3 days before)
|
|
await handle_trial_will_end(event_data, db)
|
|
|
|
else:
|
|
logger.info("Unhandled webhook event type", event_type=event_type)
|
|
|
|
return {"success": True, "event_type": event_type}
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Error processing Stripe webhook", error=str(e), exc_info=True)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Webhook processing error"
|
|
)
|
|
|
|
|
|
async def handle_checkout_completed(session: Dict[str, Any], db: AsyncSession):
|
|
"""Handle successful checkout session completion"""
|
|
logger.info("Processing checkout.session.completed",
|
|
session_id=session.get('id'))
|
|
|
|
customer_id = session.get('customer')
|
|
subscription_id = session.get('subscription')
|
|
|
|
if customer_id and subscription_id:
|
|
# Update tenant with subscription info
|
|
query = select(Tenant).where(Tenant.stripe_customer_id == customer_id)
|
|
result = await db.execute(query)
|
|
tenant = result.scalar_one_or_none()
|
|
|
|
if tenant:
|
|
logger.info("Checkout completed for tenant",
|
|
tenant_id=str(tenant.id),
|
|
subscription_id=subscription_id)
|
|
|
|
|
|
async def handle_subscription_created(subscription: Dict[str, Any], db: AsyncSession):
|
|
"""Handle new subscription creation"""
|
|
logger.info("Processing customer.subscription.created",
|
|
subscription_id=subscription.get('id'))
|
|
|
|
customer_id = subscription.get('customer')
|
|
subscription_id = subscription.get('id')
|
|
status_value = subscription.get('status')
|
|
|
|
# Find tenant by customer ID
|
|
query = select(Tenant).where(Tenant.stripe_customer_id == customer_id)
|
|
result = await db.execute(query)
|
|
tenant = result.scalar_one_or_none()
|
|
|
|
if tenant:
|
|
logger.info("Subscription created for tenant",
|
|
tenant_id=str(tenant.id),
|
|
subscription_id=subscription_id,
|
|
status=status_value)
|
|
|
|
|
|
async def handle_subscription_updated(subscription: Dict[str, Any], db: AsyncSession):
|
|
"""Handle subscription updates (status changes, plan changes, etc.)"""
|
|
subscription_id = subscription.get('id')
|
|
status_value = subscription.get('status')
|
|
|
|
logger.info("Processing customer.subscription.updated",
|
|
subscription_id=subscription_id,
|
|
new_status=status_value)
|
|
|
|
# Find subscription in database
|
|
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
|
result = await db.execute(query)
|
|
db_subscription = result.scalar_one_or_none()
|
|
|
|
if db_subscription:
|
|
# Update subscription status
|
|
db_subscription.status = status_value
|
|
db_subscription.current_period_end = datetime.fromtimestamp(
|
|
subscription.get('current_period_end')
|
|
)
|
|
|
|
# Update active status based on Stripe status
|
|
if status_value == 'active':
|
|
db_subscription.is_active = True
|
|
elif status_value in ['canceled', 'past_due', 'unpaid']:
|
|
db_subscription.is_active = False
|
|
|
|
await db.commit()
|
|
|
|
# Invalidate cache
|
|
try:
|
|
from app.services.subscription_cache import get_subscription_cache_service
|
|
import shared.redis_utils
|
|
|
|
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
|
cache_service = get_subscription_cache_service(redis_client)
|
|
await cache_service.invalidate_subscription_cache(str(db_subscription.tenant_id))
|
|
except Exception as cache_error:
|
|
logger.error("Failed to invalidate cache", error=str(cache_error))
|
|
|
|
logger.info("Subscription updated in database",
|
|
subscription_id=subscription_id,
|
|
tenant_id=str(db_subscription.tenant_id))
|
|
|
|
|
|
async def handle_subscription_deleted(subscription: Dict[str, Any], db: AsyncSession):
|
|
"""Handle subscription cancellation/deletion"""
|
|
subscription_id = subscription.get('id')
|
|
|
|
logger.info("Processing customer.subscription.deleted",
|
|
subscription_id=subscription_id)
|
|
|
|
# Find subscription in database
|
|
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
|
result = await db.execute(query)
|
|
db_subscription = result.scalar_one_or_none()
|
|
|
|
if db_subscription:
|
|
db_subscription.status = 'canceled'
|
|
db_subscription.is_active = False
|
|
db_subscription.canceled_at = datetime.utcnow()
|
|
|
|
await db.commit()
|
|
|
|
# Invalidate cache
|
|
try:
|
|
from app.services.subscription_cache import get_subscription_cache_service
|
|
import shared.redis_utils
|
|
|
|
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
|
cache_service = get_subscription_cache_service(redis_client)
|
|
await cache_service.invalidate_subscription_cache(str(db_subscription.tenant_id))
|
|
except Exception as cache_error:
|
|
logger.error("Failed to invalidate cache", error=str(cache_error))
|
|
|
|
logger.info("Subscription canceled in database",
|
|
subscription_id=subscription_id,
|
|
tenant_id=str(db_subscription.tenant_id))
|
|
|
|
|
|
async def handle_payment_succeeded(invoice: Dict[str, Any], db: AsyncSession):
|
|
"""Handle successful invoice payment"""
|
|
invoice_id = invoice.get('id')
|
|
subscription_id = invoice.get('subscription')
|
|
|
|
logger.info("Processing invoice.payment_succeeded",
|
|
invoice_id=invoice_id,
|
|
subscription_id=subscription_id)
|
|
|
|
if subscription_id:
|
|
# Find subscription and ensure it's active
|
|
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
|
result = await db.execute(query)
|
|
db_subscription = result.scalar_one_or_none()
|
|
|
|
if db_subscription:
|
|
db_subscription.status = 'active'
|
|
db_subscription.is_active = True
|
|
|
|
await db.commit()
|
|
|
|
logger.info("Payment succeeded, subscription activated",
|
|
subscription_id=subscription_id,
|
|
tenant_id=str(db_subscription.tenant_id))
|
|
|
|
|
|
async def handle_payment_failed(invoice: Dict[str, Any], db: AsyncSession):
|
|
"""Handle failed invoice payment"""
|
|
invoice_id = invoice.get('id')
|
|
subscription_id = invoice.get('subscription')
|
|
customer_id = invoice.get('customer')
|
|
|
|
logger.error("Processing invoice.payment_failed",
|
|
invoice_id=invoice_id,
|
|
subscription_id=subscription_id,
|
|
customer_id=customer_id)
|
|
|
|
if subscription_id:
|
|
# Find subscription and mark as past_due
|
|
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
|
result = await db.execute(query)
|
|
db_subscription = result.scalar_one_or_none()
|
|
|
|
if db_subscription:
|
|
db_subscription.status = 'past_due'
|
|
db_subscription.is_active = False
|
|
|
|
await db.commit()
|
|
|
|
logger.warning("Payment failed, subscription marked past_due",
|
|
subscription_id=subscription_id,
|
|
tenant_id=str(db_subscription.tenant_id))
|
|
|
|
# TODO: Send notification to user about payment failure
|
|
# You can integrate with your notification service here
|
|
|
|
|
|
async def handle_trial_will_end(subscription: Dict[str, Any], db: AsyncSession):
|
|
"""Handle notification that trial will end in 3 days"""
|
|
subscription_id = subscription.get('id')
|
|
trial_end = subscription.get('trial_end')
|
|
|
|
logger.info("Processing customer.subscription.trial_will_end",
|
|
subscription_id=subscription_id,
|
|
trial_end_timestamp=trial_end)
|
|
|
|
# Find subscription
|
|
query = select(Subscription).where(Subscription.subscription_id == subscription_id)
|
|
result = await db.execute(query)
|
|
db_subscription = result.scalar_one_or_none()
|
|
|
|
if db_subscription:
|
|
logger.info("Trial ending soon for subscription",
|
|
subscription_id=subscription_id,
|
|
tenant_id=str(db_subscription.tenant_id))
|
|
|
|
# TODO: Send notification to user about trial ending soon
|
|
# You can integrate with your notification service here
|
|
|
|
@router.post("/webhooks/generic")
|
|
async def generic_webhook(
|
|
request: Request,
|
|
payment_service: PaymentService = Depends(get_payment_service)
|
|
):
|
|
"""
|
|
Generic webhook endpoint that can handle events from any payment provider
|
|
"""
|
|
try:
|
|
# Get the payload
|
|
payload = await request.json()
|
|
|
|
# Log the event for debugging
|
|
logger.info("Received generic webhook", payload=payload)
|
|
|
|
# Process the event based on its type
|
|
event_type = payload.get('type', 'unknown')
|
|
event_data = payload.get('data', {})
|
|
|
|
# Process different types of events
|
|
if event_type == 'subscription.created':
|
|
# Handle new subscription
|
|
logger.info("Processing new subscription event", subscription_id=event_data.get('id'))
|
|
# Update database with new subscription
|
|
elif event_type == 'subscription.updated':
|
|
# Handle subscription update
|
|
logger.info("Processing subscription update event", subscription_id=event_data.get('id'))
|
|
# Update database with subscription changes
|
|
elif event_type == 'subscription.deleted':
|
|
# Handle subscription cancellation
|
|
logger.info("Processing subscription cancellation event", subscription_id=event_data.get('id'))
|
|
# Update database with cancellation
|
|
elif event_type == 'payment.succeeded':
|
|
# Handle successful payment
|
|
logger.info("Processing successful payment event", payment_id=event_data.get('id'))
|
|
# Update payment status in database
|
|
elif event_type == 'payment.failed':
|
|
# Handle failed payment
|
|
logger.info("Processing failed payment event", payment_id=event_data.get('id'))
|
|
# Update payment status and notify user
|
|
elif event_type == 'invoice.created':
|
|
# Handle new invoice
|
|
logger.info("Processing new invoice event", invoice_id=event_data.get('id'))
|
|
# Store invoice information
|
|
else:
|
|
logger.warning("Unknown event type received", event_type=event_type)
|
|
|
|
return {"success": True}
|
|
|
|
except Exception as e:
|
|
logger.error("Error processing generic webhook", error=str(e))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Webhook error"
|
|
)
|