Initial commit - production deployment
This commit is contained in:
680
services/tenant/app/repositories/tenant_repository.py
Normal file
680
services/tenant/app/repositories/tenant_repository.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""
|
||||
Tenant Repository
|
||||
Repository for tenant operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import Tenant, Subscription
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantRepository(TenantBaseRepository):
|
||||
"""Repository for tenant operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Tenants are relatively stable, longer cache time (10 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_tenant(self, tenant_data: Dict[str, Any]) -> Tenant:
|
||||
"""Create a new tenant with validation"""
|
||||
try:
|
||||
# Validate tenant data
|
||||
validation_result = self._validate_tenant_data(
|
||||
tenant_data,
|
||||
["name", "address", "postal_code", "owner_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid tenant data: {validation_result['errors']}")
|
||||
|
||||
# Generate subdomain if not provided
|
||||
if "subdomain" not in tenant_data or not tenant_data["subdomain"]:
|
||||
subdomain = await self._generate_unique_subdomain(tenant_data["name"])
|
||||
tenant_data["subdomain"] = subdomain
|
||||
else:
|
||||
# Check if provided subdomain is unique
|
||||
existing_tenant = await self.get_by_subdomain(tenant_data["subdomain"])
|
||||
if existing_tenant:
|
||||
raise DuplicateRecordError(f"Subdomain {tenant_data['subdomain']} already exists")
|
||||
|
||||
# Set default values
|
||||
if "business_type" not in tenant_data:
|
||||
tenant_data["business_type"] = "bakery"
|
||||
if "city" not in tenant_data:
|
||||
tenant_data["city"] = "Madrid"
|
||||
if "is_active" not in tenant_data:
|
||||
tenant_data["is_active"] = True
|
||||
if "ml_model_trained" not in tenant_data:
|
||||
tenant_data["ml_model_trained"] = False
|
||||
|
||||
# Create tenant
|
||||
tenant = await self.create(tenant_data)
|
||||
|
||||
logger.info("Tenant created successfully",
|
||||
tenant_id=tenant.id,
|
||||
name=tenant.name,
|
||||
subdomain=tenant.subdomain,
|
||||
owner_id=tenant.owner_id)
|
||||
|
||||
return tenant
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tenant",
|
||||
name=tenant_data.get("name"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create tenant: {str(e)}")
|
||||
|
||||
async def get_by_subdomain(self, subdomain: str) -> Optional[Tenant]:
|
||||
"""Get tenant by subdomain"""
|
||||
try:
|
||||
return await self.get_by_field("subdomain", subdomain)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant by subdomain",
|
||||
subdomain=subdomain,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant: {str(e)}")
|
||||
|
||||
async def get_tenants_by_owner(self, owner_id: str) -> List[Tenant]:
|
||||
"""Get all tenants owned by a user"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"owner_id": owner_id, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by owner",
|
||||
owner_id=owner_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenants: {str(e)}")
|
||||
|
||||
async def get_active_tenants(self, skip: int = 0, limit: int = 100) -> List[Tenant]:
|
||||
"""Get all active tenants"""
|
||||
return await self.get_active_records(skip=skip, limit=limit)
|
||||
|
||||
async def search_tenants(
|
||||
self,
|
||||
search_term: str,
|
||||
business_type: str = None,
|
||||
city: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[Tenant]:
|
||||
"""Search tenants by name, address, or other criteria"""
|
||||
try:
|
||||
# Build search conditions
|
||||
conditions = ["is_active = true"]
|
||||
params = {"skip": skip, "limit": limit}
|
||||
|
||||
# Add text search
|
||||
conditions.append("(LOWER(name) LIKE LOWER(:search_term) OR LOWER(address) LIKE LOWER(:search_term))")
|
||||
params["search_term"] = f"%{search_term}%"
|
||||
|
||||
# Add business type filter
|
||||
if business_type:
|
||||
conditions.append("business_type = :business_type")
|
||||
params["business_type"] = business_type
|
||||
|
||||
# Add city filter
|
||||
if city:
|
||||
conditions.append("LOWER(city) = LOWER(:city)")
|
||||
params["city"] = city
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM tenants
|
||||
WHERE {' AND '.join(conditions)}
|
||||
ORDER BY name ASC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
return tenants
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to search tenants",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def update_tenant_model_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
ml_model_trained: bool,
|
||||
last_training_date: datetime = None
|
||||
) -> Optional[Tenant]:
|
||||
"""Update tenant model training status"""
|
||||
try:
|
||||
update_data = {
|
||||
"ml_model_trained": ml_model_trained,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
if last_training_date:
|
||||
update_data["last_training_date"] = last_training_date
|
||||
elif ml_model_trained:
|
||||
update_data["last_training_date"] = datetime.utcnow()
|
||||
|
||||
updated_tenant = await self.update(tenant_id, update_data)
|
||||
|
||||
logger.info("Tenant model status updated",
|
||||
tenant_id=tenant_id,
|
||||
ml_model_trained=ml_model_trained,
|
||||
last_training_date=last_training_date)
|
||||
|
||||
return updated_tenant
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update tenant model status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update model status: {str(e)}")
|
||||
|
||||
|
||||
async def get_tenants_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
radius_km: float = 10.0,
|
||||
limit: int = 50
|
||||
) -> List[Tenant]:
|
||||
"""Get tenants within a geographic radius"""
|
||||
try:
|
||||
# Using Haversine formula for distance calculation
|
||||
query_text = """
|
||||
SELECT *,
|
||||
(6371 * acos(
|
||||
cos(radians(:latitude)) *
|
||||
cos(radians(latitude)) *
|
||||
cos(radians(longitude) - radians(:longitude)) +
|
||||
sin(radians(:latitude)) *
|
||||
sin(radians(latitude))
|
||||
)) AS distance_km
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
AND latitude IS NOT NULL
|
||||
AND longitude IS NOT NULL
|
||||
HAVING distance_km <= :radius_km
|
||||
ORDER BY distance_km ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"radius_km": radius_km,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
# Create tenant object (excluding the calculated distance_km field)
|
||||
record_dict = dict(row._mapping)
|
||||
record_dict.pop("distance_km", None) # Remove calculated field
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
return tenants
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by location",
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
radius_km=radius_km,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_tenant_statistics(self) -> Dict[str, Any]:
|
||||
"""Get global tenant statistics"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_tenants = await self.count()
|
||||
active_tenants = await self.count(filters={"is_active": True})
|
||||
|
||||
# Get tenants by business type
|
||||
business_type_query = text("""
|
||||
SELECT business_type, COUNT(*) as count
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
GROUP BY business_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(business_type_query)
|
||||
business_type_stats = {row.business_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get tenants by subscription tier - now from subscriptions table
|
||||
tier_query = text("""
|
||||
SELECT s.plan as subscription_tier, COUNT(*) as count
|
||||
FROM tenants t
|
||||
LEFT JOIN subscriptions s ON t.id = s.tenant_id AND s.status = 'active'
|
||||
WHERE t.is_active = true
|
||||
GROUP BY s.plan
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
tier_result = await self.session.execute(tier_query)
|
||||
tier_stats = {}
|
||||
for row in tier_result.fetchall():
|
||||
tier = row.subscription_tier if row.subscription_tier else "no_subscription"
|
||||
tier_stats[tier] = row.count
|
||||
|
||||
# Get model training statistics
|
||||
model_query = text("""
|
||||
SELECT
|
||||
COUNT(CASE WHEN ml_model_trained = true THEN 1 END) as trained_count,
|
||||
COUNT(CASE WHEN ml_model_trained = false THEN 1 END) as untrained_count,
|
||||
AVG(EXTRACT(EPOCH FROM (NOW() - last_training_date))/86400) as avg_days_since_training
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
""")
|
||||
|
||||
model_result = await self.session.execute(model_query)
|
||||
model_row = model_result.fetchone()
|
||||
|
||||
# Get recent registrations (last 30 days)
|
||||
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
|
||||
recent_registrations = await self.count(filters={
|
||||
"created_at": f">= '{thirty_days_ago.isoformat()}'"
|
||||
})
|
||||
|
||||
return {
|
||||
"total_tenants": total_tenants,
|
||||
"active_tenants": active_tenants,
|
||||
"inactive_tenants": total_tenants - active_tenants,
|
||||
"tenants_by_business_type": business_type_stats,
|
||||
"tenants_by_subscription": tier_stats,
|
||||
"model_training": {
|
||||
"trained_tenants": int(model_row.trained_count or 0),
|
||||
"untrained_tenants": int(model_row.untrained_count or 0),
|
||||
"avg_days_since_training": float(model_row.avg_days_since_training or 0)
|
||||
} if model_row else {
|
||||
"trained_tenants": 0,
|
||||
"untrained_tenants": 0,
|
||||
"avg_days_since_training": 0.0
|
||||
},
|
||||
"recent_registrations_30d": recent_registrations
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics", error=str(e))
|
||||
return {
|
||||
"total_tenants": 0,
|
||||
"active_tenants": 0,
|
||||
"inactive_tenants": 0,
|
||||
"tenants_by_business_type": {},
|
||||
"tenants_by_subscription": {},
|
||||
"model_training": {
|
||||
"trained_tenants": 0,
|
||||
"untrained_tenants": 0,
|
||||
"avg_days_since_training": 0.0
|
||||
},
|
||||
"recent_registrations_30d": 0
|
||||
}
|
||||
|
||||
async def _generate_unique_subdomain(self, name: str) -> str:
|
||||
"""Generate a unique subdomain from tenant name"""
|
||||
try:
|
||||
# Clean the name to create a subdomain
|
||||
subdomain = name.lower().replace(' ', '-')
|
||||
# Remove accents
|
||||
subdomain = subdomain.replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u')
|
||||
subdomain = subdomain.replace('ñ', 'n')
|
||||
# Keep only alphanumeric and hyphens
|
||||
subdomain = ''.join(c for c in subdomain if c.isalnum() or c == '-')
|
||||
# Remove multiple consecutive hyphens
|
||||
while '--' in subdomain:
|
||||
subdomain = subdomain.replace('--', '-')
|
||||
# Remove leading/trailing hyphens
|
||||
subdomain = subdomain.strip('-')
|
||||
|
||||
# Ensure minimum length
|
||||
if len(subdomain) < 3:
|
||||
subdomain = f"tenant-{subdomain}"
|
||||
|
||||
# Check if subdomain exists
|
||||
existing_tenant = await self.get_by_subdomain(subdomain)
|
||||
if not existing_tenant:
|
||||
return subdomain
|
||||
|
||||
# If it exists, add a unique suffix
|
||||
counter = 1
|
||||
while True:
|
||||
candidate = f"{subdomain}-{counter}"
|
||||
existing_tenant = await self.get_by_subdomain(candidate)
|
||||
if not existing_tenant:
|
||||
return candidate
|
||||
counter += 1
|
||||
|
||||
# Prevent infinite loop
|
||||
if counter > 9999:
|
||||
return f"{subdomain}-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate unique subdomain",
|
||||
name=name,
|
||||
error=str(e))
|
||||
# Fallback to UUID-based subdomain
|
||||
return f"tenant-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def deactivate_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Deactivate a tenant"""
|
||||
return await self.deactivate_record(tenant_id)
|
||||
|
||||
async def activate_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Activate a tenant"""
|
||||
return await self.activate_record(tenant_id)
|
||||
|
||||
async def get_child_tenants(self, parent_tenant_id: str) -> List[Tenant]:
|
||||
"""Get all child tenants for a parent tenant"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"parent_tenant_id": parent_tenant_id, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get child tenants",
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get child tenants: {str(e)}")
|
||||
|
||||
async def get(self, record_id: Any) -> Optional[Tenant]:
|
||||
"""Get tenant by ID - alias for get_by_id for compatibility"""
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def get_child_tenant_count(self, parent_tenant_id: str) -> int:
|
||||
"""Get count of child tenants for a parent tenant"""
|
||||
try:
|
||||
child_tenants = await self.get_child_tenants(parent_tenant_id)
|
||||
return len(child_tenants)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get child tenant count",
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
error=str(e))
|
||||
return 0
|
||||
|
||||
async def get_user_tenants_with_hierarchy(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all tenants a user has access to, organized in hierarchy.
|
||||
Returns parent tenants with their children nested.
|
||||
"""
|
||||
try:
|
||||
# Get all tenants where user is owner or member
|
||||
query_text = """
|
||||
SELECT DISTINCT t.*
|
||||
FROM tenants t
|
||||
LEFT JOIN tenant_members tm ON t.id = tm.tenant_id
|
||||
WHERE (t.owner_id = :user_id OR tm.user_id = :user_id)
|
||||
AND t.is_active = true
|
||||
ORDER BY t.tenant_type DESC, t.created_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"user_id": user_id})
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
# Organize into hierarchy
|
||||
tenant_hierarchy = []
|
||||
parent_map = {}
|
||||
|
||||
# First pass: collect all parent/standalone tenants
|
||||
for tenant in tenants:
|
||||
if tenant.tenant_type in ['parent', 'standalone']:
|
||||
tenant_dict = {
|
||||
'id': str(tenant.id),
|
||||
'name': tenant.name,
|
||||
'subdomain': tenant.subdomain,
|
||||
'tenant_type': tenant.tenant_type,
|
||||
'business_type': tenant.business_type,
|
||||
'business_model': tenant.business_model,
|
||||
'city': tenant.city,
|
||||
'is_active': tenant.is_active,
|
||||
'children': [] if tenant.tenant_type == 'parent' else None
|
||||
}
|
||||
tenant_hierarchy.append(tenant_dict)
|
||||
parent_map[str(tenant.id)] = tenant_dict
|
||||
|
||||
# Second pass: attach children to their parents
|
||||
for tenant in tenants:
|
||||
if tenant.tenant_type == 'child' and tenant.parent_tenant_id:
|
||||
parent_id = str(tenant.parent_tenant_id)
|
||||
if parent_id in parent_map:
|
||||
child_dict = {
|
||||
'id': str(tenant.id),
|
||||
'name': tenant.name,
|
||||
'subdomain': tenant.subdomain,
|
||||
'tenant_type': 'child',
|
||||
'parent_tenant_id': parent_id,
|
||||
'city': tenant.city,
|
||||
'is_active': tenant.is_active
|
||||
}
|
||||
parent_map[parent_id]['children'].append(child_dict)
|
||||
|
||||
return tenant_hierarchy
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user tenants with hierarchy",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_tenants_by_session_id(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get tenants associated with a specific demo session using the demo_session_id field.
|
||||
"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"demo_session_id": session_id,
|
||||
"is_active": True
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by session ID",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenants by session ID: {str(e)}")
|
||||
|
||||
async def get_professional_demo_tenants(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get professional demo tenants filtered by session.
|
||||
|
||||
Args:
|
||||
session_id: Required demo session ID to filter tenants
|
||||
|
||||
Returns:
|
||||
List of professional demo tenants for this specific session
|
||||
"""
|
||||
try:
|
||||
filters = {
|
||||
"business_model": "professional_bakery",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
}
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get professional demo tenants",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get professional demo tenants: {str(e)}")
|
||||
|
||||
async def get_enterprise_demo_tenants(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get enterprise demo tenants (parent and children) filtered by session.
|
||||
|
||||
Args:
|
||||
session_id: Required demo session ID to filter tenants
|
||||
|
||||
Returns:
|
||||
List of enterprise demo tenants (1 parent + 3 children) for this specific session
|
||||
"""
|
||||
try:
|
||||
# Get enterprise demo parent tenants for this session
|
||||
parent_tenants = await self.get_multi(
|
||||
filters={
|
||||
"tenant_type": "parent",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
# Get child tenants for the enterprise demo session
|
||||
child_tenants = await self.get_multi(
|
||||
filters={
|
||||
"tenant_type": "child",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
# Combine parent and child tenants
|
||||
return parent_tenants + child_tenants
|
||||
except Exception as e:
|
||||
logger.error("Failed to get enterprise demo tenants",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}")
|
||||
|
||||
async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get tenant by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Find tenant by joining with subscriptions table
|
||||
# Tenant doesn't have customer_id directly, so we need to find via subscription
|
||||
query = select(Tenant).join(
|
||||
Subscription, Subscription.tenant_id == Tenant.id
|
||||
).where(Subscription.customer_id == customer_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No tenant found for customer_id",
|
||||
customer_id=customer_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}")
|
||||
|
||||
async def get_user_primary_tenant(self, user_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get the primary tenant for a user (the tenant they own)
|
||||
|
||||
Args:
|
||||
user_id: User ID to find primary tenant for
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting primary tenant for user", user_id=user_id)
|
||||
|
||||
# Query for tenant where user is the owner
|
||||
query = select(Tenant).where(Tenant.owner_id == user_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found primary tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No primary tenant found for user", user_id=user_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting primary tenant for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get primary tenant for user: {str(e)}")
|
||||
|
||||
async def get_any_user_tenant(self, user_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get any tenant that the user has access to (via tenant_members)
|
||||
|
||||
Args:
|
||||
user_id: User ID to find accessible tenants for
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting any accessible tenant for user", user_id=user_id)
|
||||
|
||||
# Query for tenant members where user has access
|
||||
from app.models.tenants import TenantMember
|
||||
|
||||
query = select(Tenant).join(
|
||||
TenantMember, Tenant.id == TenantMember.tenant_id
|
||||
).where(TenantMember.user_id == user_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found accessible tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No accessible tenants found for user", user_id=user_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting accessible tenant for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get accessible tenant for user: {str(e)}")
|
||||
Reference in New Issue
Block a user