REFACTOR - Database logic
This commit is contained in:
420
services/tenant/app/repositories/subscription_repository.py
Normal file
420
services/tenant/app/repositories/subscription_repository.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""
|
||||
Subscription Repository
|
||||
Repository for subscription operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import Subscription
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionRepository(TenantBaseRepository):
|
||||
"""Repository for subscription operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Subscriptions are relatively stable, medium cache time (10 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
|
||||
"""Create a new subscription with validation"""
|
||||
try:
|
||||
# Validate subscription data
|
||||
validation_result = self._validate_tenant_data(
|
||||
subscription_data,
|
||||
["tenant_id", "plan"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
|
||||
|
||||
# Check for existing active subscription
|
||||
existing_subscription = await self.get_active_subscription(
|
||||
subscription_data["tenant_id"]
|
||||
)
|
||||
|
||||
if existing_subscription:
|
||||
raise DuplicateRecordError(f"Tenant already has an active subscription")
|
||||
|
||||
# Set default values based on plan
|
||||
plan_config = self._get_plan_configuration(subscription_data["plan"])
|
||||
|
||||
# Set defaults from plan config
|
||||
for key, value in plan_config.items():
|
||||
if key not in subscription_data:
|
||||
subscription_data[key] = value
|
||||
|
||||
# Set default subscription values
|
||||
if "status" not in subscription_data:
|
||||
subscription_data["status"] = "active"
|
||||
if "billing_cycle" not in subscription_data:
|
||||
subscription_data["billing_cycle"] = "monthly"
|
||||
if "next_billing_date" not in subscription_data:
|
||||
# Set next billing date based on cycle
|
||||
if subscription_data["billing_cycle"] == "yearly":
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
|
||||
else:
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
# Create subscription
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
logger.info("Subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
return subscription
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create subscription",
|
||||
tenant_id=subscription_data.get("tenant_id"),
|
||||
plan=subscription_data.get("plan"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create subscription: {str(e)}")
|
||||
|
||||
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
|
||||
"""Get active subscription for tenant"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"status": "active"
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions[0] if subscriptions else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active subscription",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
||||
|
||||
async def get_tenant_subscriptions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
include_inactive: bool = False
|
||||
) -> List[Subscription]:
|
||||
"""Get all subscriptions for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if not include_inactive:
|
||||
filters["status"] = "active"
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant subscriptions",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
|
||||
|
||||
async def update_subscription_plan(
|
||||
self,
|
||||
subscription_id: str,
|
||||
new_plan: str
|
||||
) -> Optional[Subscription]:
|
||||
"""Update subscription plan and pricing"""
|
||||
try:
|
||||
valid_plans = ["basic", "professional", "enterprise"]
|
||||
if new_plan not in valid_plans:
|
||||
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
|
||||
|
||||
# Get new plan configuration
|
||||
plan_config = self._get_plan_configuration(new_plan)
|
||||
|
||||
# Update subscription with new plan details
|
||||
update_data = {
|
||||
"plan": new_plan,
|
||||
"monthly_price": plan_config["monthly_price"],
|
||||
"max_users": plan_config["max_users"],
|
||||
"max_locations": plan_config["max_locations"],
|
||||
"max_products": plan_config["max_products"],
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
logger.info("Subscription plan updated",
|
||||
subscription_id=subscription_id,
|
||||
new_plan=new_plan,
|
||||
new_price=plan_config["monthly_price"])
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription plan",
|
||||
subscription_id=subscription_id,
|
||||
new_plan=new_plan,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update plan: {str(e)}")
|
||||
|
||||
async def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
reason: str = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Cancel a subscription"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": "cancelled",
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
logger.info("Subscription cancelled",
|
||||
subscription_id=subscription_id,
|
||||
reason=reason)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
|
||||
|
||||
async def suspend_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
reason: str = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Suspend a subscription"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": "suspended",
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
logger.info("Subscription suspended",
|
||||
subscription_id=subscription_id,
|
||||
reason=reason)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to suspend subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
|
||||
|
||||
async def reactivate_subscription(
|
||||
self,
|
||||
subscription_id: str
|
||||
) -> Optional[Subscription]:
|
||||
"""Reactivate a cancelled or suspended subscription"""
|
||||
try:
|
||||
# Reset billing date when reactivating
|
||||
next_billing_date = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
update_data = {
|
||||
"status": "active",
|
||||
"next_billing_date": next_billing_date,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
logger.info("Subscription reactivated",
|
||||
subscription_id=subscription_id,
|
||||
next_billing_date=next_billing_date)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to reactivate subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
|
||||
|
||||
async def get_subscriptions_due_for_billing(
|
||||
self,
|
||||
days_ahead: int = 3
|
||||
) -> List[Subscription]:
|
||||
"""Get subscriptions that need billing in the next N days"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
AND next_billing_date <= :cutoff_date
|
||||
ORDER BY next_billing_date ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
subscriptions = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
subscription = self.model(**record_dict)
|
||||
subscriptions.append(subscription)
|
||||
|
||||
return subscriptions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscriptions due for billing",
|
||||
days_ahead=days_ahead,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def update_billing_date(
|
||||
self,
|
||||
subscription_id: str,
|
||||
next_billing_date: datetime
|
||||
) -> Optional[Subscription]:
|
||||
"""Update next billing date for subscription"""
|
||||
try:
|
||||
updated_subscription = await self.update(subscription_id, {
|
||||
"next_billing_date": next_billing_date,
|
||||
"updated_at": datetime.utcnow()
|
||||
})
|
||||
|
||||
logger.info("Subscription billing date updated",
|
||||
subscription_id=subscription_id,
|
||||
next_billing_date=next_billing_date)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update billing date",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update billing date: {str(e)}")
|
||||
|
||||
async def get_subscription_statistics(self) -> Dict[str, Any]:
|
||||
"""Get subscription statistics"""
|
||||
try:
|
||||
# Get counts by plan
|
||||
plan_query = text("""
|
||||
SELECT plan, COUNT(*) as count
|
||||
FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
GROUP BY plan
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(plan_query)
|
||||
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
|
||||
|
||||
# Get counts by status
|
||||
status_query = text("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM subscriptions
|
||||
GROUP BY status
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(status_query)
|
||||
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
|
||||
|
||||
# Get revenue statistics
|
||||
revenue_query = text("""
|
||||
SELECT
|
||||
SUM(monthly_price) as total_monthly_revenue,
|
||||
AVG(monthly_price) as avg_monthly_price,
|
||||
COUNT(*) as total_active_subscriptions
|
||||
FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
""")
|
||||
|
||||
revenue_result = await self.session.execute(revenue_query)
|
||||
revenue_row = revenue_result.fetchone()
|
||||
|
||||
# Get upcoming billing count
|
||||
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
|
||||
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
|
||||
|
||||
return {
|
||||
"subscriptions_by_plan": subscriptions_by_plan,
|
||||
"subscriptions_by_status": subscriptions_by_status,
|
||||
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
|
||||
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
|
||||
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
|
||||
"upcoming_billing_30d": upcoming_billing
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription statistics", error=str(e))
|
||||
return {
|
||||
"subscriptions_by_plan": {},
|
||||
"subscriptions_by_status": {},
|
||||
"total_monthly_revenue": 0.0,
|
||||
"avg_monthly_price": 0.0,
|
||||
"total_active_subscriptions": 0,
|
||||
"upcoming_billing_30d": 0
|
||||
}
|
||||
|
||||
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
|
||||
"""Clean up very old cancelled subscriptions (2 years)"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM subscriptions
|
||||
WHERE status IN ('cancelled', 'suspended')
|
||||
AND updated_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old subscriptions",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
def _get_plan_configuration(self, plan: str) -> Dict[str, Any]:
|
||||
"""Get configuration for a subscription plan"""
|
||||
plan_configs = {
|
||||
"basic": {
|
||||
"monthly_price": 29.99,
|
||||
"max_users": 2,
|
||||
"max_locations": 1,
|
||||
"max_products": 50
|
||||
},
|
||||
"professional": {
|
||||
"monthly_price": 79.99,
|
||||
"max_users": 10,
|
||||
"max_locations": 3,
|
||||
"max_products": 200
|
||||
},
|
||||
"enterprise": {
|
||||
"monthly_price": 199.99,
|
||||
"max_users": 50,
|
||||
"max_locations": 10,
|
||||
"max_products": 1000
|
||||
}
|
||||
}
|
||||
|
||||
return plan_configs.get(plan, plan_configs["basic"])
|
||||
Reference in New Issue
Block a user