""" Subscription Repository Repository for subscription operations """ from typing import Optional, List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, text, and_ from datetime import datetime, timedelta import structlog import json from .base import TenantBaseRepository from app.models.tenants import Subscription from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError logger = structlog.get_logger() class SubscriptionRepository(TenantBaseRepository): """Repository for subscription operations""" def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600): # Subscriptions are relatively stable, medium cache time (10 minutes) super().__init__(model_class, session, cache_ttl) async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription: """Create a new subscription with validation""" try: # Validate subscription data validation_result = self._validate_tenant_data( subscription_data, ["tenant_id", "plan"] ) if not validation_result["is_valid"]: raise ValidationError(f"Invalid subscription data: {validation_result['errors']}") # Check for existing active subscription existing_subscription = await self.get_active_subscription( subscription_data["tenant_id"] ) if existing_subscription: raise DuplicateRecordError(f"Tenant already has an active subscription") # Set default values based on plan plan_config = self._get_plan_configuration(subscription_data["plan"]) # Set defaults from plan config for key, value in plan_config.items(): if key not in subscription_data: subscription_data[key] = value # Set default subscription values if "status" not in subscription_data: subscription_data["status"] = "active" if "billing_cycle" not in subscription_data: subscription_data["billing_cycle"] = "monthly" if "next_billing_date" not in subscription_data: # Set next billing date based on cycle if subscription_data["billing_cycle"] == "yearly": subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365) else: subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30) # Create subscription subscription = await self.create(subscription_data) logger.info("Subscription created successfully", subscription_id=subscription.id, tenant_id=subscription.tenant_id, plan=subscription.plan, monthly_price=subscription.monthly_price) return subscription except (ValidationError, DuplicateRecordError): raise except Exception as e: logger.error("Failed to create subscription", tenant_id=subscription_data.get("tenant_id"), plan=subscription_data.get("plan"), error=str(e)) raise DatabaseError(f"Failed to create subscription: {str(e)}") async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]: """Get active subscription for tenant""" try: subscriptions = await self.get_multi( filters={ "tenant_id": tenant_id, "status": "active" }, limit=1, order_by="created_at", order_desc=True ) return subscriptions[0] if subscriptions else None except Exception as e: logger.error("Failed to get active subscription", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to get subscription: {str(e)}") async def get_tenant_subscriptions( self, tenant_id: str, include_inactive: bool = False ) -> List[Subscription]: """Get all subscriptions for a tenant""" try: filters = {"tenant_id": tenant_id} if not include_inactive: filters["status"] = "active" return await self.get_multi( filters=filters, order_by="created_at", order_desc=True ) except Exception as e: logger.error("Failed to get tenant subscriptions", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to get subscriptions: {str(e)}") async def update_subscription_plan( self, subscription_id: str, new_plan: str ) -> Optional[Subscription]: """Update subscription plan and pricing""" try: valid_plans = ["basic", "professional", "enterprise"] if new_plan not in valid_plans: raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}") # Get new plan configuration plan_config = self._get_plan_configuration(new_plan) # Update subscription with new plan details update_data = { "plan": new_plan, "monthly_price": plan_config["monthly_price"], "max_users": plan_config["max_users"], "max_locations": plan_config["max_locations"], "max_products": plan_config["max_products"], "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) logger.info("Subscription plan updated", subscription_id=subscription_id, new_plan=new_plan, new_price=plan_config["monthly_price"]) return updated_subscription except ValidationError: raise except Exception as e: logger.error("Failed to update subscription plan", subscription_id=subscription_id, new_plan=new_plan, error=str(e)) raise DatabaseError(f"Failed to update plan: {str(e)}") async def cancel_subscription( self, subscription_id: str, reason: str = None ) -> Optional[Subscription]: """Cancel a subscription""" try: update_data = { "status": "cancelled", "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) logger.info("Subscription cancelled", subscription_id=subscription_id, reason=reason) return updated_subscription except Exception as e: logger.error("Failed to cancel subscription", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to cancel subscription: {str(e)}") async def suspend_subscription( self, subscription_id: str, reason: str = None ) -> Optional[Subscription]: """Suspend a subscription""" try: update_data = { "status": "suspended", "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) logger.info("Subscription suspended", subscription_id=subscription_id, reason=reason) return updated_subscription except Exception as e: logger.error("Failed to suspend subscription", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to suspend subscription: {str(e)}") async def reactivate_subscription( self, subscription_id: str ) -> Optional[Subscription]: """Reactivate a cancelled or suspended subscription""" try: # Reset billing date when reactivating next_billing_date = datetime.utcnow() + timedelta(days=30) update_data = { "status": "active", "next_billing_date": next_billing_date, "updated_at": datetime.utcnow() } updated_subscription = await self.update(subscription_id, update_data) logger.info("Subscription reactivated", subscription_id=subscription_id, next_billing_date=next_billing_date) return updated_subscription except Exception as e: logger.error("Failed to reactivate subscription", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to reactivate subscription: {str(e)}") async def get_subscriptions_due_for_billing( self, days_ahead: int = 3 ) -> List[Subscription]: """Get subscriptions that need billing in the next N days""" try: cutoff_date = datetime.utcnow() + timedelta(days=days_ahead) query_text = """ SELECT * FROM subscriptions WHERE status = 'active' AND next_billing_date <= :cutoff_date ORDER BY next_billing_date ASC """ result = await self.session.execute(text(query_text), { "cutoff_date": cutoff_date }) subscriptions = [] for row in result.fetchall(): record_dict = dict(row._mapping) subscription = self.model(**record_dict) subscriptions.append(subscription) return subscriptions except Exception as e: logger.error("Failed to get subscriptions due for billing", days_ahead=days_ahead, error=str(e)) return [] async def update_billing_date( self, subscription_id: str, next_billing_date: datetime ) -> Optional[Subscription]: """Update next billing date for subscription""" try: updated_subscription = await self.update(subscription_id, { "next_billing_date": next_billing_date, "updated_at": datetime.utcnow() }) logger.info("Subscription billing date updated", subscription_id=subscription_id, next_billing_date=next_billing_date) return updated_subscription except Exception as e: logger.error("Failed to update billing date", subscription_id=subscription_id, error=str(e)) raise DatabaseError(f"Failed to update billing date: {str(e)}") async def get_subscription_statistics(self) -> Dict[str, Any]: """Get subscription statistics""" try: # Get counts by plan plan_query = text(""" SELECT plan, COUNT(*) as count FROM subscriptions WHERE status = 'active' GROUP BY plan ORDER BY count DESC """) result = await self.session.execute(plan_query) subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()} # Get counts by status status_query = text(""" SELECT status, COUNT(*) as count FROM subscriptions GROUP BY status ORDER BY count DESC """) result = await self.session.execute(status_query) subscriptions_by_status = {row.status: row.count for row in result.fetchall()} # Get revenue statistics revenue_query = text(""" SELECT SUM(monthly_price) as total_monthly_revenue, AVG(monthly_price) as avg_monthly_price, COUNT(*) as total_active_subscriptions FROM subscriptions WHERE status = 'active' """) revenue_result = await self.session.execute(revenue_query) revenue_row = revenue_result.fetchone() # Get upcoming billing count thirty_days_ahead = datetime.utcnow() + timedelta(days=30) upcoming_billing = len(await self.get_subscriptions_due_for_billing(30)) return { "subscriptions_by_plan": subscriptions_by_plan, "subscriptions_by_status": subscriptions_by_status, "total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0), "avg_monthly_price": float(revenue_row.avg_monthly_price or 0), "total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0), "upcoming_billing_30d": upcoming_billing } except Exception as e: logger.error("Failed to get subscription statistics", error=str(e)) return { "subscriptions_by_plan": {}, "subscriptions_by_status": {}, "total_monthly_revenue": 0.0, "avg_monthly_price": 0.0, "total_active_subscriptions": 0, "upcoming_billing_30d": 0 } async def cleanup_old_subscriptions(self, days_old: int = 730) -> int: """Clean up very old cancelled subscriptions (2 years)""" try: cutoff_date = datetime.utcnow() - timedelta(days=days_old) query_text = """ DELETE FROM subscriptions WHERE status IN ('cancelled', 'suspended') AND updated_at < :cutoff_date """ result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date}) deleted_count = result.rowcount logger.info("Cleaned up old subscriptions", deleted_count=deleted_count, days_old=days_old) return deleted_count except Exception as e: logger.error("Failed to cleanup old subscriptions", error=str(e)) raise DatabaseError(f"Cleanup failed: {str(e)}") def _get_plan_configuration(self, plan: str) -> Dict[str, Any]: """Get configuration for a subscription plan""" plan_configs = { "basic": { "monthly_price": 29.99, "max_users": 2, "max_locations": 1, "max_products": 50 }, "professional": { "monthly_price": 79.99, "max_users": 10, "max_locations": 3, "max_products": 200 }, "enterprise": { "monthly_price": 199.99, "max_users": 50, "max_locations": 10, "max_products": 1000 } } return plan_configs.get(plan, plan_configs["basic"])