Files
bakery-ia/services/tenant/app/repositories/subscription_repository.py

505 lines
20 KiB
Python
Raw Normal View History

2025-08-08 09:08:41 +02:00
"""
Subscription Repository
Repository for subscription operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import json
from .base import TenantBaseRepository
from app.models.tenants import Subscription
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
2025-10-29 06:58:05 +01:00
from shared.subscription.plans import SubscriptionPlanMetadata, QuotaLimits, PlanPricing
2025-08-08 09:08:41 +02:00
logger = structlog.get_logger()
class SubscriptionRepository(TenantBaseRepository):
"""Repository for subscription operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Subscriptions are relatively stable, medium cache time (10 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
"""Create a new subscription with validation"""
try:
# Validate subscription data
validation_result = self._validate_tenant_data(
subscription_data,
["tenant_id", "plan"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
# Check for existing active subscription
existing_subscription = await self.get_active_subscription(
subscription_data["tenant_id"]
)
if existing_subscription:
raise DuplicateRecordError(f"Tenant already has an active subscription")
2025-10-29 06:58:05 +01:00
# Set default values based on plan from centralized configuration
plan = subscription_data["plan"]
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
# Set defaults from centralized plan configuration
if "monthly_price" not in subscription_data:
billing_cycle = subscription_data.get("billing_cycle", "monthly")
subscription_data["monthly_price"] = float(
PlanPricing.get_price(plan, billing_cycle)
)
if "max_users" not in subscription_data:
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
if "max_locations" not in subscription_data:
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
if "max_products" not in subscription_data:
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
if "features" not in subscription_data:
subscription_data["features"] = {
feature: True for feature in plan_info.get("features", [])
}
2025-08-08 09:08:41 +02:00
# Set default subscription values
if "status" not in subscription_data:
subscription_data["status"] = "active"
if "billing_cycle" not in subscription_data:
subscription_data["billing_cycle"] = "monthly"
if "next_billing_date" not in subscription_data:
# Set next billing date based on cycle
if subscription_data["billing_cycle"] == "yearly":
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Create subscription
subscription = await self.create(subscription_data)
logger.info("Subscription created successfully",
subscription_id=subscription.id,
tenant_id=subscription.tenant_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create subscription",
tenant_id=subscription_data.get("tenant_id"),
plan=subscription_data.get("plan"),
error=str(e))
raise DatabaseError(f"Failed to create subscription: {str(e)}")
2026-01-11 21:40:04 +01:00
async def get_by_tenant_id(self, tenant_id: str) -> Optional[Subscription]:
"""Get subscription by tenant ID"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by tenant ID",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_by_stripe_id(self, stripe_subscription_id: str) -> Optional[Subscription]:
"""Get subscription by Stripe subscription ID"""
try:
subscriptions = await self.get_multi(
filters={
"stripe_subscription_id": stripe_subscription_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by Stripe ID",
stripe_subscription_id=stripe_subscription_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
2025-08-08 09:08:41 +02:00
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
"""Get active subscription for tenant"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id,
"status": "active"
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get active subscription",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_tenant_subscriptions(
self,
tenant_id: str,
include_inactive: bool = False
) -> List[Subscription]:
"""Get all subscriptions for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if not include_inactive:
filters["status"] = "active"
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenant subscriptions",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
async def update_subscription_plan(
self,
subscription_id: str,
2025-10-29 06:58:05 +01:00
new_plan: str,
billing_cycle: str = "monthly"
2025-08-08 09:08:41 +02:00
) -> Optional[Subscription]:
2025-10-29 06:58:05 +01:00
"""Update subscription plan and pricing using centralized configuration"""
2025-08-08 09:08:41 +02:00
try:
2025-09-01 19:21:12 +02:00
valid_plans = ["starter", "professional", "enterprise"]
2025-08-08 09:08:41 +02:00
if new_plan not in valid_plans:
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
2025-10-29 06:58:05 +01:00
# Get current subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Get new plan configuration from centralized source
plan_info = SubscriptionPlanMetadata.get_plan_info(new_plan)
2025-08-08 09:08:41 +02:00
# Update subscription with new plan details
update_data = {
"plan": new_plan,
2025-10-29 06:58:05 +01:00
"monthly_price": float(PlanPricing.get_price(new_plan, billing_cycle)),
"billing_cycle": billing_cycle,
"max_users": QuotaLimits.get_limit('MAX_USERS', new_plan) or -1,
"max_locations": QuotaLimits.get_limit('MAX_LOCATIONS', new_plan) or -1,
"max_products": QuotaLimits.get_limit('MAX_PRODUCTS', new_plan) or -1,
"features": {feature: True for feature in plan_info.get("features", [])},
2025-08-08 09:08:41 +02:00
"updated_at": datetime.utcnow()
}
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
updated_subscription = await self.update(subscription_id, update_data)
2025-10-29 06:58:05 +01:00
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
2025-08-08 09:08:41 +02:00
logger.info("Subscription plan updated",
subscription_id=subscription_id,
new_plan=new_plan,
2025-10-29 06:58:05 +01:00
new_price=update_data["monthly_price"])
2025-08-08 09:08:41 +02:00
return updated_subscription
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription plan",
subscription_id=subscription_id,
new_plan=new_plan,
error=str(e))
raise DatabaseError(f"Failed to update plan: {str(e)}")
async def cancel_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Cancel a subscription"""
try:
2025-10-29 06:58:05 +01:00
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
2025-08-08 09:08:41 +02:00
update_data = {
"status": "cancelled",
"updated_at": datetime.utcnow()
}
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
updated_subscription = await self.update(subscription_id, update_data)
2025-10-29 06:58:05 +01:00
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
2025-08-08 09:08:41 +02:00
logger.info("Subscription cancelled",
subscription_id=subscription_id,
reason=reason)
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
return updated_subscription
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
except Exception as e:
logger.error("Failed to cancel subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
async def suspend_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Suspend a subscription"""
try:
2025-10-29 06:58:05 +01:00
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
2025-08-08 09:08:41 +02:00
update_data = {
"status": "suspended",
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
2025-10-29 06:58:05 +01:00
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
2025-08-08 09:08:41 +02:00
logger.info("Subscription suspended",
subscription_id=subscription_id,
reason=reason)
return updated_subscription
except Exception as e:
logger.error("Failed to suspend subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
async def reactivate_subscription(
self,
subscription_id: str
) -> Optional[Subscription]:
"""Reactivate a cancelled or suspended subscription"""
try:
2025-10-29 06:58:05 +01:00
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
2025-08-08 09:08:41 +02:00
# Reset billing date when reactivating
next_billing_date = datetime.utcnow() + timedelta(days=30)
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
update_data = {
"status": "active",
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
}
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
updated_subscription = await self.update(subscription_id, update_data)
2025-10-29 06:58:05 +01:00
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
2025-08-08 09:08:41 +02:00
logger.info("Subscription reactivated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
return updated_subscription
2025-10-29 06:58:05 +01:00
2025-08-08 09:08:41 +02:00
except Exception as e:
logger.error("Failed to reactivate subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
async def get_subscriptions_due_for_billing(
self,
days_ahead: int = 3
) -> List[Subscription]:
"""Get subscriptions that need billing in the next N days"""
try:
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
query_text = """
SELECT * FROM subscriptions
WHERE status = 'active'
AND next_billing_date <= :cutoff_date
ORDER BY next_billing_date ASC
"""
result = await self.session.execute(text(query_text), {
"cutoff_date": cutoff_date
})
subscriptions = []
for row in result.fetchall():
record_dict = dict(row._mapping)
subscription = self.model(**record_dict)
subscriptions.append(subscription)
return subscriptions
except Exception as e:
logger.error("Failed to get subscriptions due for billing",
days_ahead=days_ahead,
error=str(e))
return []
async def update_billing_date(
self,
subscription_id: str,
next_billing_date: datetime
) -> Optional[Subscription]:
"""Update next billing date for subscription"""
try:
updated_subscription = await self.update(subscription_id, {
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
})
logger.info("Subscription billing date updated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
return updated_subscription
except Exception as e:
logger.error("Failed to update billing date",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to update billing date: {str(e)}")
async def get_subscription_statistics(self) -> Dict[str, Any]:
"""Get subscription statistics"""
try:
# Get counts by plan
plan_query = text("""
SELECT plan, COUNT(*) as count
FROM subscriptions
WHERE status = 'active'
GROUP BY plan
ORDER BY count DESC
""")
result = await self.session.execute(plan_query)
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
# Get counts by status
status_query = text("""
SELECT status, COUNT(*) as count
FROM subscriptions
GROUP BY status
ORDER BY count DESC
""")
result = await self.session.execute(status_query)
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
# Get revenue statistics
revenue_query = text("""
SELECT
SUM(monthly_price) as total_monthly_revenue,
AVG(monthly_price) as avg_monthly_price,
COUNT(*) as total_active_subscriptions
FROM subscriptions
WHERE status = 'active'
""")
revenue_result = await self.session.execute(revenue_query)
revenue_row = revenue_result.fetchone()
# Get upcoming billing count
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
return {
"subscriptions_by_plan": subscriptions_by_plan,
"subscriptions_by_status": subscriptions_by_status,
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
"upcoming_billing_30d": upcoming_billing
}
except Exception as e:
logger.error("Failed to get subscription statistics", error=str(e))
return {
"subscriptions_by_plan": {},
"subscriptions_by_status": {},
"total_monthly_revenue": 0.0,
"avg_monthly_price": 0.0,
"total_active_subscriptions": 0,
"upcoming_billing_30d": 0
}
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
"""Clean up very old cancelled subscriptions (2 years)"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM subscriptions
WHERE status IN ('cancelled', 'suspended')
AND updated_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old subscriptions",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old subscriptions",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
2025-10-29 06:58:05 +01:00
async def _invalidate_cache(self, tenant_id: str) -> None:
"""
Invalidate subscription cache for a tenant
Args:
tenant_id: Tenant ID
"""
try:
from app.services.subscription_cache import get_subscription_cache_service
cache_service = get_subscription_cache_service()
await cache_service.invalidate_subscription_cache(tenant_id)
logger.debug("Invalidated subscription cache from repository",
tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to invalidate cache (non-critical)",
tenant_id=tenant_id, error=str(e))