""" Tenant Repository Repository for tenant operations """ from typing import Optional, List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, text, and_ from datetime import datetime, timedelta import structlog import uuid from .base import TenantBaseRepository from app.models.tenants import Tenant, Subscription from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError logger = structlog.get_logger() class TenantRepository(TenantBaseRepository): """Repository for tenant operations""" def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600): # Tenants are relatively stable, longer cache time (10 minutes) super().__init__(model_class, session, cache_ttl) async def create_tenant(self, tenant_data: Dict[str, Any]) -> Tenant: """Create a new tenant with validation""" try: # Validate tenant data validation_result = self._validate_tenant_data( tenant_data, ["name", "address", "postal_code", "owner_id"] ) if not validation_result["is_valid"]: raise ValidationError(f"Invalid tenant data: {validation_result['errors']}") # Generate subdomain if not provided if "subdomain" not in tenant_data or not tenant_data["subdomain"]: subdomain = await self._generate_unique_subdomain(tenant_data["name"]) tenant_data["subdomain"] = subdomain else: # Check if provided subdomain is unique existing_tenant = await self.get_by_subdomain(tenant_data["subdomain"]) if existing_tenant: raise DuplicateRecordError(f"Subdomain {tenant_data['subdomain']} already exists") # Set default values if "business_type" not in tenant_data: tenant_data["business_type"] = "bakery" if "city" not in tenant_data: tenant_data["city"] = "Madrid" if "is_active" not in tenant_data: tenant_data["is_active"] = True if "ml_model_trained" not in tenant_data: tenant_data["ml_model_trained"] = False # Create tenant tenant = await self.create(tenant_data) logger.info("Tenant created successfully", tenant_id=tenant.id, name=tenant.name, subdomain=tenant.subdomain, owner_id=tenant.owner_id) return tenant except (ValidationError, DuplicateRecordError): raise except Exception as e: logger.error("Failed to create tenant", name=tenant_data.get("name"), error=str(e)) raise DatabaseError(f"Failed to create tenant: {str(e)}") async def get_by_subdomain(self, subdomain: str) -> Optional[Tenant]: """Get tenant by subdomain""" try: return await self.get_by_field("subdomain", subdomain) except Exception as e: logger.error("Failed to get tenant by subdomain", subdomain=subdomain, error=str(e)) raise DatabaseError(f"Failed to get tenant: {str(e)}") async def get_tenants_by_owner(self, owner_id: str) -> List[Tenant]: """Get all tenants owned by a user""" try: return await self.get_multi( filters={"owner_id": owner_id, "is_active": True}, order_by="created_at", order_desc=True ) except Exception as e: logger.error("Failed to get tenants by owner", owner_id=owner_id, error=str(e)) raise DatabaseError(f"Failed to get tenants: {str(e)}") async def get_active_tenants(self, skip: int = 0, limit: int = 100) -> List[Tenant]: """Get all active tenants""" return await self.get_active_records(skip=skip, limit=limit) async def search_tenants( self, search_term: str, business_type: str = None, city: str = None, skip: int = 0, limit: int = 50 ) -> List[Tenant]: """Search tenants by name, address, or other criteria""" try: # Build search conditions conditions = ["is_active = true"] params = {"skip": skip, "limit": limit} # Add text search conditions.append("(LOWER(name) LIKE LOWER(:search_term) OR LOWER(address) LIKE LOWER(:search_term))") params["search_term"] = f"%{search_term}%" # Add business type filter if business_type: conditions.append("business_type = :business_type") params["business_type"] = business_type # Add city filter if city: conditions.append("LOWER(city) = LOWER(:city)") params["city"] = city query_text = f""" SELECT * FROM tenants WHERE {' AND '.join(conditions)} ORDER BY name ASC LIMIT :limit OFFSET :skip """ result = await self.session.execute(text(query_text), params) tenants = [] for row in result.fetchall(): record_dict = dict(row._mapping) tenant = self.model(**record_dict) tenants.append(tenant) return tenants except Exception as e: logger.error("Failed to search tenants", search_term=search_term, error=str(e)) return [] async def update_tenant_model_status( self, tenant_id: str, ml_model_trained: bool, last_training_date: datetime = None ) -> Optional[Tenant]: """Update tenant model training status""" try: update_data = { "ml_model_trained": ml_model_trained, "updated_at": datetime.utcnow() } if last_training_date: update_data["last_training_date"] = last_training_date elif ml_model_trained: update_data["last_training_date"] = datetime.utcnow() updated_tenant = await self.update(tenant_id, update_data) logger.info("Tenant model status updated", tenant_id=tenant_id, ml_model_trained=ml_model_trained, last_training_date=last_training_date) return updated_tenant except Exception as e: logger.error("Failed to update tenant model status", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to update model status: {str(e)}") async def get_tenants_by_location( self, latitude: float, longitude: float, radius_km: float = 10.0, limit: int = 50 ) -> List[Tenant]: """Get tenants within a geographic radius""" try: # Using Haversine formula for distance calculation query_text = """ SELECT *, (6371 * acos( cos(radians(:latitude)) * cos(radians(latitude)) * cos(radians(longitude) - radians(:longitude)) + sin(radians(:latitude)) * sin(radians(latitude)) )) AS distance_km FROM tenants WHERE is_active = true AND latitude IS NOT NULL AND longitude IS NOT NULL HAVING distance_km <= :radius_km ORDER BY distance_km ASC LIMIT :limit """ result = await self.session.execute(text(query_text), { "latitude": latitude, "longitude": longitude, "radius_km": radius_km, "limit": limit }) tenants = [] for row in result.fetchall(): # Create tenant object (excluding the calculated distance_km field) record_dict = dict(row._mapping) record_dict.pop("distance_km", None) # Remove calculated field tenant = self.model(**record_dict) tenants.append(tenant) return tenants except Exception as e: logger.error("Failed to get tenants by location", latitude=latitude, longitude=longitude, radius_km=radius_km, error=str(e)) return [] async def get_tenant_statistics(self) -> Dict[str, Any]: """Get global tenant statistics""" try: # Get basic counts total_tenants = await self.count() active_tenants = await self.count(filters={"is_active": True}) # Get tenants by business type business_type_query = text(""" SELECT business_type, COUNT(*) as count FROM tenants WHERE is_active = true GROUP BY business_type ORDER BY count DESC """) result = await self.session.execute(business_type_query) business_type_stats = {row.business_type: row.count for row in result.fetchall()} # Get tenants by subscription tier - now from subscriptions table tier_query = text(""" SELECT s.plan as subscription_tier, COUNT(*) as count FROM tenants t LEFT JOIN subscriptions s ON t.id = s.tenant_id AND s.status = 'active' WHERE t.is_active = true GROUP BY s.plan ORDER BY count DESC """) tier_result = await self.session.execute(tier_query) tier_stats = {} for row in tier_result.fetchall(): tier = row.subscription_tier if row.subscription_tier else "no_subscription" tier_stats[tier] = row.count # Get model training statistics model_query = text(""" SELECT COUNT(CASE WHEN ml_model_trained = true THEN 1 END) as trained_count, COUNT(CASE WHEN ml_model_trained = false THEN 1 END) as untrained_count, AVG(EXTRACT(EPOCH FROM (NOW() - last_training_date))/86400) as avg_days_since_training FROM tenants WHERE is_active = true """) model_result = await self.session.execute(model_query) model_row = model_result.fetchone() # Get recent registrations (last 30 days) thirty_days_ago = datetime.utcnow() - timedelta(days=30) recent_registrations = await self.count(filters={ "created_at": f">= '{thirty_days_ago.isoformat()}'" }) return { "total_tenants": total_tenants, "active_tenants": active_tenants, "inactive_tenants": total_tenants - active_tenants, "tenants_by_business_type": business_type_stats, "tenants_by_subscription": tier_stats, "model_training": { "trained_tenants": int(model_row.trained_count or 0), "untrained_tenants": int(model_row.untrained_count or 0), "avg_days_since_training": float(model_row.avg_days_since_training or 0) } if model_row else { "trained_tenants": 0, "untrained_tenants": 0, "avg_days_since_training": 0.0 }, "recent_registrations_30d": recent_registrations } except Exception as e: logger.error("Failed to get tenant statistics", error=str(e)) return { "total_tenants": 0, "active_tenants": 0, "inactive_tenants": 0, "tenants_by_business_type": {}, "tenants_by_subscription": {}, "model_training": { "trained_tenants": 0, "untrained_tenants": 0, "avg_days_since_training": 0.0 }, "recent_registrations_30d": 0 } async def _generate_unique_subdomain(self, name: str) -> str: """Generate a unique subdomain from tenant name""" try: # Clean the name to create a subdomain subdomain = name.lower().replace(' ', '-') # Remove accents subdomain = subdomain.replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u') subdomain = subdomain.replace('ñ', 'n') # Keep only alphanumeric and hyphens subdomain = ''.join(c for c in subdomain if c.isalnum() or c == '-') # Remove multiple consecutive hyphens while '--' in subdomain: subdomain = subdomain.replace('--', '-') # Remove leading/trailing hyphens subdomain = subdomain.strip('-') # Ensure minimum length if len(subdomain) < 3: subdomain = f"tenant-{subdomain}" # Check if subdomain exists existing_tenant = await self.get_by_subdomain(subdomain) if not existing_tenant: return subdomain # If it exists, add a unique suffix counter = 1 while True: candidate = f"{subdomain}-{counter}" existing_tenant = await self.get_by_subdomain(candidate) if not existing_tenant: return candidate counter += 1 # Prevent infinite loop if counter > 9999: return f"{subdomain}-{uuid.uuid4().hex[:6]}" except Exception as e: logger.error("Failed to generate unique subdomain", name=name, error=str(e)) # Fallback to UUID-based subdomain return f"tenant-{uuid.uuid4().hex[:8]}" async def deactivate_tenant(self, tenant_id: str) -> Optional[Tenant]: """Deactivate a tenant""" return await self.deactivate_record(tenant_id) async def activate_tenant(self, tenant_id: str) -> Optional[Tenant]: """Activate a tenant""" return await self.activate_record(tenant_id) async def get_child_tenants(self, parent_tenant_id: str) -> List[Tenant]: """Get all child tenants for a parent tenant""" try: return await self.get_multi( filters={"parent_tenant_id": parent_tenant_id, "is_active": True}, order_by="created_at", order_desc=False ) except Exception as e: logger.error("Failed to get child tenants", parent_tenant_id=parent_tenant_id, error=str(e)) raise DatabaseError(f"Failed to get child tenants: {str(e)}") async def get(self, record_id: Any) -> Optional[Tenant]: """Get tenant by ID - alias for get_by_id for compatibility""" return await self.get_by_id(record_id) async def get_child_tenant_count(self, parent_tenant_id: str) -> int: """Get count of child tenants for a parent tenant""" try: child_tenants = await self.get_child_tenants(parent_tenant_id) return len(child_tenants) except Exception as e: logger.error("Failed to get child tenant count", parent_tenant_id=parent_tenant_id, error=str(e)) return 0 async def get_user_tenants_with_hierarchy(self, user_id: str) -> List[Dict[str, Any]]: """ Get all tenants a user has access to, organized in hierarchy. Returns parent tenants with their children nested. """ try: # Get all tenants where user is owner or member query_text = """ SELECT DISTINCT t.* FROM tenants t LEFT JOIN tenant_members tm ON t.id = tm.tenant_id WHERE (t.owner_id = :user_id OR tm.user_id = :user_id) AND t.is_active = true ORDER BY t.tenant_type DESC, t.created_at ASC """ result = await self.session.execute(text(query_text), {"user_id": user_id}) tenants = [] for row in result.fetchall(): record_dict = dict(row._mapping) tenant = self.model(**record_dict) tenants.append(tenant) # Organize into hierarchy tenant_hierarchy = [] parent_map = {} # First pass: collect all parent/standalone tenants for tenant in tenants: if tenant.tenant_type in ['parent', 'standalone']: tenant_dict = { 'id': str(tenant.id), 'name': tenant.name, 'subdomain': tenant.subdomain, 'tenant_type': tenant.tenant_type, 'business_type': tenant.business_type, 'business_model': tenant.business_model, 'city': tenant.city, 'is_active': tenant.is_active, 'children': [] if tenant.tenant_type == 'parent' else None } tenant_hierarchy.append(tenant_dict) parent_map[str(tenant.id)] = tenant_dict # Second pass: attach children to their parents for tenant in tenants: if tenant.tenant_type == 'child' and tenant.parent_tenant_id: parent_id = str(tenant.parent_tenant_id) if parent_id in parent_map: child_dict = { 'id': str(tenant.id), 'name': tenant.name, 'subdomain': tenant.subdomain, 'tenant_type': 'child', 'parent_tenant_id': parent_id, 'city': tenant.city, 'is_active': tenant.is_active } parent_map[parent_id]['children'].append(child_dict) return tenant_hierarchy except Exception as e: logger.error("Failed to get user tenants with hierarchy", user_id=user_id, error=str(e)) return [] async def get_tenants_by_session_id(self, session_id: str) -> List[Tenant]: """ Get tenants associated with a specific demo session using the demo_session_id field. """ try: return await self.get_multi( filters={ "demo_session_id": session_id, "is_active": True }, order_by="created_at", order_desc=True ) except Exception as e: logger.error("Failed to get tenants by session ID", session_id=session_id, error=str(e)) raise DatabaseError(f"Failed to get tenants by session ID: {str(e)}") async def get_professional_demo_tenants(self, session_id: str) -> List[Tenant]: """ Get professional demo tenants filtered by session. Args: session_id: Required demo session ID to filter tenants Returns: List of professional demo tenants for this specific session """ try: filters = { "business_model": "professional_bakery", "is_demo": True, "is_active": True, "demo_session_id": session_id # Always filter by session } return await self.get_multi( filters=filters, order_by="created_at", order_desc=True ) except Exception as e: logger.error("Failed to get professional demo tenants", session_id=session_id, error=str(e)) raise DatabaseError(f"Failed to get professional demo tenants: {str(e)}") async def get_enterprise_demo_tenants(self, session_id: str) -> List[Tenant]: """ Get enterprise demo tenants (parent and children) filtered by session. Args: session_id: Required demo session ID to filter tenants Returns: List of enterprise demo tenants (1 parent + 3 children) for this specific session """ try: # Get enterprise demo parent tenants for this session parent_tenants = await self.get_multi( filters={ "tenant_type": "parent", "is_demo": True, "is_active": True, "demo_session_id": session_id # Always filter by session }, order_by="created_at", order_desc=True ) # Get child tenants for the enterprise demo session child_tenants = await self.get_multi( filters={ "tenant_type": "child", "is_demo": True, "is_active": True, "demo_session_id": session_id # Always filter by session }, order_by="created_at", order_desc=True ) # Combine parent and child tenants return parent_tenants + child_tenants except Exception as e: logger.error("Failed to get enterprise demo tenants", session_id=session_id, error=str(e)) raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}") async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]: """ Get tenant by Stripe customer ID Args: customer_id: Stripe customer ID Returns: Tenant object if found, None otherwise """ try: # Find tenant by joining with subscriptions table # Tenant doesn't have customer_id directly, so we need to find via subscription query = select(Tenant).join( Subscription, Subscription.tenant_id == Tenant.id ).where(Subscription.customer_id == customer_id) result = await self.session.execute(query) tenant = result.scalar_one_or_none() if tenant: logger.debug("Found tenant by customer_id", customer_id=customer_id, tenant_id=str(tenant.id)) return tenant else: logger.debug("No tenant found for customer_id", customer_id=customer_id) return None except Exception as e: logger.error("Error getting tenant by customer_id", customer_id=customer_id, error=str(e)) raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}") async def get_user_primary_tenant(self, user_id: str) -> Optional[Tenant]: """ Get the primary tenant for a user (the tenant they own) Args: user_id: User ID to find primary tenant for Returns: Tenant object if found, None otherwise """ try: logger.debug("Getting primary tenant for user", user_id=user_id) # Query for tenant where user is the owner query = select(Tenant).where(Tenant.owner_id == user_id) result = await self.session.execute(query) tenant = result.scalar_one_or_none() if tenant: logger.debug("Found primary tenant for user", user_id=user_id, tenant_id=str(tenant.id)) return tenant else: logger.debug("No primary tenant found for user", user_id=user_id) return None except Exception as e: logger.error("Error getting primary tenant for user", user_id=user_id, error=str(e)) raise DatabaseError(f"Failed to get primary tenant for user: {str(e)}") async def get_any_user_tenant(self, user_id: str) -> Optional[Tenant]: """ Get any tenant that the user has access to (via tenant_members) Args: user_id: User ID to find accessible tenants for Returns: Tenant object if found, None otherwise """ try: logger.debug("Getting any accessible tenant for user", user_id=user_id) # Query for tenant members where user has access from app.models.tenants import TenantMember query = select(Tenant).join( TenantMember, Tenant.id == TenantMember.tenant_id ).where(TenantMember.user_id == user_id) result = await self.session.execute(query) tenant = result.scalar_one_or_none() if tenant: logger.debug("Found accessible tenant for user", user_id=user_id, tenant_id=str(tenant.id)) return tenant else: logger.debug("No accessible tenants found for user", user_id=user_id) return None except Exception as e: logger.error("Error getting accessible tenant for user", user_id=user_id, error=str(e)) raise DatabaseError(f"Failed to get accessible tenant for user: {str(e)}")