""" Subscription Repository Repository for subscription operations """ from typing import Optional, List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, text, and_ from datetime import datetime, timedelta import structlog import json from .base import TenantBaseRepository from app.models.tenants import Subscription from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError from shared.subscription.plans import SubscriptionPlanMetadata, QuotaLimits, PlanPricing logger = structlog.get_logger() class SubscriptionRepository(TenantBaseRepository): """Repository for subscription operations""" def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600): # Subscriptions are relatively stable, medium cache time (10 minutes) super().__init__(model_class, session, cache_ttl) async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription: """Create a new subscription with validation""" try: # Validate subscription data validation_result = self._validate_tenant_data( subscription_data, ["tenant_id", "plan"] ) if not validation_result["is_valid"]: raise ValidationError(f"Invalid subscription data: {validation_result['errors']}") # Check for existing active subscription existing_subscription = await self.get_active_subscription( subscription_data["tenant_id"] ) if existing_subscription: raise DuplicateRecordError(f"Tenant already has an active subscription") # Set default values based on plan from centralized configuration plan = subscription_data["plan"] plan_info = SubscriptionPlanMetadata.get_plan_info(plan) # Set defaults from centralized plan configuration if "monthly_price" not in subscription_data: billing_cycle = subscription_data.get("billing_cycle", "monthly") subscription_data["monthly_price"] = float( PlanPricing.get_price(plan, billing_cycle) ) if "max_users" not in subscription_data: subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1 if "max_locations" not in subscription_data: subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1 if "max_products" not in subscription_data: subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1 if "features" not in subscription_data: subscription_data["features"] = { feature: True for feature in plan_info.get("features", []) } # Set default subscription values if "status" not in subscription_data: subscription_data["status"] = "active" if "billing_cycle" not in subscription_data: subscription_data["billing_cycle"] = "monthly" if "next_billing_date" not in subscription_data: # Set next billing date based on cycle if subscription_data["billing_cycle"] == "yearly": subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365) else: subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30) # Create subscription subscription = await self.create(subscription_data) logger.info("Subscription created successfully", subscription_id=subscription.id, tenant_id=subscription.tenant_id, plan=subscription.plan, monthly_price=subscription.monthly_price) return subscription except (ValidationError, DuplicateRecordError): raise except Exception as e: logger.error("Failed to create subscription", tenant_id=subscription_data.get("tenant_id"), plan=subscription_data.get("plan"), error=str(e)) raise DatabaseError(f"Failed to create subscription: {str(e)}") async def get_by_tenant_id(self, tenant_id: str) -> Optional[Subscription]: """Get subscription by tenant ID""" try: subscriptions = await self.get_multi( filters={ "tenant_id": tenant_id }, limit=1, order_by="created_at", order_desc=True ) return subscriptions[0] if subscriptions else None except Exception as e: logger.error("Failed to get subscription by tenant ID", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to get subscription: {str(e)}") async def get_by_provider_id(self, subscription_id: str) -> Optional[Subscription]: """Get subscription by payment provider subscription ID""" try: subscriptions = await self.get_multi( filters={ "subscription_id": subscription_id }, limit=1, order_by="created_at", order_desc=True ) return subscriptions[0] if subscriptions else None except Exception as e: logger.error("Failed to get subscription by provider ID", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to get subscription: {str(e)}") async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]: """Get active subscription for tenant""" try: subscriptions = await self.get_multi( filters={ "tenant_id": tenant_id, "status": "active" }, limit=1, order_by="created_at", order_desc=True ) return subscriptions[0] if subscriptions else None except Exception as e: logger.error("Failed to get active subscription", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to get subscription: {str(e)}") async def get_tenant_subscriptions( self, tenant_id: str, include_inactive: bool = False ) -> List[Subscription]: """Get all subscriptions for a tenant""" try: filters = {"tenant_id": tenant_id} if not include_inactive: filters["status"] = "active" return await self.get_multi( filters=filters, order_by="created_at", order_desc=True ) except Exception as e: logger.error("Failed to get tenant subscriptions", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to get subscriptions: {str(e)}") async def update_subscription_plan( self, subscription_id: str, new_plan: str, billing_cycle: str = "monthly" ) -> Optional[Subscription]: """Update subscription plan and pricing using centralized configuration""" try: valid_plans = ["starter", "professional", "enterprise"] if new_plan not in valid_plans: raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}") # Get current subscription to find tenant_id for cache invalidation subscription = await self.get_by_id(subscription_id) if not subscription: raise ValidationError(f"Subscription {subscription_id} not found") # Get new plan configuration from centralized source plan_info = SubscriptionPlanMetadata.get_plan_info(new_plan) # Update subscription with new plan details update_data = { "plan": new_plan, "monthly_price": float(PlanPricing.get_price(new_plan, billing_cycle)), "billing_cycle": billing_cycle, "max_users": QuotaLimits.get_limit('MAX_USERS', new_plan) or -1, "max_locations": QuotaLimits.get_limit('MAX_LOCATIONS', new_plan) or -1, "max_products": QuotaLimits.get_limit('MAX_PRODUCTS', new_plan) or -1, "features": {feature: True for feature in plan_info.get("features", [])}, "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) # Invalidate cache await self._invalidate_cache(str(subscription.tenant_id)) logger.info("Subscription plan updated", subscription_id=subscription_id, new_plan=new_plan, new_price=update_data["monthly_price"]) return updated_subscription except ValidationError: raise except Exception as e: logger.error("Failed to update subscription plan", subscription_id=subscription_id, new_plan=new_plan, error=str(e)) raise DatabaseError(f"Failed to update plan: {str(e)}") async def cancel_subscription( self, subscription_id: str, reason: str = None ) -> Optional[Subscription]: """Cancel a subscription""" try: # Get subscription to find tenant_id for cache invalidation subscription = await self.get_by_id(subscription_id) if not subscription: raise ValidationError(f"Subscription {subscription_id} not found") update_data = { "status": "cancelled", "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) # Invalidate cache await self._invalidate_cache(str(subscription.tenant_id)) logger.info("Subscription cancelled", subscription_id=subscription_id, reason=reason) return updated_subscription except Exception as e: logger.error("Failed to cancel subscription", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to cancel subscription: {str(e)}") async def suspend_subscription( self, subscription_id: str, reason: str = None ) -> Optional[Subscription]: """Suspend a subscription""" try: # Get subscription to find tenant_id for cache invalidation subscription = await self.get_by_id(subscription_id) if not subscription: raise ValidationError(f"Subscription {subscription_id} not found") update_data = { "status": "suspended", "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) # Invalidate cache await self._invalidate_cache(str(subscription.tenant_id)) logger.info("Subscription suspended", subscription_id=subscription_id, reason=reason) return updated_subscription except Exception as e: logger.error("Failed to suspend subscription", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to suspend subscription: {str(e)}") async def reactivate_subscription( self, subscription_id: str ) -> Optional[Subscription]: """Reactivate a cancelled or suspended subscription""" try: # Get subscription to find tenant_id for cache invalidation subscription = await self.get_by_id(subscription_id) if not subscription: raise ValidationError(f"Subscription {subscription_id} not found") # Reset billing date when reactivating next_billing_date = datetime.utcnow() + timedelta(days=30) update_data = { "status": "active", "next_billing_date": next_billing_date, "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) # Invalidate cache await self._invalidate_cache(str(subscription.tenant_id)) logger.info("Subscription reactivated", subscription_id=subscription_id, next_billing_date=next_billing_date) return updated_subscription except Exception as e: logger.error("Failed to reactivate subscription", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to reactivate subscription: {str(e)}") async def get_subscriptions_due_for_billing( self, days_ahead: int = 3 ) -> List[Subscription]: """Get subscriptions that need billing in the next N days""" try: cutoff_date = datetime.utcnow() + timedelta(days=days_ahead) query_text = """ SELECT * FROM subscriptions WHERE status = 'active' AND next_billing_date <= :cutoff_date ORDER BY next_billing_date ASC """ result = await self.session.execute(text(query_text), { "cutoff_date": cutoff_date }) subscriptions = [] for row in result.fetchall(): record_dict = dict(row._mapping) subscription = self.model(**record_dict) subscriptions.append(subscription) return subscriptions except Exception as e: logger.error("Failed to get subscriptions due for billing", days_ahead=days_ahead, error=str(e)) return [] async def update_billing_date( self, subscription_id: str, next_billing_date: datetime ) -> Optional[Subscription]: """Update next billing date for subscription""" try: updated_subscription = await self.update(subscription_id, { "next_billing_date": next_billing_date, "updated_at": datetime.utcnow() }) logger.info("Subscription billing date updated", subscription_id=subscription_id, next_billing_date=next_billing_date) return updated_subscription except Exception as e: logger.error("Failed to update billing date", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to update billing date: {str(e)}") async def get_subscription_statistics(self) -> Dict[str, Any]: """Get subscription statistics""" try: # Get counts by plan plan_query = text(""" SELECT plan, COUNT(*) as count FROM subscriptions WHERE status = 'active' GROUP BY plan ORDER BY count DESC """) result = await self.session.execute(plan_query) subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()} # Get counts by status status_query = text(""" SELECT status, COUNT(*) as count FROM subscriptions GROUP BY status ORDER BY count DESC """) result = await self.session.execute(status_query) subscriptions_by_status = {row.status: row.count for row in result.fetchall()} # Get revenue statistics revenue_query = text(""" SELECT SUM(monthly_price) as total_monthly_revenue, AVG(monthly_price) as avg_monthly_price, COUNT(*) as total_active_subscriptions FROM subscriptions WHERE status = 'active' """) revenue_result = await self.session.execute(revenue_query) revenue_row = revenue_result.fetchone() # Get upcoming billing count thirty_days_ahead = datetime.utcnow() + timedelta(days=30) upcoming_billing = len(await self.get_subscriptions_due_for_billing(30)) return { "subscriptions_by_plan": subscriptions_by_plan, "subscriptions_by_status": subscriptions_by_status, "total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0), "avg_monthly_price": float(revenue_row.avg_monthly_price or 0), "total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0), "upcoming_billing_30d": upcoming_billing } except Exception as e: logger.error("Failed to get subscription statistics", error=str(e)) return { "subscriptions_by_plan": {}, "subscriptions_by_status": {}, "total_monthly_revenue": 0.0, "avg_monthly_price": 0.0, "total_active_subscriptions": 0, "upcoming_billing_30d": 0 } async def cleanup_old_subscriptions(self, days_old: int = 730) -> int: """Clean up very old cancelled subscriptions (2 years)""" try: cutoff_date = datetime.utcnow() - timedelta(days=days_old) query_text = """ DELETE FROM subscriptions WHERE status IN ('cancelled', 'suspended') AND updated_at < :cutoff_date """ result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date}) deleted_count = result.rowcount logger.info("Cleaned up old subscriptions", deleted_count=deleted_count, days_old=days_old) return deleted_count except Exception as e: logger.error("Failed to cleanup old subscriptions", error=str(e)) raise DatabaseError(f"Cleanup failed: {str(e)}") async def _invalidate_cache(self, tenant_id: str) -> None: """ Invalidate subscription cache for a tenant Args: tenant_id: Tenant ID """ try: from app.services.subscription_cache import get_subscription_cache_service cache_service = get_subscription_cache_service() await cache_service.invalidate_subscription_cache(tenant_id) logger.debug("Invalidated subscription cache from repository", tenant_id=tenant_id) except Exception as e: logger.warning("Failed to invalidate cache (non-critical)", tenant_id=tenant_id, error=str(e)) # ======================================================================== # TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture) # ======================================================================== async def create_tenant_independent_subscription( self, subscription_data: Dict[str, Any] ) -> Subscription: """Create a subscription not linked to any tenant (for registration flow)""" try: # Validate required data for tenant-independent subscription required_fields = ["user_id", "plan", "subscription_id", "customer_id"] validation_result = self._validate_tenant_data(subscription_data, required_fields) if not validation_result["is_valid"]: raise ValidationError(f"Invalid subscription data: {validation_result['errors']}") # Ensure tenant_id is not provided (this is tenant-independent) if "tenant_id" in subscription_data and subscription_data["tenant_id"]: raise ValidationError("tenant_id should not be provided for tenant-independent subscriptions") # Set tenant-independent specific fields subscription_data["tenant_id"] = None subscription_data["is_tenant_linked"] = False subscription_data["tenant_linking_status"] = "pending" subscription_data["linked_at"] = None # Set default values based on plan from centralized configuration plan = subscription_data["plan"] plan_info = SubscriptionPlanMetadata.get_plan_info(plan) # Set defaults from centralized plan configuration if "monthly_price" not in subscription_data: billing_cycle = subscription_data.get("billing_cycle", "monthly") subscription_data["monthly_price"] = float( PlanPricing.get_price(plan, billing_cycle) ) if "max_users" not in subscription_data: subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1 if "max_locations" not in subscription_data: subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1 if "max_products" not in subscription_data: subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1 if "features" not in subscription_data: subscription_data["features"] = { feature: True for feature in plan_info.get("features", []) } # Set default subscription values if "status" not in subscription_data: subscription_data["status"] = "pending_tenant_linking" if "billing_cycle" not in subscription_data: subscription_data["billing_cycle"] = "monthly" if "next_billing_date" not in subscription_data: # Set next billing date based on cycle if subscription_data["billing_cycle"] == "yearly": subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365) else: subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30) # Create tenant-independent subscription subscription = await self.create(subscription_data) logger.info("Tenant-independent subscription created successfully", subscription_id=subscription.id, user_id=subscription.user_id, plan=subscription.plan, monthly_price=subscription.monthly_price) return subscription except (ValidationError, DuplicateRecordError): raise except Exception as e: logger.error("Failed to create tenant-independent subscription", user_id=subscription_data.get("user_id"), plan=subscription_data.get("plan"), error=str(e)) raise DatabaseError(f"Failed to create tenant-independent subscription: {str(e)}") async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]: """Get all subscriptions waiting to be linked to tenants""" try: subscriptions = await self.get_multi( filters={ "tenant_linking_status": "pending", "is_tenant_linked": False }, order_by="created_at", order_desc=True ) return subscriptions except Exception as e: logger.error("Failed to get pending tenant linking subscriptions", error=str(e)) raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}") async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]: """Get pending tenant linking subscriptions for a specific user""" try: subscriptions = await self.get_multi( filters={ "user_id": user_id, "tenant_linking_status": "pending", "is_tenant_linked": False }, order_by="created_at", order_desc=True ) return subscriptions except Exception as e: logger.error("Failed to get pending subscriptions by user", user_id=user_id, error=str(e)) raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}") async def link_subscription_to_tenant( self, subscription_id: str, tenant_id: str, user_id: str ) -> Subscription: """Link a pending subscription to a tenant""" try: # Get the subscription first subscription = await self.get_by_id(subscription_id) if not subscription: raise ValidationError(f"Subscription {subscription_id} not found") # Validate subscription can be linked if not subscription.can_be_linked_to_tenant(user_id): raise ValidationError( f"Subscription {subscription_id} cannot be linked to tenant by user {user_id}. " f"Current status: {subscription.tenant_linking_status}, " f"User: {subscription.user_id}, " f"Already linked: {subscription.is_tenant_linked}" ) # Update subscription with tenant information update_data = { "tenant_id": tenant_id, "is_tenant_linked": True, "tenant_linking_status": "completed", "linked_at": datetime.utcnow(), "status": "active", # Activate subscription when linked to tenant "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) # Invalidate cache for the tenant await self._invalidate_cache(tenant_id) logger.info("Subscription linked to tenant successfully", subscription_id=subscription_id, tenant_id=tenant_id, user_id=user_id) return updated_subscription except Exception as e: logger.error("Failed to link subscription to tenant", subscription_id=subscription_id, tenant_id=tenant_id, user_id=user_id, error=str(e)) raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}") async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int: """Clean up subscriptions that were never linked to tenants""" try: cutoff_date = datetime.utcnow() - timedelta(days=days_old) query_text = """ DELETE FROM subscriptions WHERE tenant_linking_status = 'pending' AND is_tenant_linked = FALSE AND created_at < :cutoff_date """ result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date}) deleted_count = result.rowcount logger.info("Cleaned up orphaned subscriptions", deleted_count=deleted_count, days_old=days_old) return deleted_count except Exception as e: logger.error("Failed to cleanup orphaned subscriptions", error=str(e)) raise DatabaseError(f"Cleanup failed: {str(e)}")