Initial commit - production deployment
This commit is contained in:
16
services/tenant/app/repositories/__init__.py
Normal file
16
services/tenant/app/repositories/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Tenant Service Repositories
|
||||
Repository implementations for tenant service
|
||||
"""
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from .tenant_repository import TenantRepository
|
||||
from .tenant_member_repository import TenantMemberRepository
|
||||
from .subscription_repository import SubscriptionRepository
|
||||
|
||||
__all__ = [
|
||||
"TenantBaseRepository",
|
||||
"TenantRepository",
|
||||
"TenantMemberRepository",
|
||||
"SubscriptionRepository"
|
||||
]
|
||||
234
services/tenant/app/repositories/base.py
Normal file
234
services/tenant/app/repositories/base.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Base Repository for Tenant Service
|
||||
Service-specific repository base class with tenant management utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantBaseRepository(BaseRepository):
|
||||
"""Base repository for tenant service with common tenant operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Tenant data is relatively stable, medium cache time (10 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by tenant ID"""
|
||||
if hasattr(self.model, 'tenant_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by user ID (for cross-service references)"""
|
||||
if hasattr(self.model, 'user_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"user_id": user_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
elif hasattr(self.model, 'owner_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"owner_id": user_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return []
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_old_records(self, days_old: int = 365) -> int:
|
||||
"""Clean up old tenant records (very conservative - 1 year)"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Only delete inactive records that are very old
|
||||
conditions = [
|
||||
"created_at < :cutoff_date"
|
||||
]
|
||||
|
||||
if hasattr(self.model, 'is_active'):
|
||||
conditions.append("is_active = false")
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up old {self.model.__name__} records",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics for a tenant"""
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Get basic counts
|
||||
total_records = await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
# Get active records if applicable
|
||||
active_records = total_records
|
||||
if hasattr(self.model, 'is_active'):
|
||||
active_records = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
})
|
||||
|
||||
# Get recent activity (records in last 7 days)
|
||||
seven_days_ago = datetime.utcnow() - timedelta(days=7)
|
||||
recent_query = text(f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND created_at >= :seven_days_ago
|
||||
""")
|
||||
|
||||
result = await self.session.execute(recent_query, {
|
||||
"tenant_id": tenant_id,
|
||||
"seven_days_ago": seven_days_ago
|
||||
})
|
||||
recent_records = result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total_records": total_records,
|
||||
"active_records": active_records,
|
||||
"inactive_records": total_records - active_records,
|
||||
"recent_records_7d": recent_records
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics",
|
||||
model=self.model.__name__,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_records": 0,
|
||||
"active_records": 0,
|
||||
"inactive_records": 0,
|
||||
"recent_records_7d": 0
|
||||
}
|
||||
|
||||
def _validate_tenant_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate tenant-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate tenant_id format if present
|
||||
if "tenant_id" in data and data["tenant_id"]:
|
||||
tenant_id = data["tenant_id"]
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate user_id format if present
|
||||
if "user_id" in data and data["user_id"]:
|
||||
user_id = data["user_id"]
|
||||
if not isinstance(user_id, str) or len(user_id) < 1:
|
||||
errors.append("Invalid user_id format")
|
||||
|
||||
# Validate owner_id format if present
|
||||
if "owner_id" in data and data["owner_id"]:
|
||||
owner_id = data["owner_id"]
|
||||
if not isinstance(owner_id, str) or len(owner_id) < 1:
|
||||
errors.append("Invalid owner_id format")
|
||||
|
||||
# Validate email format if present
|
||||
if "email" in data and data["email"]:
|
||||
email = data["email"]
|
||||
if "@" not in email or "." not in email.split("@")[-1]:
|
||||
errors.append("Invalid email format")
|
||||
|
||||
# Validate phone format if present (basic validation)
|
||||
if "phone" in data and data["phone"]:
|
||||
phone = data["phone"]
|
||||
if not isinstance(phone, str) or len(phone) < 9:
|
||||
errors.append("Invalid phone format")
|
||||
|
||||
# Validate coordinates if present
|
||||
if "latitude" in data and data["latitude"] is not None:
|
||||
try:
|
||||
lat = float(data["latitude"])
|
||||
if lat < -90 or lat > 90:
|
||||
errors.append("Invalid latitude - must be between -90 and 90")
|
||||
except (ValueError, TypeError):
|
||||
errors.append("Invalid latitude format")
|
||||
|
||||
if "longitude" in data and data["longitude"] is not None:
|
||||
try:
|
||||
lng = float(data["longitude"])
|
||||
if lng < -180 or lng > 180:
|
||||
errors.append("Invalid longitude - must be between -180 and 180")
|
||||
except (ValueError, TypeError):
|
||||
errors.append("Invalid longitude format")
|
||||
|
||||
# Validate JSON fields
|
||||
json_fields = ["permissions"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
if isinstance(data[field], str):
|
||||
try:
|
||||
json.loads(data[field])
|
||||
except json.JSONDecodeError:
|
||||
errors.append(f"Invalid JSON format in {field}")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
326
services/tenant/app/repositories/coupon_repository.py
Normal file
326
services/tenant/app/repositories/coupon_repository.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
Repository for coupon data access and validation
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.coupon import CouponModel, CouponRedemptionModel
|
||||
from shared.subscription.coupons import (
|
||||
Coupon,
|
||||
CouponRedemption,
|
||||
CouponValidationResult,
|
||||
DiscountType,
|
||||
calculate_trial_end_date,
|
||||
format_discount_description
|
||||
)
|
||||
|
||||
|
||||
class CouponRepository:
|
||||
"""Data access layer for coupon operations"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_coupon_by_code(self, code: str) -> Optional[Coupon]:
|
||||
"""
|
||||
Retrieve coupon by code.
|
||||
Returns None if not found.
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(CouponModel).where(CouponModel.code == code.upper())
|
||||
)
|
||||
coupon_model = result.scalar_one_or_none()
|
||||
|
||||
if not coupon_model:
|
||||
return None
|
||||
|
||||
return self._model_to_dataclass(coupon_model)
|
||||
|
||||
async def validate_coupon(
|
||||
self,
|
||||
code: str,
|
||||
tenant_id: str
|
||||
) -> CouponValidationResult:
|
||||
"""
|
||||
Validate a coupon code for a specific tenant.
|
||||
Checks: existence, validity, redemption limits, and if tenant already used it.
|
||||
"""
|
||||
# Get coupon
|
||||
coupon = await self.get_coupon_by_code(code)
|
||||
if not coupon:
|
||||
return CouponValidationResult(
|
||||
valid=False,
|
||||
coupon=None,
|
||||
error_message="Código de cupón inválido",
|
||||
discount_preview=None
|
||||
)
|
||||
|
||||
# Check if coupon can be redeemed
|
||||
can_redeem, reason = coupon.can_be_redeemed()
|
||||
if not can_redeem:
|
||||
error_messages = {
|
||||
"Coupon is inactive": "Este cupón no está activo",
|
||||
"Coupon is not yet valid": "Este cupón aún no es válido",
|
||||
"Coupon has expired": "Este cupón ha expirado",
|
||||
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
|
||||
}
|
||||
return CouponValidationResult(
|
||||
valid=False,
|
||||
coupon=coupon,
|
||||
error_message=error_messages.get(reason, reason),
|
||||
discount_preview=None
|
||||
)
|
||||
|
||||
# Check if tenant already redeemed this coupon
|
||||
result = await self.db.execute(
|
||||
select(CouponRedemptionModel).where(
|
||||
and_(
|
||||
CouponRedemptionModel.tenant_id == tenant_id,
|
||||
CouponRedemptionModel.coupon_code == code.upper()
|
||||
)
|
||||
)
|
||||
)
|
||||
existing_redemption = result.scalar_one_or_none()
|
||||
|
||||
if existing_redemption:
|
||||
return CouponValidationResult(
|
||||
valid=False,
|
||||
coupon=coupon,
|
||||
error_message="Ya has utilizado este cupón",
|
||||
discount_preview=None
|
||||
)
|
||||
|
||||
# Generate discount preview
|
||||
discount_preview = self._generate_discount_preview(coupon)
|
||||
|
||||
return CouponValidationResult(
|
||||
valid=True,
|
||||
coupon=coupon,
|
||||
error_message=None,
|
||||
discount_preview=discount_preview
|
||||
)
|
||||
|
||||
async def redeem_coupon(
|
||||
self,
|
||||
code: str,
|
||||
tenant_id: Optional[str],
|
||||
base_trial_days: int = 0
|
||||
) -> tuple[bool, Optional[CouponRedemption], Optional[str]]:
|
||||
"""
|
||||
Redeem a coupon for a tenant.
|
||||
For tenant-independent registrations, tenant_id can be None initially.
|
||||
Returns (success, redemption, error_message)
|
||||
"""
|
||||
# For tenant-independent registrations, skip tenant validation
|
||||
if tenant_id:
|
||||
# Validate first
|
||||
validation = await self.validate_coupon(code, tenant_id)
|
||||
if not validation.valid:
|
||||
return False, None, validation.error_message
|
||||
coupon = validation.coupon
|
||||
else:
|
||||
# Just get the coupon and validate its general availability
|
||||
coupon = await self.get_coupon_by_code(code)
|
||||
if not coupon:
|
||||
return False, None, "Código de cupón inválido"
|
||||
|
||||
# Check if coupon can be redeemed
|
||||
can_redeem, reason = coupon.can_be_redeemed()
|
||||
if not can_redeem:
|
||||
error_messages = {
|
||||
"Coupon is inactive": "Este cupón no está activo",
|
||||
"Coupon is not yet valid": "Este cupón aún no es válido",
|
||||
"Coupon has expired": "Este cupón ha expirado",
|
||||
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
|
||||
}
|
||||
return False, None, error_messages.get(reason, reason)
|
||||
|
||||
# Calculate discount applied
|
||||
discount_applied = self._calculate_discount_applied(
|
||||
coupon,
|
||||
base_trial_days
|
||||
)
|
||||
|
||||
# Only create redemption record if tenant_id is provided
|
||||
# For tenant-independent subscriptions, skip redemption record creation
|
||||
if tenant_id:
|
||||
# Create redemption record
|
||||
redemption_model = CouponRedemptionModel(
|
||||
tenant_id=tenant_id,
|
||||
coupon_code=code.upper(),
|
||||
redeemed_at=datetime.now(timezone.utc),
|
||||
discount_applied=discount_applied,
|
||||
extra_data={
|
||||
"coupon_type": coupon.discount_type.value,
|
||||
"coupon_value": coupon.discount_value
|
||||
}
|
||||
)
|
||||
|
||||
self.db.add(redemption_model)
|
||||
|
||||
# Increment coupon redemption count
|
||||
result = await self.db.execute(
|
||||
select(CouponModel).where(CouponModel.code == code.upper())
|
||||
)
|
||||
coupon_model = result.scalar_one_or_none()
|
||||
if coupon_model:
|
||||
coupon_model.current_redemptions += 1
|
||||
|
||||
try:
|
||||
await self.db.commit()
|
||||
await self.db.refresh(redemption_model)
|
||||
|
||||
redemption = CouponRedemption(
|
||||
id=str(redemption_model.id),
|
||||
tenant_id=redemption_model.tenant_id,
|
||||
coupon_code=redemption_model.coupon_code,
|
||||
redeemed_at=redemption_model.redeemed_at,
|
||||
discount_applied=redemption_model.discount_applied,
|
||||
extra_data=redemption_model.extra_data
|
||||
)
|
||||
|
||||
return True, redemption, None
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
return False, None, f"Error al aplicar el cupón: {str(e)}"
|
||||
else:
|
||||
# For tenant-independent subscriptions, return discount without creating redemption
|
||||
# The redemption will be created when the tenant is linked
|
||||
redemption = CouponRedemption(
|
||||
id="pending", # Temporary ID
|
||||
tenant_id="pending", # Will be set during tenant linking
|
||||
coupon_code=code.upper(),
|
||||
redeemed_at=datetime.now(timezone.utc),
|
||||
discount_applied=discount_applied,
|
||||
extra_data={
|
||||
"coupon_type": coupon.discount_type.value,
|
||||
"coupon_value": coupon.discount_value
|
||||
}
|
||||
)
|
||||
return True, redemption, None
|
||||
|
||||
async def get_redemption_by_tenant_and_code(
|
||||
self,
|
||||
tenant_id: str,
|
||||
code: str
|
||||
) -> Optional[CouponRedemption]:
|
||||
"""Get existing redemption for tenant and coupon code"""
|
||||
result = await self.db.execute(
|
||||
select(CouponRedemptionModel).where(
|
||||
and_(
|
||||
CouponRedemptionModel.tenant_id == tenant_id,
|
||||
CouponRedemptionModel.coupon_code == code.upper()
|
||||
)
|
||||
)
|
||||
)
|
||||
redemption_model = result.scalar_one_or_none()
|
||||
|
||||
if not redemption_model:
|
||||
return None
|
||||
|
||||
return CouponRedemption(
|
||||
id=str(redemption_model.id),
|
||||
tenant_id=redemption_model.tenant_id,
|
||||
coupon_code=redemption_model.coupon_code,
|
||||
redeemed_at=redemption_model.redeemed_at,
|
||||
discount_applied=redemption_model.discount_applied,
|
||||
extra_data=redemption_model.extra_data
|
||||
)
|
||||
|
||||
async def get_coupon_usage_stats(self, code: str) -> Optional[dict]:
|
||||
"""Get usage statistics for a coupon"""
|
||||
result = await self.db.execute(
|
||||
select(CouponModel).where(CouponModel.code == code.upper())
|
||||
)
|
||||
coupon_model = result.scalar_one_or_none()
|
||||
|
||||
if not coupon_model:
|
||||
return None
|
||||
|
||||
count_result = await self.db.execute(
|
||||
select(CouponRedemptionModel).where(
|
||||
CouponRedemptionModel.coupon_code == code.upper()
|
||||
)
|
||||
)
|
||||
redemptions_count = len(count_result.scalars().all())
|
||||
|
||||
return {
|
||||
"code": coupon_model.code,
|
||||
"current_redemptions": coupon_model.current_redemptions,
|
||||
"max_redemptions": coupon_model.max_redemptions,
|
||||
"redemptions_remaining": (
|
||||
coupon_model.max_redemptions - coupon_model.current_redemptions
|
||||
if coupon_model.max_redemptions
|
||||
else None
|
||||
),
|
||||
"active": coupon_model.active,
|
||||
"valid_from": coupon_model.valid_from.isoformat(),
|
||||
"valid_until": coupon_model.valid_until.isoformat() if coupon_model.valid_until else None
|
||||
}
|
||||
|
||||
def _model_to_dataclass(self, model: CouponModel) -> Coupon:
|
||||
"""Convert SQLAlchemy model to dataclass"""
|
||||
return Coupon(
|
||||
id=str(model.id),
|
||||
code=model.code,
|
||||
discount_type=DiscountType(model.discount_type),
|
||||
discount_value=model.discount_value,
|
||||
max_redemptions=model.max_redemptions,
|
||||
current_redemptions=model.current_redemptions,
|
||||
valid_from=model.valid_from,
|
||||
valid_until=model.valid_until,
|
||||
active=model.active,
|
||||
created_at=model.created_at,
|
||||
extra_data=model.extra_data
|
||||
)
|
||||
|
||||
def _generate_discount_preview(self, coupon: Coupon) -> dict:
|
||||
"""Generate a preview of the discount to be applied"""
|
||||
description = format_discount_description(coupon)
|
||||
|
||||
preview = {
|
||||
"description": description,
|
||||
"discount_type": coupon.discount_type.value,
|
||||
"discount_value": coupon.discount_value
|
||||
}
|
||||
|
||||
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
|
||||
trial_end = calculate_trial_end_date(0, coupon.discount_value)
|
||||
preview["trial_end_date"] = trial_end.isoformat()
|
||||
preview["total_trial_days"] = 0 + coupon.discount_value
|
||||
|
||||
return preview
|
||||
|
||||
def _calculate_discount_applied(
|
||||
self,
|
||||
coupon: Coupon,
|
||||
base_trial_days: int
|
||||
) -> dict:
|
||||
"""Calculate the actual discount that will be applied"""
|
||||
discount = {
|
||||
"type": coupon.discount_type.value,
|
||||
"value": coupon.discount_value,
|
||||
"description": format_discount_description(coupon)
|
||||
}
|
||||
|
||||
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
|
||||
total_trial_days = base_trial_days + coupon.discount_value
|
||||
trial_end = calculate_trial_end_date(base_trial_days, coupon.discount_value)
|
||||
|
||||
discount["base_trial_days"] = base_trial_days
|
||||
discount["extension_days"] = coupon.discount_value
|
||||
discount["total_trial_days"] = total_trial_days
|
||||
discount["trial_end_date"] = trial_end.isoformat()
|
||||
|
||||
elif coupon.discount_type == DiscountType.PERCENTAGE:
|
||||
discount["percentage_off"] = coupon.discount_value
|
||||
|
||||
elif coupon.discount_type == DiscountType.FIXED_AMOUNT:
|
||||
discount["amount_off_cents"] = coupon.discount_value
|
||||
discount["amount_off_euros"] = coupon.discount_value / 100
|
||||
|
||||
return discount
|
||||
283
services/tenant/app/repositories/event_repository.py
Normal file
283
services/tenant/app/repositories/event_repository.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Event Repository
|
||||
Data access layer for events
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import date, datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from app.models.events import Event, EventTemplate
|
||||
from shared.database.repository import BaseRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EventRepository(BaseRepository[Event]):
|
||||
"""Repository for event management"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(Event, session)
|
||||
|
||||
async def get_events_by_date_range(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
event_types: List[str] = None,
|
||||
confirmed_only: bool = False
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get events within a date range.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
start_date: Start date (inclusive)
|
||||
end_date: End date (inclusive)
|
||||
event_types: Optional filter by event types
|
||||
confirmed_only: Only return confirmed events
|
||||
|
||||
Returns:
|
||||
List of Event objects
|
||||
"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_date >= start_date,
|
||||
Event.event_date <= end_date
|
||||
)
|
||||
)
|
||||
|
||||
if event_types:
|
||||
query = query.where(Event.event_type.in_(event_types))
|
||||
|
||||
if confirmed_only:
|
||||
query = query.where(Event.is_confirmed == True)
|
||||
|
||||
query = query.order_by(Event.event_date)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
logger.debug("Retrieved events by date range",
|
||||
tenant_id=str(tenant_id),
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
count=len(events))
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get events by date range",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_events_for_date(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
event_date: date
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get all events for a specific date.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
event_date: Date to get events for
|
||||
|
||||
Returns:
|
||||
List of Event objects
|
||||
"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_date == event_date
|
||||
)
|
||||
).order_by(Event.start_time)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get events for date",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_upcoming_events(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
days_ahead: int = 30,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get upcoming events.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
days_ahead: Number of days to look ahead
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of upcoming Event objects
|
||||
"""
|
||||
try:
|
||||
from datetime import date, timedelta
|
||||
|
||||
today = date.today()
|
||||
future_date = today + timedelta(days=days_ahead)
|
||||
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_date >= today,
|
||||
Event.event_date <= future_date
|
||||
)
|
||||
).order_by(Event.event_date).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get upcoming events",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def create_event(self, event_data: Dict[str, Any]) -> Event:
|
||||
"""Create a new event"""
|
||||
try:
|
||||
event = Event(**event_data)
|
||||
self.session.add(event)
|
||||
await self.session.flush()
|
||||
|
||||
logger.info("Created event",
|
||||
event_id=str(event.id),
|
||||
event_name=event.event_name,
|
||||
event_date=event.event_date.isoformat())
|
||||
|
||||
return event
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create event", error=str(e))
|
||||
raise
|
||||
|
||||
async def update_event_actual_impact(
|
||||
self,
|
||||
event_id: UUID,
|
||||
actual_impact_multiplier: float,
|
||||
actual_sales_increase_percent: float
|
||||
) -> Optional[Event]:
|
||||
"""
|
||||
Update event with actual impact after it occurs.
|
||||
|
||||
Args:
|
||||
event_id: Event UUID
|
||||
actual_impact_multiplier: Actual demand multiplier observed
|
||||
actual_sales_increase_percent: Actual sales increase percentage
|
||||
|
||||
Returns:
|
||||
Updated Event or None
|
||||
"""
|
||||
try:
|
||||
event = await self.get(event_id)
|
||||
if not event:
|
||||
return None
|
||||
|
||||
event.actual_impact_multiplier = actual_impact_multiplier
|
||||
event.actual_sales_increase_percent = actual_sales_increase_percent
|
||||
|
||||
await self.session.flush()
|
||||
|
||||
logger.info("Updated event actual impact",
|
||||
event_id=str(event_id),
|
||||
actual_multiplier=actual_impact_multiplier)
|
||||
|
||||
return event
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update event actual impact",
|
||||
event_id=str(event_id),
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_events_by_type(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
event_type: str,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""Get events by type"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_type == event_type
|
||||
)
|
||||
).order_by(Event.event_date.desc()).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get events by type",
|
||||
tenant_id=str(tenant_id),
|
||||
event_type=event_type,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
|
||||
class EventTemplateRepository(BaseRepository[EventTemplate]):
|
||||
"""Repository for event template management"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(EventTemplate, session)
|
||||
|
||||
async def get_active_templates(self, tenant_id: UUID) -> List[EventTemplate]:
|
||||
"""Get all active event templates for a tenant"""
|
||||
try:
|
||||
query = select(EventTemplate).where(
|
||||
and_(
|
||||
EventTemplate.tenant_id == tenant_id,
|
||||
EventTemplate.is_active == True
|
||||
)
|
||||
).order_by(EventTemplate.template_name)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
templates = result.scalars().all()
|
||||
|
||||
return list(templates)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active templates",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def create_template(self, template_data: Dict[str, Any]) -> EventTemplate:
|
||||
"""Create a new event template"""
|
||||
try:
|
||||
template = EventTemplate(**template_data)
|
||||
self.session.add(template)
|
||||
await self.session.flush()
|
||||
|
||||
logger.info("Created event template",
|
||||
template_id=str(template.id),
|
||||
template_name=template.template_name)
|
||||
|
||||
return template
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create event template", error=str(e))
|
||||
raise
|
||||
812
services/tenant/app/repositories/subscription_repository.py
Normal file
812
services/tenant/app/repositories/subscription_repository.py
Normal file
@@ -0,0 +1,812 @@
|
||||
"""
|
||||
Subscription Repository
|
||||
Repository for subscription operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import Subscription
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
from shared.subscription.plans import SubscriptionPlanMetadata, QuotaLimits, PlanPricing
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionRepository(TenantBaseRepository):
|
||||
"""Repository for subscription operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Subscriptions are relatively stable, medium cache time (10 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
|
||||
"""Create a new subscription with validation"""
|
||||
try:
|
||||
# Validate subscription data
|
||||
validation_result = self._validate_tenant_data(
|
||||
subscription_data,
|
||||
["tenant_id", "plan"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
|
||||
|
||||
# Check for existing active subscription
|
||||
existing_subscription = await self.get_active_subscription(
|
||||
subscription_data["tenant_id"]
|
||||
)
|
||||
|
||||
if existing_subscription:
|
||||
raise DuplicateRecordError(f"Tenant already has an active subscription")
|
||||
|
||||
# Set default values based on plan from centralized configuration
|
||||
plan = subscription_data["plan"]
|
||||
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
|
||||
|
||||
# Set defaults from centralized plan configuration
|
||||
if "monthly_price" not in subscription_data:
|
||||
billing_cycle = subscription_data.get("billing_cycle", "monthly")
|
||||
subscription_data["monthly_price"] = float(
|
||||
PlanPricing.get_price(plan, billing_cycle)
|
||||
)
|
||||
|
||||
if "max_users" not in subscription_data:
|
||||
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
|
||||
|
||||
if "max_locations" not in subscription_data:
|
||||
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
|
||||
|
||||
if "max_products" not in subscription_data:
|
||||
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
|
||||
|
||||
if "features" not in subscription_data:
|
||||
subscription_data["features"] = {
|
||||
feature: True for feature in plan_info.get("features", [])
|
||||
}
|
||||
|
||||
# Set default subscription values
|
||||
if "status" not in subscription_data:
|
||||
subscription_data["status"] = "active"
|
||||
if "billing_cycle" not in subscription_data:
|
||||
subscription_data["billing_cycle"] = "monthly"
|
||||
if "next_billing_date" not in subscription_data:
|
||||
# Set next billing date based on cycle
|
||||
if subscription_data["billing_cycle"] == "yearly":
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
|
||||
else:
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
# Check if subscription with this subscription_id already exists to prevent duplicates
|
||||
if subscription_data.get('subscription_id'):
|
||||
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
logger.info("Existing subscription updated",
|
||||
subscription_id=subscription_data['subscription_id'],
|
||||
tenant_id=subscription_data.get('tenant_id'),
|
||||
plan=subscription_data.get('plan'))
|
||||
|
||||
return updated_subscription
|
||||
|
||||
# Create subscription
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
logger.info("Subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
return subscription
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create subscription",
|
||||
tenant_id=subscription_data.get("tenant_id"),
|
||||
plan=subscription_data.get("plan"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create subscription: {str(e)}")
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str) -> Optional[Subscription]:
|
||||
"""Get subscription by tenant ID"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions[0] if subscriptions else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription by tenant ID",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
||||
|
||||
async def get_by_provider_id(self, subscription_id: str) -> Optional[Subscription]:
|
||||
"""Get subscription by payment provider subscription ID"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"subscription_id": subscription_id
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions[0] if subscriptions else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription by provider ID",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
||||
|
||||
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
|
||||
"""Get active subscription for tenant"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"status": "active"
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions[0] if subscriptions else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active subscription",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
||||
|
||||
async def get_tenant_subscriptions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
include_inactive: bool = False
|
||||
) -> List[Subscription]:
|
||||
"""Get all subscriptions for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if not include_inactive:
|
||||
filters["status"] = "active"
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant subscriptions",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
|
||||
|
||||
async def update_subscription_plan(
|
||||
self,
|
||||
subscription_id: str,
|
||||
new_plan: str,
|
||||
billing_cycle: str = "monthly"
|
||||
) -> Optional[Subscription]:
|
||||
"""Update subscription plan and pricing using centralized configuration"""
|
||||
try:
|
||||
valid_plans = ["starter", "professional", "enterprise"]
|
||||
if new_plan not in valid_plans:
|
||||
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
|
||||
|
||||
# Get current subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
# Get new plan configuration from centralized source
|
||||
plan_info = SubscriptionPlanMetadata.get_plan_info(new_plan)
|
||||
|
||||
# Update subscription with new plan details
|
||||
update_data = {
|
||||
"plan": new_plan,
|
||||
"monthly_price": float(PlanPricing.get_price(new_plan, billing_cycle)),
|
||||
"billing_cycle": billing_cycle,
|
||||
"max_users": QuotaLimits.get_limit('MAX_USERS', new_plan) or -1,
|
||||
"max_locations": QuotaLimits.get_limit('MAX_LOCATIONS', new_plan) or -1,
|
||||
"max_products": QuotaLimits.get_limit('MAX_PRODUCTS', new_plan) or -1,
|
||||
"features": {feature: True for feature in plan_info.get("features", [])},
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription plan updated",
|
||||
subscription_id=subscription_id,
|
||||
new_plan=new_plan,
|
||||
new_price=update_data["monthly_price"])
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription plan",
|
||||
subscription_id=subscription_id,
|
||||
new_plan=new_plan,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update plan: {str(e)}")
|
||||
|
||||
async def update_subscription_status(
|
||||
self,
|
||||
subscription_id: str,
|
||||
status: str,
|
||||
additional_data: Dict[str, Any] = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Update subscription status with optional additional data"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
update_data = {
|
||||
"status": status,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
# Merge additional data if provided
|
||||
if additional_data:
|
||||
update_data.update(additional_data)
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
if subscription.tenant_id:
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription status updated",
|
||||
subscription_id=subscription_id,
|
||||
new_status=status,
|
||||
additional_data=additional_data)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription status",
|
||||
subscription_id=subscription_id,
|
||||
status=status,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
|
||||
|
||||
async def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
reason: str = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Cancel a subscription"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
update_data = {
|
||||
"status": "cancelled",
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription cancelled",
|
||||
subscription_id=subscription_id,
|
||||
reason=reason)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
|
||||
|
||||
async def suspend_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
reason: str = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Suspend a subscription"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
update_data = {
|
||||
"status": "suspended",
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription suspended",
|
||||
subscription_id=subscription_id,
|
||||
reason=reason)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to suspend subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
|
||||
|
||||
async def reactivate_subscription(
|
||||
self,
|
||||
subscription_id: str
|
||||
) -> Optional[Subscription]:
|
||||
"""Reactivate a cancelled or suspended subscription"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
# Reset billing date when reactivating
|
||||
next_billing_date = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
update_data = {
|
||||
"status": "active",
|
||||
"next_billing_date": next_billing_date,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription reactivated",
|
||||
subscription_id=subscription_id,
|
||||
next_billing_date=next_billing_date)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to reactivate subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
|
||||
|
||||
async def get_subscriptions_due_for_billing(
|
||||
self,
|
||||
days_ahead: int = 3
|
||||
) -> List[Subscription]:
|
||||
"""Get subscriptions that need billing in the next N days"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
AND next_billing_date <= :cutoff_date
|
||||
ORDER BY next_billing_date ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
subscriptions = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
subscription = self.model(**record_dict)
|
||||
subscriptions.append(subscription)
|
||||
|
||||
return subscriptions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscriptions due for billing",
|
||||
days_ahead=days_ahead,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def update_billing_date(
|
||||
self,
|
||||
subscription_id: str,
|
||||
next_billing_date: datetime
|
||||
) -> Optional[Subscription]:
|
||||
"""Update next billing date for subscription"""
|
||||
try:
|
||||
updated_subscription = await self.update(subscription_id, {
|
||||
"next_billing_date": next_billing_date,
|
||||
"updated_at": datetime.utcnow()
|
||||
})
|
||||
|
||||
logger.info("Subscription billing date updated",
|
||||
subscription_id=subscription_id,
|
||||
next_billing_date=next_billing_date)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update billing date",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update billing date: {str(e)}")
|
||||
|
||||
async def get_subscription_statistics(self) -> Dict[str, Any]:
|
||||
"""Get subscription statistics"""
|
||||
try:
|
||||
# Get counts by plan
|
||||
plan_query = text("""
|
||||
SELECT plan, COUNT(*) as count
|
||||
FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
GROUP BY plan
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(plan_query)
|
||||
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
|
||||
|
||||
# Get counts by status
|
||||
status_query = text("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM subscriptions
|
||||
GROUP BY status
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(status_query)
|
||||
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
|
||||
|
||||
# Get revenue statistics
|
||||
revenue_query = text("""
|
||||
SELECT
|
||||
SUM(monthly_price) as total_monthly_revenue,
|
||||
AVG(monthly_price) as avg_monthly_price,
|
||||
COUNT(*) as total_active_subscriptions
|
||||
FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
""")
|
||||
|
||||
revenue_result = await self.session.execute(revenue_query)
|
||||
revenue_row = revenue_result.fetchone()
|
||||
|
||||
# Get upcoming billing count
|
||||
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
|
||||
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
|
||||
|
||||
return {
|
||||
"subscriptions_by_plan": subscriptions_by_plan,
|
||||
"subscriptions_by_status": subscriptions_by_status,
|
||||
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
|
||||
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
|
||||
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
|
||||
"upcoming_billing_30d": upcoming_billing
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription statistics", error=str(e))
|
||||
return {
|
||||
"subscriptions_by_plan": {},
|
||||
"subscriptions_by_status": {},
|
||||
"total_monthly_revenue": 0.0,
|
||||
"avg_monthly_price": 0.0,
|
||||
"total_active_subscriptions": 0,
|
||||
"upcoming_billing_30d": 0
|
||||
}
|
||||
|
||||
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
|
||||
"""Clean up very old cancelled subscriptions (2 years)"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM subscriptions
|
||||
WHERE status IN ('cancelled', 'suspended')
|
||||
AND updated_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old subscriptions",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def _invalidate_cache(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Invalidate subscription cache for a tenant
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
"""
|
||||
try:
|
||||
from app.services.subscription_cache import get_subscription_cache_service
|
||||
|
||||
cache_service = get_subscription_cache_service()
|
||||
await cache_service.invalidate_subscription_cache(tenant_id)
|
||||
|
||||
logger.debug("Invalidated subscription cache from repository",
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to invalidate cache (non-critical)",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
# ========================================================================
|
||||
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
|
||||
# ========================================================================
|
||||
|
||||
async def create_tenant_independent_subscription(
|
||||
self,
|
||||
subscription_data: Dict[str, Any]
|
||||
) -> Subscription:
|
||||
"""Create a subscription not linked to any tenant (for registration flow)"""
|
||||
try:
|
||||
# Validate required data for tenant-independent subscription
|
||||
# user_id may not exist during registration, so validate other required fields
|
||||
required_fields = ["plan", "subscription_id", "customer_id"]
|
||||
validation_result = self._validate_tenant_data(subscription_data, required_fields)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
|
||||
|
||||
# Ensure tenant_id is not provided (this is tenant-independent)
|
||||
if "tenant_id" in subscription_data and subscription_data["tenant_id"]:
|
||||
raise ValidationError("tenant_id should not be provided for tenant-independent subscriptions")
|
||||
|
||||
# Set tenant-independent specific fields
|
||||
subscription_data["tenant_id"] = None
|
||||
subscription_data["is_tenant_linked"] = False
|
||||
subscription_data["tenant_linking_status"] = "pending"
|
||||
subscription_data["linked_at"] = None
|
||||
|
||||
# Set default values based on plan from centralized configuration
|
||||
plan = subscription_data["plan"]
|
||||
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
|
||||
|
||||
# Set defaults from centralized plan configuration
|
||||
if "monthly_price" not in subscription_data:
|
||||
billing_cycle = subscription_data.get("billing_cycle", "monthly")
|
||||
subscription_data["monthly_price"] = float(
|
||||
PlanPricing.get_price(plan, billing_cycle)
|
||||
)
|
||||
|
||||
if "max_users" not in subscription_data:
|
||||
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
|
||||
|
||||
if "max_locations" not in subscription_data:
|
||||
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
|
||||
|
||||
if "max_products" not in subscription_data:
|
||||
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
|
||||
|
||||
if "features" not in subscription_data:
|
||||
subscription_data["features"] = {
|
||||
feature: True for feature in plan_info.get("features", [])
|
||||
}
|
||||
|
||||
# Set default subscription values
|
||||
if "status" not in subscription_data:
|
||||
subscription_data["status"] = "pending_tenant_linking"
|
||||
if "billing_cycle" not in subscription_data:
|
||||
subscription_data["billing_cycle"] = "monthly"
|
||||
if "next_billing_date" not in subscription_data:
|
||||
# Set next billing date based on cycle
|
||||
if subscription_data["billing_cycle"] == "yearly":
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
|
||||
else:
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
# Check if subscription with this subscription_id already exists
|
||||
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
logger.info("Existing tenant-independent subscription updated",
|
||||
subscription_id=subscription_data['subscription_id'],
|
||||
user_id=subscription_data.get('user_id'),
|
||||
plan=subscription_data.get('plan'))
|
||||
|
||||
return updated_subscription
|
||||
else:
|
||||
# Create new subscription, but handle potential duplicate errors
|
||||
try:
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
logger.info("Tenant-independent subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
user_id=subscription.user_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
return subscription
|
||||
except DuplicateRecordError:
|
||||
# Another process may have created the subscription between our check and create
|
||||
# Try to get the existing subscription and return it
|
||||
final_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if final_subscription:
|
||||
logger.info("Race condition detected: subscription already created by another process",
|
||||
subscription_id=subscription_data['subscription_id'])
|
||||
return final_subscription
|
||||
else:
|
||||
# This shouldn't happen, but re-raise the error if we can't find it
|
||||
raise
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tenant-independent subscription",
|
||||
user_id=subscription_data.get("user_id"),
|
||||
plan=subscription_data.get("plan"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create tenant-independent subscription: {str(e)}")
|
||||
|
||||
async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]:
|
||||
"""Get all subscriptions waiting to be linked to tenants"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"tenant_linking_status": "pending",
|
||||
"is_tenant_linked": False
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending tenant linking subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
|
||||
|
||||
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
|
||||
"""Get pending tenant linking subscriptions for a specific user"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"user_id": user_id,
|
||||
"tenant_linking_status": "pending",
|
||||
"is_tenant_linked": False
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending subscriptions by user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
|
||||
|
||||
async def link_subscription_to_tenant(
|
||||
self,
|
||||
subscription_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str
|
||||
) -> Subscription:
|
||||
"""Link a pending subscription to a tenant"""
|
||||
try:
|
||||
# Get the subscription first
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
# Validate subscription can be linked
|
||||
if not subscription.can_be_linked_to_tenant(user_id):
|
||||
raise ValidationError(
|
||||
f"Subscription {subscription_id} cannot be linked to tenant by user {user_id}. "
|
||||
f"Current status: {subscription.tenant_linking_status}, "
|
||||
f"User: {subscription.user_id}, "
|
||||
f"Already linked: {subscription.is_tenant_linked}"
|
||||
)
|
||||
|
||||
# Update subscription with tenant information
|
||||
update_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"is_tenant_linked": True,
|
||||
"tenant_linking_status": "completed",
|
||||
"linked_at": datetime.utcnow(),
|
||||
"status": "active", # Activate subscription when linked to tenant
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache for the tenant
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info("Subscription linked to tenant successfully",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to link subscription to tenant",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}")
|
||||
|
||||
async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int:
|
||||
"""Clean up subscriptions that were never linked to tenants"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM subscriptions
|
||||
WHERE tenant_linking_status = 'pending'
|
||||
AND is_tenant_linked = FALSE
|
||||
AND created_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up orphaned subscriptions",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup orphaned subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_by_customer_id(self, customer_id: str) -> List[Subscription]:
|
||||
"""
|
||||
Get subscriptions by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
List of Subscription objects
|
||||
"""
|
||||
try:
|
||||
query = select(Subscription).where(Subscription.customer_id == customer_id)
|
||||
result = await self.session.execute(query)
|
||||
subscriptions = result.scalars().all()
|
||||
|
||||
logger.debug("Found subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
count=len(subscriptions))
|
||||
return subscriptions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscriptions by customer_id: {str(e)}")
|
||||
218
services/tenant/app/repositories/tenant_location_repository.py
Normal file
218
services/tenant/app/repositories/tenant_location_repository.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Tenant Location Repository
|
||||
Handles database operations for tenant location data
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.orm import selectinload
|
||||
import structlog
|
||||
|
||||
from app.models.tenant_location import TenantLocation
|
||||
from app.models.tenants import Tenant
|
||||
from shared.database.exceptions import DatabaseError
|
||||
from .base import BaseRepository
|
||||
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantLocationRepository(BaseRepository):
|
||||
"""Repository for tenant location operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(TenantLocation, session)
|
||||
|
||||
async def create_location(self, location_data: Dict[str, Any]) -> TenantLocation:
|
||||
"""
|
||||
Create a new tenant location
|
||||
|
||||
Args:
|
||||
location_data: Dictionary containing location information
|
||||
|
||||
Returns:
|
||||
Created TenantLocation object
|
||||
"""
|
||||
try:
|
||||
# Create new location instance
|
||||
location = TenantLocation(**location_data)
|
||||
self.session.add(location)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(location)
|
||||
logger.info(f"Created new tenant location: {location.id} for tenant {location.tenant_id}")
|
||||
return location
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Failed to create tenant location: {str(e)}")
|
||||
raise DatabaseError(f"Failed to create tenant location: {str(e)}")
|
||||
|
||||
async def get_location_by_id(self, location_id: str) -> Optional[TenantLocation]:
|
||||
"""
|
||||
Get a location by its ID
|
||||
|
||||
Args:
|
||||
location_id: UUID of the location
|
||||
|
||||
Returns:
|
||||
TenantLocation object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(TenantLocation.id == location_id)
|
||||
result = await self.session.execute(stmt)
|
||||
location = result.scalar_one_or_none()
|
||||
return location
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get location by ID: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get location by ID: {str(e)}")
|
||||
|
||||
async def get_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
|
||||
"""
|
||||
Get all locations for a specific tenant
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
List of TenantLocation objects
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(TenantLocation.tenant_id == tenant_id)
|
||||
result = await self.session.execute(stmt)
|
||||
locations = result.scalars().all()
|
||||
return locations
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get locations by tenant: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get locations by tenant: {str(e)}")
|
||||
|
||||
async def get_location_by_type(self, tenant_id: str, location_type: str) -> Optional[TenantLocation]:
|
||||
"""
|
||||
Get a location by tenant and type
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
location_type: Type of location (e.g., 'central_production', 'retail_outlet')
|
||||
|
||||
Returns:
|
||||
TenantLocation object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(
|
||||
TenantLocation.tenant_id == tenant_id,
|
||||
TenantLocation.location_type == location_type
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
location = result.scalar_one_or_none()
|
||||
return location
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get location by type: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get location by type: {str(e)}")
|
||||
|
||||
async def update_location(self, location_id: str, location_data: Dict[str, Any]) -> Optional[TenantLocation]:
|
||||
"""
|
||||
Update a tenant location
|
||||
|
||||
Args:
|
||||
location_id: UUID of the location to update
|
||||
location_data: Dictionary containing updated location information
|
||||
|
||||
Returns:
|
||||
Updated TenantLocation object if successful, None if location not found
|
||||
"""
|
||||
try:
|
||||
stmt = (
|
||||
update(TenantLocation)
|
||||
.where(TenantLocation.id == location_id)
|
||||
.values(**location_data)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
# Now fetch the updated location
|
||||
location_stmt = select(TenantLocation).where(TenantLocation.id == location_id)
|
||||
result = await self.session.execute(location_stmt)
|
||||
location = result.scalar_one_or_none()
|
||||
|
||||
if location:
|
||||
await self.session.commit()
|
||||
logger.info(f"Updated tenant location: {location_id}")
|
||||
return location
|
||||
else:
|
||||
await self.session.rollback()
|
||||
logger.warning(f"Location not found for update: {location_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Failed to update location: {str(e)}")
|
||||
raise DatabaseError(f"Failed to update location: {str(e)}")
|
||||
|
||||
async def delete_location(self, location_id: str) -> bool:
|
||||
"""
|
||||
Delete a tenant location
|
||||
|
||||
Args:
|
||||
location_id: UUID of the location to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False if location not found
|
||||
"""
|
||||
try:
|
||||
stmt = delete(TenantLocation).where(TenantLocation.id == location_id)
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount > 0:
|
||||
await self.session.commit()
|
||||
logger.info(f"Deleted tenant location: {location_id}")
|
||||
return True
|
||||
else:
|
||||
await self.session.rollback()
|
||||
logger.warning(f"Location not found for deletion: {location_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Failed to delete location: {str(e)}")
|
||||
raise DatabaseError(f"Failed to delete location: {str(e)}")
|
||||
|
||||
async def get_active_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
|
||||
"""
|
||||
Get all active locations for a specific tenant
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
List of active TenantLocation objects
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(
|
||||
TenantLocation.tenant_id == tenant_id,
|
||||
TenantLocation.is_active == True
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
locations = result.scalars().all()
|
||||
return locations
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get active locations by tenant: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get active locations by tenant: {str(e)}")
|
||||
|
||||
async def get_locations_by_tenant_with_type(self, tenant_id: str, location_types: List[str]) -> List[TenantLocation]:
|
||||
"""
|
||||
Get locations for a specific tenant filtered by location types
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
location_types: List of location types to filter by
|
||||
|
||||
Returns:
|
||||
List of TenantLocation objects matching the criteria
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(
|
||||
TenantLocation.tenant_id == tenant_id,
|
||||
TenantLocation.location_type.in_(location_types)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
locations = result.scalars().all()
|
||||
return locations
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get locations by tenant and type: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get locations by tenant and type: {str(e)}")
|
||||
588
services/tenant/app/repositories/tenant_member_repository.py
Normal file
588
services/tenant/app/repositories/tenant_member_repository.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""
|
||||
Tenant Member Repository
|
||||
Repository for tenant membership operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import TenantMember
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
from shared.config.base import is_internal_service
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantMemberRepository(TenantBaseRepository):
|
||||
"""Repository for tenant member operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Member data changes more frequently, shorter cache time (5 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_membership(self, membership_data: Dict[str, Any]) -> TenantMember:
|
||||
"""Create a new tenant membership with validation"""
|
||||
try:
|
||||
# Validate membership data
|
||||
validation_result = self._validate_tenant_data(
|
||||
membership_data,
|
||||
["tenant_id", "user_id", "role"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid membership data: {validation_result['errors']}")
|
||||
|
||||
# Check for existing membership
|
||||
existing_membership = await self.get_membership(
|
||||
membership_data["tenant_id"],
|
||||
membership_data["user_id"]
|
||||
)
|
||||
|
||||
if existing_membership and existing_membership.is_active:
|
||||
raise DuplicateRecordError(f"User is already an active member of this tenant")
|
||||
|
||||
# Set default values
|
||||
if "is_active" not in membership_data:
|
||||
membership_data["is_active"] = True
|
||||
if "joined_at" not in membership_data:
|
||||
membership_data["joined_at"] = datetime.utcnow()
|
||||
|
||||
# Set permissions based on role
|
||||
if "permissions" not in membership_data:
|
||||
membership_data["permissions"] = self._get_default_permissions(
|
||||
membership_data["role"]
|
||||
)
|
||||
|
||||
# If reactivating existing membership
|
||||
if existing_membership and not existing_membership.is_active:
|
||||
# Update existing membership
|
||||
update_data = {
|
||||
key: value for key, value in membership_data.items()
|
||||
if key not in ["tenant_id", "user_id"]
|
||||
}
|
||||
membership = await self.update(existing_membership.id, update_data)
|
||||
else:
|
||||
# Create new membership
|
||||
membership = await self.create(membership_data)
|
||||
|
||||
logger.info("Tenant membership created",
|
||||
membership_id=membership.id,
|
||||
tenant_id=membership.tenant_id,
|
||||
user_id=membership.user_id,
|
||||
role=membership.role)
|
||||
|
||||
return membership
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create membership",
|
||||
tenant_id=membership_data.get("tenant_id"),
|
||||
user_id=membership_data.get("user_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create membership: {str(e)}")
|
||||
|
||||
async def get_membership(self, tenant_id: str, user_id: str) -> Optional[TenantMember]:
|
||||
"""Get specific membership by tenant and user"""
|
||||
try:
|
||||
# Validate that user_id is a proper UUID format for actual users
|
||||
# Service names like 'inventory-service' should be handled differently
|
||||
import uuid
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
is_valid_uuid = True
|
||||
except ValueError:
|
||||
is_valid_uuid = False
|
||||
|
||||
# For internal service access, return None to indicate no user membership
|
||||
# Service access should be handled at the API layer
|
||||
if not is_valid_uuid:
|
||||
if is_internal_service(user_id):
|
||||
# This is a known internal service request, return None
|
||||
# Service access is granted at the API endpoint level
|
||||
logger.debug("Internal service detected in membership lookup",
|
||||
service=user_id,
|
||||
tenant_id=tenant_id)
|
||||
return None
|
||||
elif user_id == "unknown-service":
|
||||
# Special handling for 'unknown-service' which commonly occurs in demo sessions
|
||||
# This happens when service identification fails during demo operations
|
||||
logger.warning("Demo session service identification issue",
|
||||
service=user_id,
|
||||
tenant_id=tenant_id,
|
||||
message="Service not properly identified - likely demo session context")
|
||||
return None
|
||||
else:
|
||||
# This is an unknown service
|
||||
# Return None to prevent database errors, but log a warning
|
||||
logger.warning("Unknown service detected in membership lookup",
|
||||
service=user_id,
|
||||
tenant_id=tenant_id,
|
||||
message="Service not in internal services registry")
|
||||
return None
|
||||
|
||||
memberships = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return memberships[0] if memberships else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get membership",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get membership: {str(e)}")
|
||||
|
||||
async def get_tenant_members(
|
||||
self,
|
||||
tenant_id: str,
|
||||
active_only: bool = True,
|
||||
role: str = None,
|
||||
include_user_info: bool = False
|
||||
) -> List[TenantMember]:
|
||||
"""Get all members of a tenant with optional user info enrichment"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if active_only:
|
||||
filters["is_active"] = True
|
||||
|
||||
if role:
|
||||
filters["role"] = role
|
||||
|
||||
members = await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="joined_at",
|
||||
order_desc=False
|
||||
)
|
||||
|
||||
# If include_user_info is True, enrich with user data from auth service
|
||||
if include_user_info and members:
|
||||
members = await self._enrich_members_with_user_info(members)
|
||||
|
||||
return members
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant members",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get members: {str(e)}")
|
||||
|
||||
async def _enrich_members_with_user_info(self, members: List[TenantMember]) -> List[TenantMember]:
|
||||
"""Enrich member objects with user information from auth service using batch endpoint"""
|
||||
try:
|
||||
import httpx
|
||||
import os
|
||||
|
||||
if not members:
|
||||
return members
|
||||
|
||||
# Get unique user IDs
|
||||
user_ids = list(set([str(member.user_id) for member in members]))
|
||||
|
||||
if not user_ids:
|
||||
return members
|
||||
|
||||
# Fetch user data from auth service using batch endpoint
|
||||
# Using internal service communication
|
||||
auth_service_url = os.getenv('AUTH_SERVICE_URL', 'http://auth-service:8000')
|
||||
|
||||
user_data_map = {}
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# Use batch endpoint for efficiency
|
||||
response = await client.post(
|
||||
f"{auth_service_url}/api/v1/auth/users/batch",
|
||||
json={"user_ids": user_ids},
|
||||
timeout=10.0,
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
batch_result = response.json()
|
||||
user_data_map = batch_result.get("users", {})
|
||||
logger.info(
|
||||
"Batch user fetch successful",
|
||||
requested_count=len(user_ids),
|
||||
found_count=batch_result.get("found_count", 0)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Batch user fetch failed, falling back to individual calls",
|
||||
status_code=response.status_code
|
||||
)
|
||||
# Fallback to individual calls if batch fails
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
user_data_map[user_id] = user_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Batch user fetch failed, falling back to individual calls", error=str(e))
|
||||
# Fallback to individual calls
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
user_data_map[user_id] = user_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
|
||||
continue
|
||||
|
||||
# Enrich members with user data
|
||||
for member in members:
|
||||
user_id_str = str(member.user_id)
|
||||
if user_id_str in user_data_map and user_data_map[user_id_str] is not None:
|
||||
user_data = user_data_map[user_id_str]
|
||||
# Add user fields as attributes to the member object
|
||||
member.user_email = user_data.get("email")
|
||||
member.user_full_name = user_data.get("full_name")
|
||||
member.user = user_data # Store full user object for compatibility
|
||||
else:
|
||||
# Set defaults for missing users
|
||||
member.user_email = None
|
||||
member.user_full_name = "Unknown User"
|
||||
member.user = None
|
||||
|
||||
return members
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enrich members with user info", error=str(e))
|
||||
# Return members without enrichment if it fails
|
||||
return members
|
||||
|
||||
async def get_user_memberships(
|
||||
self,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
) -> List[TenantMember]:
|
||||
"""Get all tenants a user is a member of"""
|
||||
try:
|
||||
filters = {"user_id": user_id}
|
||||
|
||||
if active_only:
|
||||
filters["is_active"] = True
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="joined_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user memberships",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get memberships: {str(e)}")
|
||||
|
||||
async def verify_user_access(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Verify if user has access to tenant and return access details"""
|
||||
try:
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
|
||||
if not membership or not membership.is_active:
|
||||
return {
|
||||
"has_access": False,
|
||||
"role": "none",
|
||||
"permissions": []
|
||||
}
|
||||
|
||||
# Parse permissions
|
||||
permissions = []
|
||||
if membership.permissions:
|
||||
try:
|
||||
permissions = json.loads(membership.permissions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid permissions JSON for membership",
|
||||
membership_id=membership.id)
|
||||
permissions = self._get_default_permissions(membership.role)
|
||||
|
||||
return {
|
||||
"has_access": True,
|
||||
"role": membership.role,
|
||||
"permissions": permissions,
|
||||
"membership_id": str(membership.id),
|
||||
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify user access",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"has_access": False,
|
||||
"role": "none",
|
||||
"permissions": []
|
||||
}
|
||||
|
||||
async def update_member_role(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
new_role: str,
|
||||
updated_by: str = None
|
||||
) -> Optional[TenantMember]:
|
||||
"""Update member role and permissions"""
|
||||
try:
|
||||
valid_roles = ["owner", "admin", "member", "viewer"]
|
||||
if new_role not in valid_roles:
|
||||
raise ValidationError(f"Invalid role. Must be one of: {valid_roles}")
|
||||
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
if not membership:
|
||||
raise ValidationError("Membership not found")
|
||||
|
||||
# Get new permissions based on role
|
||||
new_permissions = self._get_default_permissions(new_role)
|
||||
|
||||
updated_membership = await self.update(membership.id, {
|
||||
"role": new_role,
|
||||
"permissions": json.dumps(new_permissions)
|
||||
})
|
||||
|
||||
logger.info("Member role updated",
|
||||
membership_id=membership.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
old_role=membership.role,
|
||||
new_role=new_role,
|
||||
updated_by=updated_by)
|
||||
|
||||
return updated_membership
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update member role",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
new_role=new_role,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update role: {str(e)}")
|
||||
|
||||
async def deactivate_membership(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
deactivated_by: str = None
|
||||
) -> Optional[TenantMember]:
|
||||
"""Deactivate a membership (remove user from tenant)"""
|
||||
try:
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
if not membership:
|
||||
raise ValidationError("Membership not found")
|
||||
|
||||
# Don't allow deactivating the owner
|
||||
if membership.role == "owner":
|
||||
raise ValidationError("Cannot deactivate the owner membership")
|
||||
|
||||
updated_membership = await self.update(membership.id, {
|
||||
"is_active": False
|
||||
})
|
||||
|
||||
logger.info("Membership deactivated",
|
||||
membership_id=membership.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
deactivated_by=deactivated_by)
|
||||
|
||||
return updated_membership
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to deactivate membership",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to deactivate membership: {str(e)}")
|
||||
|
||||
async def reactivate_membership(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
reactivated_by: str = None
|
||||
) -> Optional[TenantMember]:
|
||||
"""Reactivate a deactivated membership"""
|
||||
try:
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
if not membership:
|
||||
raise ValidationError("Membership not found")
|
||||
|
||||
updated_membership = await self.update(membership.id, {
|
||||
"is_active": True,
|
||||
"joined_at": datetime.utcnow() # Update join date
|
||||
})
|
||||
|
||||
logger.info("Membership reactivated",
|
||||
membership_id=membership.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
reactivated_by=reactivated_by)
|
||||
|
||||
return updated_membership
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to reactivate membership",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reactivate membership: {str(e)}")
|
||||
|
||||
async def get_membership_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get membership statistics for a tenant"""
|
||||
try:
|
||||
# Get counts by role
|
||||
role_query = text("""
|
||||
SELECT role, COUNT(*) as count
|
||||
FROM tenant_members
|
||||
WHERE tenant_id = :tenant_id AND is_active = true
|
||||
GROUP BY role
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(role_query, {"tenant_id": tenant_id})
|
||||
members_by_role = {row.role: row.count for row in result.fetchall()}
|
||||
|
||||
# Get basic counts
|
||||
total_members = await self.count(filters={"tenant_id": tenant_id})
|
||||
active_members = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
})
|
||||
|
||||
# Get recent activity (members joined in last 30 days)
|
||||
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
|
||||
recent_joins = len(await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
},
|
||||
limit=1000 # High limit to get accurate count
|
||||
))
|
||||
|
||||
# Filter for recent joins (manual filtering since we can't use date range in filters easily)
|
||||
recent_members = 0
|
||||
all_active_members = await self.get_tenant_members(tenant_id, active_only=True)
|
||||
for member in all_active_members:
|
||||
if member.joined_at and member.joined_at >= thirty_days_ago:
|
||||
recent_members += 1
|
||||
|
||||
return {
|
||||
"total_members": total_members,
|
||||
"active_members": active_members,
|
||||
"inactive_members": total_members - active_members,
|
||||
"members_by_role": members_by_role,
|
||||
"recent_joins_30d": recent_members
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get membership statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_members": 0,
|
||||
"active_members": 0,
|
||||
"inactive_members": 0,
|
||||
"members_by_role": {},
|
||||
"recent_joins_30d": 0
|
||||
}
|
||||
|
||||
def _get_default_permissions(self, role: str) -> str:
|
||||
"""Get default permissions JSON string for a role"""
|
||||
permission_map = {
|
||||
"owner": ["read", "write", "admin", "delete"],
|
||||
"admin": ["read", "write", "admin"],
|
||||
"member": ["read", "write"],
|
||||
"viewer": ["read"]
|
||||
}
|
||||
|
||||
permissions = permission_map.get(role, ["read"])
|
||||
return json.dumps(permissions)
|
||||
|
||||
async def bulk_update_permissions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
role_permissions: Dict[str, List[str]]
|
||||
) -> int:
|
||||
"""Bulk update permissions for all members of specific roles"""
|
||||
try:
|
||||
updated_count = 0
|
||||
|
||||
for role, permissions in role_permissions.items():
|
||||
members = await self.get_tenant_members(
|
||||
tenant_id, active_only=True, role=role
|
||||
)
|
||||
|
||||
for member in members:
|
||||
await self.update(member.id, {
|
||||
"permissions": json.dumps(permissions)
|
||||
})
|
||||
updated_count += 1
|
||||
|
||||
logger.info("Bulk updated member permissions",
|
||||
tenant_id=tenant_id,
|
||||
updated_count=updated_count,
|
||||
roles=list(role_permissions.keys()))
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to bulk update permissions",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Bulk permission update failed: {str(e)}")
|
||||
|
||||
async def cleanup_inactive_memberships(self, days_old: int = 180) -> int:
|
||||
"""Clean up old inactive memberships"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM tenant_members
|
||||
WHERE is_active = false
|
||||
AND created_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up inactive memberships",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup inactive memberships",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
680
services/tenant/app/repositories/tenant_repository.py
Normal file
680
services/tenant/app/repositories/tenant_repository.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""
|
||||
Tenant Repository
|
||||
Repository for tenant operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import Tenant, Subscription
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantRepository(TenantBaseRepository):
|
||||
"""Repository for tenant operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Tenants are relatively stable, longer cache time (10 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_tenant(self, tenant_data: Dict[str, Any]) -> Tenant:
|
||||
"""Create a new tenant with validation"""
|
||||
try:
|
||||
# Validate tenant data
|
||||
validation_result = self._validate_tenant_data(
|
||||
tenant_data,
|
||||
["name", "address", "postal_code", "owner_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid tenant data: {validation_result['errors']}")
|
||||
|
||||
# Generate subdomain if not provided
|
||||
if "subdomain" not in tenant_data or not tenant_data["subdomain"]:
|
||||
subdomain = await self._generate_unique_subdomain(tenant_data["name"])
|
||||
tenant_data["subdomain"] = subdomain
|
||||
else:
|
||||
# Check if provided subdomain is unique
|
||||
existing_tenant = await self.get_by_subdomain(tenant_data["subdomain"])
|
||||
if existing_tenant:
|
||||
raise DuplicateRecordError(f"Subdomain {tenant_data['subdomain']} already exists")
|
||||
|
||||
# Set default values
|
||||
if "business_type" not in tenant_data:
|
||||
tenant_data["business_type"] = "bakery"
|
||||
if "city" not in tenant_data:
|
||||
tenant_data["city"] = "Madrid"
|
||||
if "is_active" not in tenant_data:
|
||||
tenant_data["is_active"] = True
|
||||
if "ml_model_trained" not in tenant_data:
|
||||
tenant_data["ml_model_trained"] = False
|
||||
|
||||
# Create tenant
|
||||
tenant = await self.create(tenant_data)
|
||||
|
||||
logger.info("Tenant created successfully",
|
||||
tenant_id=tenant.id,
|
||||
name=tenant.name,
|
||||
subdomain=tenant.subdomain,
|
||||
owner_id=tenant.owner_id)
|
||||
|
||||
return tenant
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tenant",
|
||||
name=tenant_data.get("name"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create tenant: {str(e)}")
|
||||
|
||||
async def get_by_subdomain(self, subdomain: str) -> Optional[Tenant]:
|
||||
"""Get tenant by subdomain"""
|
||||
try:
|
||||
return await self.get_by_field("subdomain", subdomain)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant by subdomain",
|
||||
subdomain=subdomain,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant: {str(e)}")
|
||||
|
||||
async def get_tenants_by_owner(self, owner_id: str) -> List[Tenant]:
|
||||
"""Get all tenants owned by a user"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"owner_id": owner_id, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by owner",
|
||||
owner_id=owner_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenants: {str(e)}")
|
||||
|
||||
async def get_active_tenants(self, skip: int = 0, limit: int = 100) -> List[Tenant]:
|
||||
"""Get all active tenants"""
|
||||
return await self.get_active_records(skip=skip, limit=limit)
|
||||
|
||||
async def search_tenants(
|
||||
self,
|
||||
search_term: str,
|
||||
business_type: str = None,
|
||||
city: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[Tenant]:
|
||||
"""Search tenants by name, address, or other criteria"""
|
||||
try:
|
||||
# Build search conditions
|
||||
conditions = ["is_active = true"]
|
||||
params = {"skip": skip, "limit": limit}
|
||||
|
||||
# Add text search
|
||||
conditions.append("(LOWER(name) LIKE LOWER(:search_term) OR LOWER(address) LIKE LOWER(:search_term))")
|
||||
params["search_term"] = f"%{search_term}%"
|
||||
|
||||
# Add business type filter
|
||||
if business_type:
|
||||
conditions.append("business_type = :business_type")
|
||||
params["business_type"] = business_type
|
||||
|
||||
# Add city filter
|
||||
if city:
|
||||
conditions.append("LOWER(city) = LOWER(:city)")
|
||||
params["city"] = city
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM tenants
|
||||
WHERE {' AND '.join(conditions)}
|
||||
ORDER BY name ASC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
return tenants
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to search tenants",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def update_tenant_model_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
ml_model_trained: bool,
|
||||
last_training_date: datetime = None
|
||||
) -> Optional[Tenant]:
|
||||
"""Update tenant model training status"""
|
||||
try:
|
||||
update_data = {
|
||||
"ml_model_trained": ml_model_trained,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
if last_training_date:
|
||||
update_data["last_training_date"] = last_training_date
|
||||
elif ml_model_trained:
|
||||
update_data["last_training_date"] = datetime.utcnow()
|
||||
|
||||
updated_tenant = await self.update(tenant_id, update_data)
|
||||
|
||||
logger.info("Tenant model status updated",
|
||||
tenant_id=tenant_id,
|
||||
ml_model_trained=ml_model_trained,
|
||||
last_training_date=last_training_date)
|
||||
|
||||
return updated_tenant
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update tenant model status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update model status: {str(e)}")
|
||||
|
||||
|
||||
async def get_tenants_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
radius_km: float = 10.0,
|
||||
limit: int = 50
|
||||
) -> List[Tenant]:
|
||||
"""Get tenants within a geographic radius"""
|
||||
try:
|
||||
# Using Haversine formula for distance calculation
|
||||
query_text = """
|
||||
SELECT *,
|
||||
(6371 * acos(
|
||||
cos(radians(:latitude)) *
|
||||
cos(radians(latitude)) *
|
||||
cos(radians(longitude) - radians(:longitude)) +
|
||||
sin(radians(:latitude)) *
|
||||
sin(radians(latitude))
|
||||
)) AS distance_km
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
AND latitude IS NOT NULL
|
||||
AND longitude IS NOT NULL
|
||||
HAVING distance_km <= :radius_km
|
||||
ORDER BY distance_km ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"radius_km": radius_km,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
# Create tenant object (excluding the calculated distance_km field)
|
||||
record_dict = dict(row._mapping)
|
||||
record_dict.pop("distance_km", None) # Remove calculated field
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
return tenants
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by location",
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
radius_km=radius_km,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_tenant_statistics(self) -> Dict[str, Any]:
|
||||
"""Get global tenant statistics"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_tenants = await self.count()
|
||||
active_tenants = await self.count(filters={"is_active": True})
|
||||
|
||||
# Get tenants by business type
|
||||
business_type_query = text("""
|
||||
SELECT business_type, COUNT(*) as count
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
GROUP BY business_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(business_type_query)
|
||||
business_type_stats = {row.business_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get tenants by subscription tier - now from subscriptions table
|
||||
tier_query = text("""
|
||||
SELECT s.plan as subscription_tier, COUNT(*) as count
|
||||
FROM tenants t
|
||||
LEFT JOIN subscriptions s ON t.id = s.tenant_id AND s.status = 'active'
|
||||
WHERE t.is_active = true
|
||||
GROUP BY s.plan
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
tier_result = await self.session.execute(tier_query)
|
||||
tier_stats = {}
|
||||
for row in tier_result.fetchall():
|
||||
tier = row.subscription_tier if row.subscription_tier else "no_subscription"
|
||||
tier_stats[tier] = row.count
|
||||
|
||||
# Get model training statistics
|
||||
model_query = text("""
|
||||
SELECT
|
||||
COUNT(CASE WHEN ml_model_trained = true THEN 1 END) as trained_count,
|
||||
COUNT(CASE WHEN ml_model_trained = false THEN 1 END) as untrained_count,
|
||||
AVG(EXTRACT(EPOCH FROM (NOW() - last_training_date))/86400) as avg_days_since_training
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
""")
|
||||
|
||||
model_result = await self.session.execute(model_query)
|
||||
model_row = model_result.fetchone()
|
||||
|
||||
# Get recent registrations (last 30 days)
|
||||
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
|
||||
recent_registrations = await self.count(filters={
|
||||
"created_at": f">= '{thirty_days_ago.isoformat()}'"
|
||||
})
|
||||
|
||||
return {
|
||||
"total_tenants": total_tenants,
|
||||
"active_tenants": active_tenants,
|
||||
"inactive_tenants": total_tenants - active_tenants,
|
||||
"tenants_by_business_type": business_type_stats,
|
||||
"tenants_by_subscription": tier_stats,
|
||||
"model_training": {
|
||||
"trained_tenants": int(model_row.trained_count or 0),
|
||||
"untrained_tenants": int(model_row.untrained_count or 0),
|
||||
"avg_days_since_training": float(model_row.avg_days_since_training or 0)
|
||||
} if model_row else {
|
||||
"trained_tenants": 0,
|
||||
"untrained_tenants": 0,
|
||||
"avg_days_since_training": 0.0
|
||||
},
|
||||
"recent_registrations_30d": recent_registrations
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics", error=str(e))
|
||||
return {
|
||||
"total_tenants": 0,
|
||||
"active_tenants": 0,
|
||||
"inactive_tenants": 0,
|
||||
"tenants_by_business_type": {},
|
||||
"tenants_by_subscription": {},
|
||||
"model_training": {
|
||||
"trained_tenants": 0,
|
||||
"untrained_tenants": 0,
|
||||
"avg_days_since_training": 0.0
|
||||
},
|
||||
"recent_registrations_30d": 0
|
||||
}
|
||||
|
||||
async def _generate_unique_subdomain(self, name: str) -> str:
|
||||
"""Generate a unique subdomain from tenant name"""
|
||||
try:
|
||||
# Clean the name to create a subdomain
|
||||
subdomain = name.lower().replace(' ', '-')
|
||||
# Remove accents
|
||||
subdomain = subdomain.replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u')
|
||||
subdomain = subdomain.replace('ñ', 'n')
|
||||
# Keep only alphanumeric and hyphens
|
||||
subdomain = ''.join(c for c in subdomain if c.isalnum() or c == '-')
|
||||
# Remove multiple consecutive hyphens
|
||||
while '--' in subdomain:
|
||||
subdomain = subdomain.replace('--', '-')
|
||||
# Remove leading/trailing hyphens
|
||||
subdomain = subdomain.strip('-')
|
||||
|
||||
# Ensure minimum length
|
||||
if len(subdomain) < 3:
|
||||
subdomain = f"tenant-{subdomain}"
|
||||
|
||||
# Check if subdomain exists
|
||||
existing_tenant = await self.get_by_subdomain(subdomain)
|
||||
if not existing_tenant:
|
||||
return subdomain
|
||||
|
||||
# If it exists, add a unique suffix
|
||||
counter = 1
|
||||
while True:
|
||||
candidate = f"{subdomain}-{counter}"
|
||||
existing_tenant = await self.get_by_subdomain(candidate)
|
||||
if not existing_tenant:
|
||||
return candidate
|
||||
counter += 1
|
||||
|
||||
# Prevent infinite loop
|
||||
if counter > 9999:
|
||||
return f"{subdomain}-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate unique subdomain",
|
||||
name=name,
|
||||
error=str(e))
|
||||
# Fallback to UUID-based subdomain
|
||||
return f"tenant-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def deactivate_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Deactivate a tenant"""
|
||||
return await self.deactivate_record(tenant_id)
|
||||
|
||||
async def activate_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Activate a tenant"""
|
||||
return await self.activate_record(tenant_id)
|
||||
|
||||
async def get_child_tenants(self, parent_tenant_id: str) -> List[Tenant]:
|
||||
"""Get all child tenants for a parent tenant"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"parent_tenant_id": parent_tenant_id, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get child tenants",
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get child tenants: {str(e)}")
|
||||
|
||||
async def get(self, record_id: Any) -> Optional[Tenant]:
|
||||
"""Get tenant by ID - alias for get_by_id for compatibility"""
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def get_child_tenant_count(self, parent_tenant_id: str) -> int:
|
||||
"""Get count of child tenants for a parent tenant"""
|
||||
try:
|
||||
child_tenants = await self.get_child_tenants(parent_tenant_id)
|
||||
return len(child_tenants)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get child tenant count",
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
error=str(e))
|
||||
return 0
|
||||
|
||||
async def get_user_tenants_with_hierarchy(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all tenants a user has access to, organized in hierarchy.
|
||||
Returns parent tenants with their children nested.
|
||||
"""
|
||||
try:
|
||||
# Get all tenants where user is owner or member
|
||||
query_text = """
|
||||
SELECT DISTINCT t.*
|
||||
FROM tenants t
|
||||
LEFT JOIN tenant_members tm ON t.id = tm.tenant_id
|
||||
WHERE (t.owner_id = :user_id OR tm.user_id = :user_id)
|
||||
AND t.is_active = true
|
||||
ORDER BY t.tenant_type DESC, t.created_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"user_id": user_id})
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
# Organize into hierarchy
|
||||
tenant_hierarchy = []
|
||||
parent_map = {}
|
||||
|
||||
# First pass: collect all parent/standalone tenants
|
||||
for tenant in tenants:
|
||||
if tenant.tenant_type in ['parent', 'standalone']:
|
||||
tenant_dict = {
|
||||
'id': str(tenant.id),
|
||||
'name': tenant.name,
|
||||
'subdomain': tenant.subdomain,
|
||||
'tenant_type': tenant.tenant_type,
|
||||
'business_type': tenant.business_type,
|
||||
'business_model': tenant.business_model,
|
||||
'city': tenant.city,
|
||||
'is_active': tenant.is_active,
|
||||
'children': [] if tenant.tenant_type == 'parent' else None
|
||||
}
|
||||
tenant_hierarchy.append(tenant_dict)
|
||||
parent_map[str(tenant.id)] = tenant_dict
|
||||
|
||||
# Second pass: attach children to their parents
|
||||
for tenant in tenants:
|
||||
if tenant.tenant_type == 'child' and tenant.parent_tenant_id:
|
||||
parent_id = str(tenant.parent_tenant_id)
|
||||
if parent_id in parent_map:
|
||||
child_dict = {
|
||||
'id': str(tenant.id),
|
||||
'name': tenant.name,
|
||||
'subdomain': tenant.subdomain,
|
||||
'tenant_type': 'child',
|
||||
'parent_tenant_id': parent_id,
|
||||
'city': tenant.city,
|
||||
'is_active': tenant.is_active
|
||||
}
|
||||
parent_map[parent_id]['children'].append(child_dict)
|
||||
|
||||
return tenant_hierarchy
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user tenants with hierarchy",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_tenants_by_session_id(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get tenants associated with a specific demo session using the demo_session_id field.
|
||||
"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"demo_session_id": session_id,
|
||||
"is_active": True
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by session ID",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenants by session ID: {str(e)}")
|
||||
|
||||
async def get_professional_demo_tenants(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get professional demo tenants filtered by session.
|
||||
|
||||
Args:
|
||||
session_id: Required demo session ID to filter tenants
|
||||
|
||||
Returns:
|
||||
List of professional demo tenants for this specific session
|
||||
"""
|
||||
try:
|
||||
filters = {
|
||||
"business_model": "professional_bakery",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
}
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get professional demo tenants",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get professional demo tenants: {str(e)}")
|
||||
|
||||
async def get_enterprise_demo_tenants(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get enterprise demo tenants (parent and children) filtered by session.
|
||||
|
||||
Args:
|
||||
session_id: Required demo session ID to filter tenants
|
||||
|
||||
Returns:
|
||||
List of enterprise demo tenants (1 parent + 3 children) for this specific session
|
||||
"""
|
||||
try:
|
||||
# Get enterprise demo parent tenants for this session
|
||||
parent_tenants = await self.get_multi(
|
||||
filters={
|
||||
"tenant_type": "parent",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
# Get child tenants for the enterprise demo session
|
||||
child_tenants = await self.get_multi(
|
||||
filters={
|
||||
"tenant_type": "child",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
# Combine parent and child tenants
|
||||
return parent_tenants + child_tenants
|
||||
except Exception as e:
|
||||
logger.error("Failed to get enterprise demo tenants",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}")
|
||||
|
||||
async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get tenant by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Find tenant by joining with subscriptions table
|
||||
# Tenant doesn't have customer_id directly, so we need to find via subscription
|
||||
query = select(Tenant).join(
|
||||
Subscription, Subscription.tenant_id == Tenant.id
|
||||
).where(Subscription.customer_id == customer_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No tenant found for customer_id",
|
||||
customer_id=customer_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}")
|
||||
|
||||
async def get_user_primary_tenant(self, user_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get the primary tenant for a user (the tenant they own)
|
||||
|
||||
Args:
|
||||
user_id: User ID to find primary tenant for
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting primary tenant for user", user_id=user_id)
|
||||
|
||||
# Query for tenant where user is the owner
|
||||
query = select(Tenant).where(Tenant.owner_id == user_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found primary tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No primary tenant found for user", user_id=user_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting primary tenant for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get primary tenant for user: {str(e)}")
|
||||
|
||||
async def get_any_user_tenant(self, user_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get any tenant that the user has access to (via tenant_members)
|
||||
|
||||
Args:
|
||||
user_id: User ID to find accessible tenants for
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting any accessible tenant for user", user_id=user_id)
|
||||
|
||||
# Query for tenant members where user has access
|
||||
from app.models.tenants import TenantMember
|
||||
|
||||
query = select(Tenant).join(
|
||||
TenantMember, Tenant.id == TenantMember.tenant_id
|
||||
).where(TenantMember.user_id == user_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found accessible tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No accessible tenants found for user", user_id=user_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting accessible tenant for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get accessible tenant for user: {str(e)}")
|
||||
@@ -0,0 +1,82 @@
|
||||
# services/tenant/app/repositories/tenant_settings_repository.py
|
||||
"""
|
||||
Tenant Settings Repository
|
||||
Data access layer for tenant settings
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from ..models.tenant_settings import TenantSettings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantSettingsRepository:
|
||||
"""Repository for TenantSettings data access"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: UUID) -> Optional[TenantSettings]:
|
||||
"""
|
||||
Get tenant settings by tenant ID
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
TenantSettings or None if not found
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, settings: TenantSettings) -> TenantSettings:
|
||||
"""
|
||||
Create new tenant settings
|
||||
|
||||
Args:
|
||||
settings: TenantSettings instance to create
|
||||
|
||||
Returns:
|
||||
Created TenantSettings instance
|
||||
"""
|
||||
self.db.add(settings)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(settings)
|
||||
return settings
|
||||
|
||||
async def update(self, settings: TenantSettings) -> TenantSettings:
|
||||
"""
|
||||
Update tenant settings
|
||||
|
||||
Args:
|
||||
settings: TenantSettings instance with updates
|
||||
|
||||
Returns:
|
||||
Updated TenantSettings instance
|
||||
"""
|
||||
await self.db.commit()
|
||||
await self.db.refresh(settings)
|
||||
return settings
|
||||
|
||||
async def delete(self, tenant_id: UUID) -> None:
|
||||
"""
|
||||
Delete tenant settings
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if settings:
|
||||
await self.db.delete(settings)
|
||||
await self.db.commit()
|
||||
Reference in New Issue
Block a user