420 lines
16 KiB
Python
420 lines
16 KiB
Python
"""
|
|
Subscription Repository
|
|
Repository for subscription operations
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, text, and_
|
|
from datetime import datetime, timedelta
|
|
import structlog
|
|
import json
|
|
|
|
from .base import TenantBaseRepository
|
|
from app.models.tenants import Subscription
|
|
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class SubscriptionRepository(TenantBaseRepository):
|
|
"""Repository for subscription operations"""
|
|
|
|
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
|
# Subscriptions are relatively stable, medium cache time (10 minutes)
|
|
super().__init__(model_class, session, cache_ttl)
|
|
|
|
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
|
|
"""Create a new subscription with validation"""
|
|
try:
|
|
# Validate subscription data
|
|
validation_result = self._validate_tenant_data(
|
|
subscription_data,
|
|
["tenant_id", "plan"]
|
|
)
|
|
|
|
if not validation_result["is_valid"]:
|
|
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
|
|
|
|
# Check for existing active subscription
|
|
existing_subscription = await self.get_active_subscription(
|
|
subscription_data["tenant_id"]
|
|
)
|
|
|
|
if existing_subscription:
|
|
raise DuplicateRecordError(f"Tenant already has an active subscription")
|
|
|
|
# Set default values based on plan
|
|
plan_config = self._get_plan_configuration(subscription_data["plan"])
|
|
|
|
# Set defaults from plan config
|
|
for key, value in plan_config.items():
|
|
if key not in subscription_data:
|
|
subscription_data[key] = value
|
|
|
|
# Set default subscription values
|
|
if "status" not in subscription_data:
|
|
subscription_data["status"] = "active"
|
|
if "billing_cycle" not in subscription_data:
|
|
subscription_data["billing_cycle"] = "monthly"
|
|
if "next_billing_date" not in subscription_data:
|
|
# Set next billing date based on cycle
|
|
if subscription_data["billing_cycle"] == "yearly":
|
|
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
|
|
else:
|
|
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
|
|
|
# Create subscription
|
|
subscription = await self.create(subscription_data)
|
|
|
|
logger.info("Subscription created successfully",
|
|
subscription_id=subscription.id,
|
|
tenant_id=subscription.tenant_id,
|
|
plan=subscription.plan,
|
|
monthly_price=subscription.monthly_price)
|
|
|
|
return subscription
|
|
|
|
except (ValidationError, DuplicateRecordError):
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to create subscription",
|
|
tenant_id=subscription_data.get("tenant_id"),
|
|
plan=subscription_data.get("plan"),
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create subscription: {str(e)}")
|
|
|
|
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
|
|
"""Get active subscription for tenant"""
|
|
try:
|
|
subscriptions = await self.get_multi(
|
|
filters={
|
|
"tenant_id": tenant_id,
|
|
"status": "active"
|
|
},
|
|
limit=1,
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
return subscriptions[0] if subscriptions else None
|
|
except Exception as e:
|
|
logger.error("Failed to get active subscription",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
|
|
|
async def get_tenant_subscriptions(
|
|
self,
|
|
tenant_id: str,
|
|
include_inactive: bool = False
|
|
) -> List[Subscription]:
|
|
"""Get all subscriptions for a tenant"""
|
|
try:
|
|
filters = {"tenant_id": tenant_id}
|
|
|
|
if not include_inactive:
|
|
filters["status"] = "active"
|
|
|
|
return await self.get_multi(
|
|
filters=filters,
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get tenant subscriptions",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
|
|
|
|
async def update_subscription_plan(
|
|
self,
|
|
subscription_id: str,
|
|
new_plan: str
|
|
) -> Optional[Subscription]:
|
|
"""Update subscription plan and pricing"""
|
|
try:
|
|
valid_plans = ["basic", "professional", "enterprise"]
|
|
if new_plan not in valid_plans:
|
|
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
|
|
|
|
# Get new plan configuration
|
|
plan_config = self._get_plan_configuration(new_plan)
|
|
|
|
# Update subscription with new plan details
|
|
update_data = {
|
|
"plan": new_plan,
|
|
"monthly_price": plan_config["monthly_price"],
|
|
"max_users": plan_config["max_users"],
|
|
"max_locations": plan_config["max_locations"],
|
|
"max_products": plan_config["max_products"],
|
|
"updated_at": datetime.utcnow()
|
|
}
|
|
|
|
updated_subscription = await self.update(subscription_id, update_data)
|
|
|
|
logger.info("Subscription plan updated",
|
|
subscription_id=subscription_id,
|
|
new_plan=new_plan,
|
|
new_price=plan_config["monthly_price"])
|
|
|
|
return updated_subscription
|
|
|
|
except ValidationError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to update subscription plan",
|
|
subscription_id=subscription_id,
|
|
new_plan=new_plan,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to update plan: {str(e)}")
|
|
|
|
async def cancel_subscription(
|
|
self,
|
|
subscription_id: str,
|
|
reason: str = None
|
|
) -> Optional[Subscription]:
|
|
"""Cancel a subscription"""
|
|
try:
|
|
update_data = {
|
|
"status": "cancelled",
|
|
"updated_at": datetime.utcnow()
|
|
}
|
|
|
|
updated_subscription = await self.update(subscription_id, update_data)
|
|
|
|
logger.info("Subscription cancelled",
|
|
subscription_id=subscription_id,
|
|
reason=reason)
|
|
|
|
return updated_subscription
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cancel subscription",
|
|
subscription_id=subscription_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
|
|
|
|
async def suspend_subscription(
|
|
self,
|
|
subscription_id: str,
|
|
reason: str = None
|
|
) -> Optional[Subscription]:
|
|
"""Suspend a subscription"""
|
|
try:
|
|
update_data = {
|
|
"status": "suspended",
|
|
"updated_at": datetime.utcnow()
|
|
}
|
|
|
|
updated_subscription = await self.update(subscription_id, update_data)
|
|
|
|
logger.info("Subscription suspended",
|
|
subscription_id=subscription_id,
|
|
reason=reason)
|
|
|
|
return updated_subscription
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to suspend subscription",
|
|
subscription_id=subscription_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
|
|
|
|
async def reactivate_subscription(
|
|
self,
|
|
subscription_id: str
|
|
) -> Optional[Subscription]:
|
|
"""Reactivate a cancelled or suspended subscription"""
|
|
try:
|
|
# Reset billing date when reactivating
|
|
next_billing_date = datetime.utcnow() + timedelta(days=30)
|
|
|
|
update_data = {
|
|
"status": "active",
|
|
"next_billing_date": next_billing_date,
|
|
"updated_at": datetime.utcnow()
|
|
}
|
|
|
|
updated_subscription = await self.update(subscription_id, update_data)
|
|
|
|
logger.info("Subscription reactivated",
|
|
subscription_id=subscription_id,
|
|
next_billing_date=next_billing_date)
|
|
|
|
return updated_subscription
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to reactivate subscription",
|
|
subscription_id=subscription_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
|
|
|
|
async def get_subscriptions_due_for_billing(
|
|
self,
|
|
days_ahead: int = 3
|
|
) -> List[Subscription]:
|
|
"""Get subscriptions that need billing in the next N days"""
|
|
try:
|
|
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
|
|
|
|
query_text = """
|
|
SELECT * FROM subscriptions
|
|
WHERE status = 'active'
|
|
AND next_billing_date <= :cutoff_date
|
|
ORDER BY next_billing_date ASC
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {
|
|
"cutoff_date": cutoff_date
|
|
})
|
|
|
|
subscriptions = []
|
|
for row in result.fetchall():
|
|
record_dict = dict(row._mapping)
|
|
subscription = self.model(**record_dict)
|
|
subscriptions.append(subscription)
|
|
|
|
return subscriptions
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get subscriptions due for billing",
|
|
days_ahead=days_ahead,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def update_billing_date(
|
|
self,
|
|
subscription_id: str,
|
|
next_billing_date: datetime
|
|
) -> Optional[Subscription]:
|
|
"""Update next billing date for subscription"""
|
|
try:
|
|
updated_subscription = await self.update(subscription_id, {
|
|
"next_billing_date": next_billing_date,
|
|
"updated_at": datetime.utcnow()
|
|
})
|
|
|
|
logger.info("Subscription billing date updated",
|
|
subscription_id=subscription_id,
|
|
next_billing_date=next_billing_date)
|
|
|
|
return updated_subscription
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to update billing date",
|
|
subscription_id=subscription_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to update billing date: {str(e)}")
|
|
|
|
async def get_subscription_statistics(self) -> Dict[str, Any]:
|
|
"""Get subscription statistics"""
|
|
try:
|
|
# Get counts by plan
|
|
plan_query = text("""
|
|
SELECT plan, COUNT(*) as count
|
|
FROM subscriptions
|
|
WHERE status = 'active'
|
|
GROUP BY plan
|
|
ORDER BY count DESC
|
|
""")
|
|
|
|
result = await self.session.execute(plan_query)
|
|
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
|
|
|
|
# Get counts by status
|
|
status_query = text("""
|
|
SELECT status, COUNT(*) as count
|
|
FROM subscriptions
|
|
GROUP BY status
|
|
ORDER BY count DESC
|
|
""")
|
|
|
|
result = await self.session.execute(status_query)
|
|
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
|
|
|
|
# Get revenue statistics
|
|
revenue_query = text("""
|
|
SELECT
|
|
SUM(monthly_price) as total_monthly_revenue,
|
|
AVG(monthly_price) as avg_monthly_price,
|
|
COUNT(*) as total_active_subscriptions
|
|
FROM subscriptions
|
|
WHERE status = 'active'
|
|
""")
|
|
|
|
revenue_result = await self.session.execute(revenue_query)
|
|
revenue_row = revenue_result.fetchone()
|
|
|
|
# Get upcoming billing count
|
|
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
|
|
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
|
|
|
|
return {
|
|
"subscriptions_by_plan": subscriptions_by_plan,
|
|
"subscriptions_by_status": subscriptions_by_status,
|
|
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
|
|
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
|
|
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
|
|
"upcoming_billing_30d": upcoming_billing
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get subscription statistics", error=str(e))
|
|
return {
|
|
"subscriptions_by_plan": {},
|
|
"subscriptions_by_status": {},
|
|
"total_monthly_revenue": 0.0,
|
|
"avg_monthly_price": 0.0,
|
|
"total_active_subscriptions": 0,
|
|
"upcoming_billing_30d": 0
|
|
}
|
|
|
|
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
|
|
"""Clean up very old cancelled subscriptions (2 years)"""
|
|
try:
|
|
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
|
|
|
query_text = """
|
|
DELETE FROM subscriptions
|
|
WHERE status IN ('cancelled', 'suspended')
|
|
AND updated_at < :cutoff_date
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
|
deleted_count = result.rowcount
|
|
|
|
logger.info("Cleaned up old subscriptions",
|
|
deleted_count=deleted_count,
|
|
days_old=days_old)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup old subscriptions",
|
|
error=str(e))
|
|
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
|
|
|
def _get_plan_configuration(self, plan: str) -> Dict[str, Any]:
|
|
"""Get configuration for a subscription plan"""
|
|
plan_configs = {
|
|
"basic": {
|
|
"monthly_price": 29.99,
|
|
"max_users": 2,
|
|
"max_locations": 1,
|
|
"max_products": 50
|
|
},
|
|
"professional": {
|
|
"monthly_price": 79.99,
|
|
"max_users": 10,
|
|
"max_locations": 3,
|
|
"max_products": 200
|
|
},
|
|
"enterprise": {
|
|
"monthly_price": 199.99,
|
|
"max_users": 50,
|
|
"max_locations": 10,
|
|
"max_products": 1000
|
|
}
|
|
}
|
|
|
|
return plan_configs.get(plan, plan_configs["basic"]) |