Files
bakery-ia/services/tenant/app/repositories/subscription_repository.py
2026-01-11 21:40:04 +01:00

505 lines
20 KiB
Python

"""
Subscription Repository
Repository for subscription operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import json
from .base import TenantBaseRepository
from app.models.tenants import Subscription
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
from shared.subscription.plans import SubscriptionPlanMetadata, QuotaLimits, PlanPricing
logger = structlog.get_logger()
class SubscriptionRepository(TenantBaseRepository):
"""Repository for subscription operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Subscriptions are relatively stable, medium cache time (10 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
"""Create a new subscription with validation"""
try:
# Validate subscription data
validation_result = self._validate_tenant_data(
subscription_data,
["tenant_id", "plan"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
# Check for existing active subscription
existing_subscription = await self.get_active_subscription(
subscription_data["tenant_id"]
)
if existing_subscription:
raise DuplicateRecordError(f"Tenant already has an active subscription")
# Set default values based on plan from centralized configuration
plan = subscription_data["plan"]
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
# Set defaults from centralized plan configuration
if "monthly_price" not in subscription_data:
billing_cycle = subscription_data.get("billing_cycle", "monthly")
subscription_data["monthly_price"] = float(
PlanPricing.get_price(plan, billing_cycle)
)
if "max_users" not in subscription_data:
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
if "max_locations" not in subscription_data:
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
if "max_products" not in subscription_data:
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
if "features" not in subscription_data:
subscription_data["features"] = {
feature: True for feature in plan_info.get("features", [])
}
# Set default subscription values
if "status" not in subscription_data:
subscription_data["status"] = "active"
if "billing_cycle" not in subscription_data:
subscription_data["billing_cycle"] = "monthly"
if "next_billing_date" not in subscription_data:
# Set next billing date based on cycle
if subscription_data["billing_cycle"] == "yearly":
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Create subscription
subscription = await self.create(subscription_data)
logger.info("Subscription created successfully",
subscription_id=subscription.id,
tenant_id=subscription.tenant_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create subscription",
tenant_id=subscription_data.get("tenant_id"),
plan=subscription_data.get("plan"),
error=str(e))
raise DatabaseError(f"Failed to create subscription: {str(e)}")
async def get_by_tenant_id(self, tenant_id: str) -> Optional[Subscription]:
"""Get subscription by tenant ID"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by tenant ID",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_by_stripe_id(self, stripe_subscription_id: str) -> Optional[Subscription]:
"""Get subscription by Stripe subscription ID"""
try:
subscriptions = await self.get_multi(
filters={
"stripe_subscription_id": stripe_subscription_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by Stripe ID",
stripe_subscription_id=stripe_subscription_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
"""Get active subscription for tenant"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id,
"status": "active"
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get active subscription",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_tenant_subscriptions(
self,
tenant_id: str,
include_inactive: bool = False
) -> List[Subscription]:
"""Get all subscriptions for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if not include_inactive:
filters["status"] = "active"
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenant subscriptions",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
async def update_subscription_plan(
self,
subscription_id: str,
new_plan: str,
billing_cycle: str = "monthly"
) -> Optional[Subscription]:
"""Update subscription plan and pricing using centralized configuration"""
try:
valid_plans = ["starter", "professional", "enterprise"]
if new_plan not in valid_plans:
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
# Get current subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Get new plan configuration from centralized source
plan_info = SubscriptionPlanMetadata.get_plan_info(new_plan)
# Update subscription with new plan details
update_data = {
"plan": new_plan,
"monthly_price": float(PlanPricing.get_price(new_plan, billing_cycle)),
"billing_cycle": billing_cycle,
"max_users": QuotaLimits.get_limit('MAX_USERS', new_plan) or -1,
"max_locations": QuotaLimits.get_limit('MAX_LOCATIONS', new_plan) or -1,
"max_products": QuotaLimits.get_limit('MAX_PRODUCTS', new_plan) or -1,
"features": {feature: True for feature in plan_info.get("features", [])},
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription plan updated",
subscription_id=subscription_id,
new_plan=new_plan,
new_price=update_data["monthly_price"])
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription plan",
subscription_id=subscription_id,
new_plan=new_plan,
error=str(e))
raise DatabaseError(f"Failed to update plan: {str(e)}")
async def cancel_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Cancel a subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": "cancelled",
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription cancelled",
subscription_id=subscription_id,
reason=reason)
return updated_subscription
except Exception as e:
logger.error("Failed to cancel subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
async def suspend_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Suspend a subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": "suspended",
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription suspended",
subscription_id=subscription_id,
reason=reason)
return updated_subscription
except Exception as e:
logger.error("Failed to suspend subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
async def reactivate_subscription(
self,
subscription_id: str
) -> Optional[Subscription]:
"""Reactivate a cancelled or suspended subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Reset billing date when reactivating
next_billing_date = datetime.utcnow() + timedelta(days=30)
update_data = {
"status": "active",
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription reactivated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
return updated_subscription
except Exception as e:
logger.error("Failed to reactivate subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
async def get_subscriptions_due_for_billing(
self,
days_ahead: int = 3
) -> List[Subscription]:
"""Get subscriptions that need billing in the next N days"""
try:
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
query_text = """
SELECT * FROM subscriptions
WHERE status = 'active'
AND next_billing_date <= :cutoff_date
ORDER BY next_billing_date ASC
"""
result = await self.session.execute(text(query_text), {
"cutoff_date": cutoff_date
})
subscriptions = []
for row in result.fetchall():
record_dict = dict(row._mapping)
subscription = self.model(**record_dict)
subscriptions.append(subscription)
return subscriptions
except Exception as e:
logger.error("Failed to get subscriptions due for billing",
days_ahead=days_ahead,
error=str(e))
return []
async def update_billing_date(
self,
subscription_id: str,
next_billing_date: datetime
) -> Optional[Subscription]:
"""Update next billing date for subscription"""
try:
updated_subscription = await self.update(subscription_id, {
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
})
logger.info("Subscription billing date updated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
return updated_subscription
except Exception as e:
logger.error("Failed to update billing date",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to update billing date: {str(e)}")
async def get_subscription_statistics(self) -> Dict[str, Any]:
"""Get subscription statistics"""
try:
# Get counts by plan
plan_query = text("""
SELECT plan, COUNT(*) as count
FROM subscriptions
WHERE status = 'active'
GROUP BY plan
ORDER BY count DESC
""")
result = await self.session.execute(plan_query)
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
# Get counts by status
status_query = text("""
SELECT status, COUNT(*) as count
FROM subscriptions
GROUP BY status
ORDER BY count DESC
""")
result = await self.session.execute(status_query)
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
# Get revenue statistics
revenue_query = text("""
SELECT
SUM(monthly_price) as total_monthly_revenue,
AVG(monthly_price) as avg_monthly_price,
COUNT(*) as total_active_subscriptions
FROM subscriptions
WHERE status = 'active'
""")
revenue_result = await self.session.execute(revenue_query)
revenue_row = revenue_result.fetchone()
# Get upcoming billing count
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
return {
"subscriptions_by_plan": subscriptions_by_plan,
"subscriptions_by_status": subscriptions_by_status,
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
"upcoming_billing_30d": upcoming_billing
}
except Exception as e:
logger.error("Failed to get subscription statistics", error=str(e))
return {
"subscriptions_by_plan": {},
"subscriptions_by_status": {},
"total_monthly_revenue": 0.0,
"avg_monthly_price": 0.0,
"total_active_subscriptions": 0,
"upcoming_billing_30d": 0
}
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
"""Clean up very old cancelled subscriptions (2 years)"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM subscriptions
WHERE status IN ('cancelled', 'suspended')
AND updated_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old subscriptions",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old subscriptions",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def _invalidate_cache(self, tenant_id: str) -> None:
"""
Invalidate subscription cache for a tenant
Args:
tenant_id: Tenant ID
"""
try:
from app.services.subscription_cache import get_subscription_cache_service
cache_service = get_subscription_cache_service()
await cache_service.invalidate_subscription_cache(tenant_id)
logger.debug("Invalidated subscription cache from repository",
tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to invalidate cache (non-critical)",
tenant_id=tenant_id, error=str(e))