Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
"""
Tenant Service Repositories
Repository implementations for tenant service
"""
from .base import TenantBaseRepository
from .tenant_repository import TenantRepository
from .tenant_member_repository import TenantMemberRepository
from .subscription_repository import SubscriptionRepository
__all__ = [
"TenantBaseRepository",
"TenantRepository",
"TenantMemberRepository",
"SubscriptionRepository"
]

View File

@@ -0,0 +1,234 @@
"""
Base Repository for Tenant Service
Service-specific repository base class with tenant management utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
import json
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TenantBaseRepository(BaseRepository):
"""Base repository for tenant service with common tenant operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Tenant data is relatively stable, medium cache time (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by user ID (for cross-service references)"""
if hasattr(self.model, 'user_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"user_id": user_id},
order_by="created_at",
order_desc=True
)
elif hasattr(self.model, 'owner_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"owner_id": user_id},
order_by="created_at",
order_desc=True
)
return []
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_old_records(self, days_old: int = 365) -> int:
"""Clean up old tenant records (very conservative - 1 year)"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
table_name = self.model.__tablename__
# Only delete inactive records that are very old
conditions = [
"created_at < :cutoff_date"
]
if hasattr(self.model, 'is_active'):
conditions.append("is_active = false")
query_text = f"""
DELETE FROM {table_name}
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics for a tenant"""
try:
table_name = self.model.__tablename__
# Get basic counts
total_records = await self.count(filters={"tenant_id": tenant_id})
# Get active records if applicable
active_records = total_records
if hasattr(self.model, 'is_active'):
active_records = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
# Get recent activity (records in last 7 days)
seven_days_ago = datetime.utcnow() - timedelta(days=7)
recent_query = text(f"""
SELECT COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
AND created_at >= :seven_days_ago
""")
result = await self.session.execute(recent_query, {
"tenant_id": tenant_id,
"seven_days_ago": seven_days_ago
})
recent_records = result.scalar() or 0
return {
"total_records": total_records,
"active_records": active_records,
"inactive_records": total_records - active_records,
"recent_records_7d": recent_records
}
except Exception as e:
logger.error("Failed to get tenant statistics",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
return {
"total_records": 0,
"active_records": 0,
"inactive_records": 0,
"recent_records_7d": 0
}
def _validate_tenant_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate tenant-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate user_id format if present
if "user_id" in data and data["user_id"]:
user_id = data["user_id"]
if not isinstance(user_id, str) or len(user_id) < 1:
errors.append("Invalid user_id format")
# Validate owner_id format if present
if "owner_id" in data and data["owner_id"]:
owner_id = data["owner_id"]
if not isinstance(owner_id, str) or len(owner_id) < 1:
errors.append("Invalid owner_id format")
# Validate email format if present
if "email" in data and data["email"]:
email = data["email"]
if "@" not in email or "." not in email.split("@")[-1]:
errors.append("Invalid email format")
# Validate phone format if present (basic validation)
if "phone" in data and data["phone"]:
phone = data["phone"]
if not isinstance(phone, str) or len(phone) < 9:
errors.append("Invalid phone format")
# Validate coordinates if present
if "latitude" in data and data["latitude"] is not None:
try:
lat = float(data["latitude"])
if lat < -90 or lat > 90:
errors.append("Invalid latitude - must be between -90 and 90")
except (ValueError, TypeError):
errors.append("Invalid latitude format")
if "longitude" in data and data["longitude"] is not None:
try:
lng = float(data["longitude"])
if lng < -180 or lng > 180:
errors.append("Invalid longitude - must be between -180 and 180")
except (ValueError, TypeError):
errors.append("Invalid longitude format")
# Validate JSON fields
json_fields = ["permissions"]
for field in json_fields:
if field in data and data[field]:
if isinstance(data[field], str):
try:
json.loads(data[field])
except json.JSONDecodeError:
errors.append(f"Invalid JSON format in {field}")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,326 @@
"""
Repository for coupon data access and validation
"""
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import and_, select
from sqlalchemy.orm import selectinload
from app.models.coupon import CouponModel, CouponRedemptionModel
from shared.subscription.coupons import (
Coupon,
CouponRedemption,
CouponValidationResult,
DiscountType,
calculate_trial_end_date,
format_discount_description
)
class CouponRepository:
"""Data access layer for coupon operations"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_coupon_by_code(self, code: str) -> Optional[Coupon]:
"""
Retrieve coupon by code.
Returns None if not found.
"""
result = await self.db.execute(
select(CouponModel).where(CouponModel.code == code.upper())
)
coupon_model = result.scalar_one_or_none()
if not coupon_model:
return None
return self._model_to_dataclass(coupon_model)
async def validate_coupon(
self,
code: str,
tenant_id: str
) -> CouponValidationResult:
"""
Validate a coupon code for a specific tenant.
Checks: existence, validity, redemption limits, and if tenant already used it.
"""
# Get coupon
coupon = await self.get_coupon_by_code(code)
if not coupon:
return CouponValidationResult(
valid=False,
coupon=None,
error_message="Código de cupón inválido",
discount_preview=None
)
# Check if coupon can be redeemed
can_redeem, reason = coupon.can_be_redeemed()
if not can_redeem:
error_messages = {
"Coupon is inactive": "Este cupón no está activo",
"Coupon is not yet valid": "Este cupón aún no es válido",
"Coupon has expired": "Este cupón ha expirado",
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
}
return CouponValidationResult(
valid=False,
coupon=coupon,
error_message=error_messages.get(reason, reason),
discount_preview=None
)
# Check if tenant already redeemed this coupon
result = await self.db.execute(
select(CouponRedemptionModel).where(
and_(
CouponRedemptionModel.tenant_id == tenant_id,
CouponRedemptionModel.coupon_code == code.upper()
)
)
)
existing_redemption = result.scalar_one_or_none()
if existing_redemption:
return CouponValidationResult(
valid=False,
coupon=coupon,
error_message="Ya has utilizado este cupón",
discount_preview=None
)
# Generate discount preview
discount_preview = self._generate_discount_preview(coupon)
return CouponValidationResult(
valid=True,
coupon=coupon,
error_message=None,
discount_preview=discount_preview
)
async def redeem_coupon(
self,
code: str,
tenant_id: Optional[str],
base_trial_days: int = 0
) -> tuple[bool, Optional[CouponRedemption], Optional[str]]:
"""
Redeem a coupon for a tenant.
For tenant-independent registrations, tenant_id can be None initially.
Returns (success, redemption, error_message)
"""
# For tenant-independent registrations, skip tenant validation
if tenant_id:
# Validate first
validation = await self.validate_coupon(code, tenant_id)
if not validation.valid:
return False, None, validation.error_message
coupon = validation.coupon
else:
# Just get the coupon and validate its general availability
coupon = await self.get_coupon_by_code(code)
if not coupon:
return False, None, "Código de cupón inválido"
# Check if coupon can be redeemed
can_redeem, reason = coupon.can_be_redeemed()
if not can_redeem:
error_messages = {
"Coupon is inactive": "Este cupón no está activo",
"Coupon is not yet valid": "Este cupón aún no es válido",
"Coupon has expired": "Este cupón ha expirado",
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
}
return False, None, error_messages.get(reason, reason)
# Calculate discount applied
discount_applied = self._calculate_discount_applied(
coupon,
base_trial_days
)
# Only create redemption record if tenant_id is provided
# For tenant-independent subscriptions, skip redemption record creation
if tenant_id:
# Create redemption record
redemption_model = CouponRedemptionModel(
tenant_id=tenant_id,
coupon_code=code.upper(),
redeemed_at=datetime.now(timezone.utc),
discount_applied=discount_applied,
extra_data={
"coupon_type": coupon.discount_type.value,
"coupon_value": coupon.discount_value
}
)
self.db.add(redemption_model)
# Increment coupon redemption count
result = await self.db.execute(
select(CouponModel).where(CouponModel.code == code.upper())
)
coupon_model = result.scalar_one_or_none()
if coupon_model:
coupon_model.current_redemptions += 1
try:
await self.db.commit()
await self.db.refresh(redemption_model)
redemption = CouponRedemption(
id=str(redemption_model.id),
tenant_id=redemption_model.tenant_id,
coupon_code=redemption_model.coupon_code,
redeemed_at=redemption_model.redeemed_at,
discount_applied=redemption_model.discount_applied,
extra_data=redemption_model.extra_data
)
return True, redemption, None
except Exception as e:
await self.db.rollback()
return False, None, f"Error al aplicar el cupón: {str(e)}"
else:
# For tenant-independent subscriptions, return discount without creating redemption
# The redemption will be created when the tenant is linked
redemption = CouponRedemption(
id="pending", # Temporary ID
tenant_id="pending", # Will be set during tenant linking
coupon_code=code.upper(),
redeemed_at=datetime.now(timezone.utc),
discount_applied=discount_applied,
extra_data={
"coupon_type": coupon.discount_type.value,
"coupon_value": coupon.discount_value
}
)
return True, redemption, None
async def get_redemption_by_tenant_and_code(
self,
tenant_id: str,
code: str
) -> Optional[CouponRedemption]:
"""Get existing redemption for tenant and coupon code"""
result = await self.db.execute(
select(CouponRedemptionModel).where(
and_(
CouponRedemptionModel.tenant_id == tenant_id,
CouponRedemptionModel.coupon_code == code.upper()
)
)
)
redemption_model = result.scalar_one_or_none()
if not redemption_model:
return None
return CouponRedemption(
id=str(redemption_model.id),
tenant_id=redemption_model.tenant_id,
coupon_code=redemption_model.coupon_code,
redeemed_at=redemption_model.redeemed_at,
discount_applied=redemption_model.discount_applied,
extra_data=redemption_model.extra_data
)
async def get_coupon_usage_stats(self, code: str) -> Optional[dict]:
"""Get usage statistics for a coupon"""
result = await self.db.execute(
select(CouponModel).where(CouponModel.code == code.upper())
)
coupon_model = result.scalar_one_or_none()
if not coupon_model:
return None
count_result = await self.db.execute(
select(CouponRedemptionModel).where(
CouponRedemptionModel.coupon_code == code.upper()
)
)
redemptions_count = len(count_result.scalars().all())
return {
"code": coupon_model.code,
"current_redemptions": coupon_model.current_redemptions,
"max_redemptions": coupon_model.max_redemptions,
"redemptions_remaining": (
coupon_model.max_redemptions - coupon_model.current_redemptions
if coupon_model.max_redemptions
else None
),
"active": coupon_model.active,
"valid_from": coupon_model.valid_from.isoformat(),
"valid_until": coupon_model.valid_until.isoformat() if coupon_model.valid_until else None
}
def _model_to_dataclass(self, model: CouponModel) -> Coupon:
"""Convert SQLAlchemy model to dataclass"""
return Coupon(
id=str(model.id),
code=model.code,
discount_type=DiscountType(model.discount_type),
discount_value=model.discount_value,
max_redemptions=model.max_redemptions,
current_redemptions=model.current_redemptions,
valid_from=model.valid_from,
valid_until=model.valid_until,
active=model.active,
created_at=model.created_at,
extra_data=model.extra_data
)
def _generate_discount_preview(self, coupon: Coupon) -> dict:
"""Generate a preview of the discount to be applied"""
description = format_discount_description(coupon)
preview = {
"description": description,
"discount_type": coupon.discount_type.value,
"discount_value": coupon.discount_value
}
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
trial_end = calculate_trial_end_date(0, coupon.discount_value)
preview["trial_end_date"] = trial_end.isoformat()
preview["total_trial_days"] = 0 + coupon.discount_value
return preview
def _calculate_discount_applied(
self,
coupon: Coupon,
base_trial_days: int
) -> dict:
"""Calculate the actual discount that will be applied"""
discount = {
"type": coupon.discount_type.value,
"value": coupon.discount_value,
"description": format_discount_description(coupon)
}
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
total_trial_days = base_trial_days + coupon.discount_value
trial_end = calculate_trial_end_date(base_trial_days, coupon.discount_value)
discount["base_trial_days"] = base_trial_days
discount["extension_days"] = coupon.discount_value
discount["total_trial_days"] = total_trial_days
discount["trial_end_date"] = trial_end.isoformat()
elif coupon.discount_type == DiscountType.PERCENTAGE:
discount["percentage_off"] = coupon.discount_value
elif coupon.discount_type == DiscountType.FIXED_AMOUNT:
discount["amount_off_cents"] = coupon.discount_value
discount["amount_off_euros"] = coupon.discount_value / 100
return discount

View File

@@ -0,0 +1,283 @@
"""
Event Repository
Data access layer for events
"""
from typing import List, Optional, Dict, Any
from datetime import date, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func
from uuid import UUID
import structlog
from app.models.events import Event, EventTemplate
from shared.database.repository import BaseRepository
logger = structlog.get_logger()
class EventRepository(BaseRepository[Event]):
"""Repository for event management"""
def __init__(self, session: AsyncSession):
super().__init__(Event, session)
async def get_events_by_date_range(
self,
tenant_id: UUID,
start_date: date,
end_date: date,
event_types: List[str] = None,
confirmed_only: bool = False
) -> List[Event]:
"""
Get events within a date range.
Args:
tenant_id: Tenant UUID
start_date: Start date (inclusive)
end_date: End date (inclusive)
event_types: Optional filter by event types
confirmed_only: Only return confirmed events
Returns:
List of Event objects
"""
try:
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_date >= start_date,
Event.event_date <= end_date
)
)
if event_types:
query = query.where(Event.event_type.in_(event_types))
if confirmed_only:
query = query.where(Event.is_confirmed == True)
query = query.order_by(Event.event_date)
result = await self.session.execute(query)
events = result.scalars().all()
logger.debug("Retrieved events by date range",
tenant_id=str(tenant_id),
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
count=len(events))
return list(events)
except Exception as e:
logger.error("Failed to get events by date range",
tenant_id=str(tenant_id),
error=str(e))
return []
async def get_events_for_date(
self,
tenant_id: UUID,
event_date: date
) -> List[Event]:
"""
Get all events for a specific date.
Args:
tenant_id: Tenant UUID
event_date: Date to get events for
Returns:
List of Event objects
"""
try:
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_date == event_date
)
).order_by(Event.start_time)
result = await self.session.execute(query)
events = result.scalars().all()
return list(events)
except Exception as e:
logger.error("Failed to get events for date",
tenant_id=str(tenant_id),
error=str(e))
return []
async def get_upcoming_events(
self,
tenant_id: UUID,
days_ahead: int = 30,
limit: int = 100
) -> List[Event]:
"""
Get upcoming events.
Args:
tenant_id: Tenant UUID
days_ahead: Number of days to look ahead
limit: Maximum number of events to return
Returns:
List of upcoming Event objects
"""
try:
from datetime import date, timedelta
today = date.today()
future_date = today + timedelta(days=days_ahead)
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_date >= today,
Event.event_date <= future_date
)
).order_by(Event.event_date).limit(limit)
result = await self.session.execute(query)
events = result.scalars().all()
return list(events)
except Exception as e:
logger.error("Failed to get upcoming events",
tenant_id=str(tenant_id),
error=str(e))
return []
async def create_event(self, event_data: Dict[str, Any]) -> Event:
"""Create a new event"""
try:
event = Event(**event_data)
self.session.add(event)
await self.session.flush()
logger.info("Created event",
event_id=str(event.id),
event_name=event.event_name,
event_date=event.event_date.isoformat())
return event
except Exception as e:
logger.error("Failed to create event", error=str(e))
raise
async def update_event_actual_impact(
self,
event_id: UUID,
actual_impact_multiplier: float,
actual_sales_increase_percent: float
) -> Optional[Event]:
"""
Update event with actual impact after it occurs.
Args:
event_id: Event UUID
actual_impact_multiplier: Actual demand multiplier observed
actual_sales_increase_percent: Actual sales increase percentage
Returns:
Updated Event or None
"""
try:
event = await self.get(event_id)
if not event:
return None
event.actual_impact_multiplier = actual_impact_multiplier
event.actual_sales_increase_percent = actual_sales_increase_percent
await self.session.flush()
logger.info("Updated event actual impact",
event_id=str(event_id),
actual_multiplier=actual_impact_multiplier)
return event
except Exception as e:
logger.error("Failed to update event actual impact",
event_id=str(event_id),
error=str(e))
return None
async def get_events_by_type(
self,
tenant_id: UUID,
event_type: str,
limit: int = 100
) -> List[Event]:
"""Get events by type"""
try:
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_type == event_type
)
).order_by(Event.event_date.desc()).limit(limit)
result = await self.session.execute(query)
events = result.scalars().all()
return list(events)
except Exception as e:
logger.error("Failed to get events by type",
tenant_id=str(tenant_id),
event_type=event_type,
error=str(e))
return []
class EventTemplateRepository(BaseRepository[EventTemplate]):
"""Repository for event template management"""
def __init__(self, session: AsyncSession):
super().__init__(EventTemplate, session)
async def get_active_templates(self, tenant_id: UUID) -> List[EventTemplate]:
"""Get all active event templates for a tenant"""
try:
query = select(EventTemplate).where(
and_(
EventTemplate.tenant_id == tenant_id,
EventTemplate.is_active == True
)
).order_by(EventTemplate.template_name)
result = await self.session.execute(query)
templates = result.scalars().all()
return list(templates)
except Exception as e:
logger.error("Failed to get active templates",
tenant_id=str(tenant_id),
error=str(e))
return []
async def create_template(self, template_data: Dict[str, Any]) -> EventTemplate:
"""Create a new event template"""
try:
template = EventTemplate(**template_data)
self.session.add(template)
await self.session.flush()
logger.info("Created event template",
template_id=str(template.id),
template_name=template.template_name)
return template
except Exception as e:
logger.error("Failed to create event template", error=str(e))
raise

View File

@@ -0,0 +1,812 @@
"""
Subscription Repository
Repository for subscription operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import json
from .base import TenantBaseRepository
from app.models.tenants import Subscription
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
from shared.subscription.plans import SubscriptionPlanMetadata, QuotaLimits, PlanPricing
logger = structlog.get_logger()
class SubscriptionRepository(TenantBaseRepository):
"""Repository for subscription operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Subscriptions are relatively stable, medium cache time (10 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
"""Create a new subscription with validation"""
try:
# Validate subscription data
validation_result = self._validate_tenant_data(
subscription_data,
["tenant_id", "plan"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
# Check for existing active subscription
existing_subscription = await self.get_active_subscription(
subscription_data["tenant_id"]
)
if existing_subscription:
raise DuplicateRecordError(f"Tenant already has an active subscription")
# Set default values based on plan from centralized configuration
plan = subscription_data["plan"]
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
# Set defaults from centralized plan configuration
if "monthly_price" not in subscription_data:
billing_cycle = subscription_data.get("billing_cycle", "monthly")
subscription_data["monthly_price"] = float(
PlanPricing.get_price(plan, billing_cycle)
)
if "max_users" not in subscription_data:
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
if "max_locations" not in subscription_data:
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
if "max_products" not in subscription_data:
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
if "features" not in subscription_data:
subscription_data["features"] = {
feature: True for feature in plan_info.get("features", [])
}
# Set default subscription values
if "status" not in subscription_data:
subscription_data["status"] = "active"
if "billing_cycle" not in subscription_data:
subscription_data["billing_cycle"] = "monthly"
if "next_billing_date" not in subscription_data:
# Set next billing date based on cycle
if subscription_data["billing_cycle"] == "yearly":
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Check if subscription with this subscription_id already exists to prevent duplicates
if subscription_data.get('subscription_id'):
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
logger.info("Existing subscription updated",
subscription_id=subscription_data['subscription_id'],
tenant_id=subscription_data.get('tenant_id'),
plan=subscription_data.get('plan'))
return updated_subscription
# Create subscription
subscription = await self.create(subscription_data)
logger.info("Subscription created successfully",
subscription_id=subscription.id,
tenant_id=subscription.tenant_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create subscription",
tenant_id=subscription_data.get("tenant_id"),
plan=subscription_data.get("plan"),
error=str(e))
raise DatabaseError(f"Failed to create subscription: {str(e)}")
async def get_by_tenant_id(self, tenant_id: str) -> Optional[Subscription]:
"""Get subscription by tenant ID"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by tenant ID",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_by_provider_id(self, subscription_id: str) -> Optional[Subscription]:
"""Get subscription by payment provider subscription ID"""
try:
subscriptions = await self.get_multi(
filters={
"subscription_id": subscription_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by provider ID",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
"""Get active subscription for tenant"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id,
"status": "active"
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get active subscription",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_tenant_subscriptions(
self,
tenant_id: str,
include_inactive: bool = False
) -> List[Subscription]:
"""Get all subscriptions for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if not include_inactive:
filters["status"] = "active"
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenant subscriptions",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
async def update_subscription_plan(
self,
subscription_id: str,
new_plan: str,
billing_cycle: str = "monthly"
) -> Optional[Subscription]:
"""Update subscription plan and pricing using centralized configuration"""
try:
valid_plans = ["starter", "professional", "enterprise"]
if new_plan not in valid_plans:
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
# Get current subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Get new plan configuration from centralized source
plan_info = SubscriptionPlanMetadata.get_plan_info(new_plan)
# Update subscription with new plan details
update_data = {
"plan": new_plan,
"monthly_price": float(PlanPricing.get_price(new_plan, billing_cycle)),
"billing_cycle": billing_cycle,
"max_users": QuotaLimits.get_limit('MAX_USERS', new_plan) or -1,
"max_locations": QuotaLimits.get_limit('MAX_LOCATIONS', new_plan) or -1,
"max_products": QuotaLimits.get_limit('MAX_PRODUCTS', new_plan) or -1,
"features": {feature: True for feature in plan_info.get("features", [])},
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription plan updated",
subscription_id=subscription_id,
new_plan=new_plan,
new_price=update_data["monthly_price"])
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription plan",
subscription_id=subscription_id,
new_plan=new_plan,
error=str(e))
raise DatabaseError(f"Failed to update plan: {str(e)}")
async def update_subscription_status(
self,
subscription_id: str,
status: str,
additional_data: Dict[str, Any] = None
) -> Optional[Subscription]:
"""Update subscription status with optional additional data"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": status,
"updated_at": datetime.utcnow()
}
# Merge additional data if provided
if additional_data:
update_data.update(additional_data)
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
if subscription.tenant_id:
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription status updated",
subscription_id=subscription_id,
new_status=status,
additional_data=additional_data)
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription status",
subscription_id=subscription_id,
status=status,
error=str(e))
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
async def cancel_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Cancel a subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": "cancelled",
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription cancelled",
subscription_id=subscription_id,
reason=reason)
return updated_subscription
except Exception as e:
logger.error("Failed to cancel subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
async def suspend_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Suspend a subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": "suspended",
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription suspended",
subscription_id=subscription_id,
reason=reason)
return updated_subscription
except Exception as e:
logger.error("Failed to suspend subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
async def reactivate_subscription(
self,
subscription_id: str
) -> Optional[Subscription]:
"""Reactivate a cancelled or suspended subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Reset billing date when reactivating
next_billing_date = datetime.utcnow() + timedelta(days=30)
update_data = {
"status": "active",
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription reactivated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
return updated_subscription
except Exception as e:
logger.error("Failed to reactivate subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
async def get_subscriptions_due_for_billing(
self,
days_ahead: int = 3
) -> List[Subscription]:
"""Get subscriptions that need billing in the next N days"""
try:
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
query_text = """
SELECT * FROM subscriptions
WHERE status = 'active'
AND next_billing_date <= :cutoff_date
ORDER BY next_billing_date ASC
"""
result = await self.session.execute(text(query_text), {
"cutoff_date": cutoff_date
})
subscriptions = []
for row in result.fetchall():
record_dict = dict(row._mapping)
subscription = self.model(**record_dict)
subscriptions.append(subscription)
return subscriptions
except Exception as e:
logger.error("Failed to get subscriptions due for billing",
days_ahead=days_ahead,
error=str(e))
return []
async def update_billing_date(
self,
subscription_id: str,
next_billing_date: datetime
) -> Optional[Subscription]:
"""Update next billing date for subscription"""
try:
updated_subscription = await self.update(subscription_id, {
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
})
logger.info("Subscription billing date updated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
return updated_subscription
except Exception as e:
logger.error("Failed to update billing date",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to update billing date: {str(e)}")
async def get_subscription_statistics(self) -> Dict[str, Any]:
"""Get subscription statistics"""
try:
# Get counts by plan
plan_query = text("""
SELECT plan, COUNT(*) as count
FROM subscriptions
WHERE status = 'active'
GROUP BY plan
ORDER BY count DESC
""")
result = await self.session.execute(plan_query)
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
# Get counts by status
status_query = text("""
SELECT status, COUNT(*) as count
FROM subscriptions
GROUP BY status
ORDER BY count DESC
""")
result = await self.session.execute(status_query)
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
# Get revenue statistics
revenue_query = text("""
SELECT
SUM(monthly_price) as total_monthly_revenue,
AVG(monthly_price) as avg_monthly_price,
COUNT(*) as total_active_subscriptions
FROM subscriptions
WHERE status = 'active'
""")
revenue_result = await self.session.execute(revenue_query)
revenue_row = revenue_result.fetchone()
# Get upcoming billing count
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
return {
"subscriptions_by_plan": subscriptions_by_plan,
"subscriptions_by_status": subscriptions_by_status,
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
"upcoming_billing_30d": upcoming_billing
}
except Exception as e:
logger.error("Failed to get subscription statistics", error=str(e))
return {
"subscriptions_by_plan": {},
"subscriptions_by_status": {},
"total_monthly_revenue": 0.0,
"avg_monthly_price": 0.0,
"total_active_subscriptions": 0,
"upcoming_billing_30d": 0
}
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
"""Clean up very old cancelled subscriptions (2 years)"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM subscriptions
WHERE status IN ('cancelled', 'suspended')
AND updated_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old subscriptions",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old subscriptions",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def _invalidate_cache(self, tenant_id: str) -> None:
"""
Invalidate subscription cache for a tenant
Args:
tenant_id: Tenant ID
"""
try:
from app.services.subscription_cache import get_subscription_cache_service
cache_service = get_subscription_cache_service()
await cache_service.invalidate_subscription_cache(tenant_id)
logger.debug("Invalidated subscription cache from repository",
tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to invalidate cache (non-critical)",
tenant_id=tenant_id, error=str(e))
# ========================================================================
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
# ========================================================================
async def create_tenant_independent_subscription(
self,
subscription_data: Dict[str, Any]
) -> Subscription:
"""Create a subscription not linked to any tenant (for registration flow)"""
try:
# Validate required data for tenant-independent subscription
# user_id may not exist during registration, so validate other required fields
required_fields = ["plan", "subscription_id", "customer_id"]
validation_result = self._validate_tenant_data(subscription_data, required_fields)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
# Ensure tenant_id is not provided (this is tenant-independent)
if "tenant_id" in subscription_data and subscription_data["tenant_id"]:
raise ValidationError("tenant_id should not be provided for tenant-independent subscriptions")
# Set tenant-independent specific fields
subscription_data["tenant_id"] = None
subscription_data["is_tenant_linked"] = False
subscription_data["tenant_linking_status"] = "pending"
subscription_data["linked_at"] = None
# Set default values based on plan from centralized configuration
plan = subscription_data["plan"]
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
# Set defaults from centralized plan configuration
if "monthly_price" not in subscription_data:
billing_cycle = subscription_data.get("billing_cycle", "monthly")
subscription_data["monthly_price"] = float(
PlanPricing.get_price(plan, billing_cycle)
)
if "max_users" not in subscription_data:
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
if "max_locations" not in subscription_data:
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
if "max_products" not in subscription_data:
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
if "features" not in subscription_data:
subscription_data["features"] = {
feature: True for feature in plan_info.get("features", [])
}
# Set default subscription values
if "status" not in subscription_data:
subscription_data["status"] = "pending_tenant_linking"
if "billing_cycle" not in subscription_data:
subscription_data["billing_cycle"] = "monthly"
if "next_billing_date" not in subscription_data:
# Set next billing date based on cycle
if subscription_data["billing_cycle"] == "yearly":
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Check if subscription with this subscription_id already exists
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
logger.info("Existing tenant-independent subscription updated",
subscription_id=subscription_data['subscription_id'],
user_id=subscription_data.get('user_id'),
plan=subscription_data.get('plan'))
return updated_subscription
else:
# Create new subscription, but handle potential duplicate errors
try:
subscription = await self.create(subscription_data)
logger.info("Tenant-independent subscription created successfully",
subscription_id=subscription.id,
user_id=subscription.user_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except DuplicateRecordError:
# Another process may have created the subscription between our check and create
# Try to get the existing subscription and return it
final_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if final_subscription:
logger.info("Race condition detected: subscription already created by another process",
subscription_id=subscription_data['subscription_id'])
return final_subscription
else:
# This shouldn't happen, but re-raise the error if we can't find it
raise
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create tenant-independent subscription",
user_id=subscription_data.get("user_id"),
plan=subscription_data.get("plan"),
error=str(e))
raise DatabaseError(f"Failed to create tenant-independent subscription: {str(e)}")
async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]:
"""Get all subscriptions waiting to be linked to tenants"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_linking_status": "pending",
"is_tenant_linked": False
},
order_by="created_at",
order_desc=True
)
return subscriptions
except Exception as e:
logger.error("Failed to get pending tenant linking subscriptions",
error=str(e))
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
"""Get pending tenant linking subscriptions for a specific user"""
try:
subscriptions = await self.get_multi(
filters={
"user_id": user_id,
"tenant_linking_status": "pending",
"is_tenant_linked": False
},
order_by="created_at",
order_desc=True
)
return subscriptions
except Exception as e:
logger.error("Failed to get pending subscriptions by user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def link_subscription_to_tenant(
self,
subscription_id: str,
tenant_id: str,
user_id: str
) -> Subscription:
"""Link a pending subscription to a tenant"""
try:
# Get the subscription first
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Validate subscription can be linked
if not subscription.can_be_linked_to_tenant(user_id):
raise ValidationError(
f"Subscription {subscription_id} cannot be linked to tenant by user {user_id}. "
f"Current status: {subscription.tenant_linking_status}, "
f"User: {subscription.user_id}, "
f"Already linked: {subscription.is_tenant_linked}"
)
# Update subscription with tenant information
update_data = {
"tenant_id": tenant_id,
"is_tenant_linked": True,
"tenant_linking_status": "completed",
"linked_at": datetime.utcnow(),
"status": "active", # Activate subscription when linked to tenant
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache for the tenant
await self._invalidate_cache(tenant_id)
logger.info("Subscription linked to tenant successfully",
subscription_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id)
return updated_subscription
except Exception as e:
logger.error("Failed to link subscription to tenant",
subscription_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}")
async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int:
"""Clean up subscriptions that were never linked to tenants"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM subscriptions
WHERE tenant_linking_status = 'pending'
AND is_tenant_linked = FALSE
AND created_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up orphaned subscriptions",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup orphaned subscriptions",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get subscriptions by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
query = select(Subscription).where(Subscription.customer_id == customer_id)
result = await self.session.execute(query)
subscriptions = result.scalars().all()
logger.debug("Found subscriptions by customer_id",
customer_id=customer_id,
count=len(subscriptions))
return subscriptions
except Exception as e:
logger.error("Error getting subscriptions by customer_id",
customer_id=customer_id,
error=str(e))
raise DatabaseError(f"Failed to get subscriptions by customer_id: {str(e)}")

View File

@@ -0,0 +1,218 @@
"""
Tenant Location Repository
Handles database operations for tenant location data
"""
from typing import List, Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
import structlog
from app.models.tenant_location import TenantLocation
from app.models.tenants import Tenant
from shared.database.exceptions import DatabaseError
from .base import BaseRepository
logger = structlog.get_logger()
class TenantLocationRepository(BaseRepository):
"""Repository for tenant location operations"""
def __init__(self, session: AsyncSession):
super().__init__(TenantLocation, session)
async def create_location(self, location_data: Dict[str, Any]) -> TenantLocation:
"""
Create a new tenant location
Args:
location_data: Dictionary containing location information
Returns:
Created TenantLocation object
"""
try:
# Create new location instance
location = TenantLocation(**location_data)
self.session.add(location)
await self.session.commit()
await self.session.refresh(location)
logger.info(f"Created new tenant location: {location.id} for tenant {location.tenant_id}")
return location
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to create tenant location: {str(e)}")
raise DatabaseError(f"Failed to create tenant location: {str(e)}")
async def get_location_by_id(self, location_id: str) -> Optional[TenantLocation]:
"""
Get a location by its ID
Args:
location_id: UUID of the location
Returns:
TenantLocation object if found, None otherwise
"""
try:
stmt = select(TenantLocation).where(TenantLocation.id == location_id)
result = await self.session.execute(stmt)
location = result.scalar_one_or_none()
return location
except Exception as e:
logger.error(f"Failed to get location by ID: {str(e)}")
raise DatabaseError(f"Failed to get location by ID: {str(e)}")
async def get_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
"""
Get all locations for a specific tenant
Args:
tenant_id: UUID of the tenant
Returns:
List of TenantLocation objects
"""
try:
stmt = select(TenantLocation).where(TenantLocation.tenant_id == tenant_id)
result = await self.session.execute(stmt)
locations = result.scalars().all()
return locations
except Exception as e:
logger.error(f"Failed to get locations by tenant: {str(e)}")
raise DatabaseError(f"Failed to get locations by tenant: {str(e)}")
async def get_location_by_type(self, tenant_id: str, location_type: str) -> Optional[TenantLocation]:
"""
Get a location by tenant and type
Args:
tenant_id: UUID of the tenant
location_type: Type of location (e.g., 'central_production', 'retail_outlet')
Returns:
TenantLocation object if found, None otherwise
"""
try:
stmt = select(TenantLocation).where(
TenantLocation.tenant_id == tenant_id,
TenantLocation.location_type == location_type
)
result = await self.session.execute(stmt)
location = result.scalar_one_or_none()
return location
except Exception as e:
logger.error(f"Failed to get location by type: {str(e)}")
raise DatabaseError(f"Failed to get location by type: {str(e)}")
async def update_location(self, location_id: str, location_data: Dict[str, Any]) -> Optional[TenantLocation]:
"""
Update a tenant location
Args:
location_id: UUID of the location to update
location_data: Dictionary containing updated location information
Returns:
Updated TenantLocation object if successful, None if location not found
"""
try:
stmt = (
update(TenantLocation)
.where(TenantLocation.id == location_id)
.values(**location_data)
)
await self.session.execute(stmt)
# Now fetch the updated location
location_stmt = select(TenantLocation).where(TenantLocation.id == location_id)
result = await self.session.execute(location_stmt)
location = result.scalar_one_or_none()
if location:
await self.session.commit()
logger.info(f"Updated tenant location: {location_id}")
return location
else:
await self.session.rollback()
logger.warning(f"Location not found for update: {location_id}")
return None
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to update location: {str(e)}")
raise DatabaseError(f"Failed to update location: {str(e)}")
async def delete_location(self, location_id: str) -> bool:
"""
Delete a tenant location
Args:
location_id: UUID of the location to delete
Returns:
True if deleted successfully, False if location not found
"""
try:
stmt = delete(TenantLocation).where(TenantLocation.id == location_id)
result = await self.session.execute(stmt)
if result.rowcount > 0:
await self.session.commit()
logger.info(f"Deleted tenant location: {location_id}")
return True
else:
await self.session.rollback()
logger.warning(f"Location not found for deletion: {location_id}")
return False
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to delete location: {str(e)}")
raise DatabaseError(f"Failed to delete location: {str(e)}")
async def get_active_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
"""
Get all active locations for a specific tenant
Args:
tenant_id: UUID of the tenant
Returns:
List of active TenantLocation objects
"""
try:
stmt = select(TenantLocation).where(
TenantLocation.tenant_id == tenant_id,
TenantLocation.is_active == True
)
result = await self.session.execute(stmt)
locations = result.scalars().all()
return locations
except Exception as e:
logger.error(f"Failed to get active locations by tenant: {str(e)}")
raise DatabaseError(f"Failed to get active locations by tenant: {str(e)}")
async def get_locations_by_tenant_with_type(self, tenant_id: str, location_types: List[str]) -> List[TenantLocation]:
"""
Get locations for a specific tenant filtered by location types
Args:
tenant_id: UUID of the tenant
location_types: List of location types to filter by
Returns:
List of TenantLocation objects matching the criteria
"""
try:
stmt = select(TenantLocation).where(
TenantLocation.tenant_id == tenant_id,
TenantLocation.location_type.in_(location_types)
)
result = await self.session.execute(stmt)
locations = result.scalars().all()
return locations
except Exception as e:
logger.error(f"Failed to get locations by tenant and type: {str(e)}")
raise DatabaseError(f"Failed to get locations by tenant and type: {str(e)}")

View File

@@ -0,0 +1,588 @@
"""
Tenant Member Repository
Repository for tenant membership operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import json
from .base import TenantBaseRepository
from app.models.tenants import TenantMember
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
from shared.config.base import is_internal_service
logger = structlog.get_logger()
class TenantMemberRepository(TenantBaseRepository):
"""Repository for tenant member operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Member data changes more frequently, shorter cache time (5 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_membership(self, membership_data: Dict[str, Any]) -> TenantMember:
"""Create a new tenant membership with validation"""
try:
# Validate membership data
validation_result = self._validate_tenant_data(
membership_data,
["tenant_id", "user_id", "role"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid membership data: {validation_result['errors']}")
# Check for existing membership
existing_membership = await self.get_membership(
membership_data["tenant_id"],
membership_data["user_id"]
)
if existing_membership and existing_membership.is_active:
raise DuplicateRecordError(f"User is already an active member of this tenant")
# Set default values
if "is_active" not in membership_data:
membership_data["is_active"] = True
if "joined_at" not in membership_data:
membership_data["joined_at"] = datetime.utcnow()
# Set permissions based on role
if "permissions" not in membership_data:
membership_data["permissions"] = self._get_default_permissions(
membership_data["role"]
)
# If reactivating existing membership
if existing_membership and not existing_membership.is_active:
# Update existing membership
update_data = {
key: value for key, value in membership_data.items()
if key not in ["tenant_id", "user_id"]
}
membership = await self.update(existing_membership.id, update_data)
else:
# Create new membership
membership = await self.create(membership_data)
logger.info("Tenant membership created",
membership_id=membership.id,
tenant_id=membership.tenant_id,
user_id=membership.user_id,
role=membership.role)
return membership
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create membership",
tenant_id=membership_data.get("tenant_id"),
user_id=membership_data.get("user_id"),
error=str(e))
raise DatabaseError(f"Failed to create membership: {str(e)}")
async def get_membership(self, tenant_id: str, user_id: str) -> Optional[TenantMember]:
"""Get specific membership by tenant and user"""
try:
# Validate that user_id is a proper UUID format for actual users
# Service names like 'inventory-service' should be handled differently
import uuid
try:
uuid.UUID(user_id)
is_valid_uuid = True
except ValueError:
is_valid_uuid = False
# For internal service access, return None to indicate no user membership
# Service access should be handled at the API layer
if not is_valid_uuid:
if is_internal_service(user_id):
# This is a known internal service request, return None
# Service access is granted at the API endpoint level
logger.debug("Internal service detected in membership lookup",
service=user_id,
tenant_id=tenant_id)
return None
elif user_id == "unknown-service":
# Special handling for 'unknown-service' which commonly occurs in demo sessions
# This happens when service identification fails during demo operations
logger.warning("Demo session service identification issue",
service=user_id,
tenant_id=tenant_id,
message="Service not properly identified - likely demo session context")
return None
else:
# This is an unknown service
# Return None to prevent database errors, but log a warning
logger.warning("Unknown service detected in membership lookup",
service=user_id,
tenant_id=tenant_id,
message="Service not in internal services registry")
return None
memberships = await self.get_multi(
filters={
"tenant_id": tenant_id,
"user_id": user_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return memberships[0] if memberships else None
except Exception as e:
logger.error("Failed to get membership",
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get membership: {str(e)}")
async def get_tenant_members(
self,
tenant_id: str,
active_only: bool = True,
role: str = None,
include_user_info: bool = False
) -> List[TenantMember]:
"""Get all members of a tenant with optional user info enrichment"""
try:
filters = {"tenant_id": tenant_id}
if active_only:
filters["is_active"] = True
if role:
filters["role"] = role
members = await self.get_multi(
filters=filters,
order_by="joined_at",
order_desc=False
)
# If include_user_info is True, enrich with user data from auth service
if include_user_info and members:
members = await self._enrich_members_with_user_info(members)
return members
except Exception as e:
logger.error("Failed to get tenant members",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get members: {str(e)}")
async def _enrich_members_with_user_info(self, members: List[TenantMember]) -> List[TenantMember]:
"""Enrich member objects with user information from auth service using batch endpoint"""
try:
import httpx
import os
if not members:
return members
# Get unique user IDs
user_ids = list(set([str(member.user_id) for member in members]))
if not user_ids:
return members
# Fetch user data from auth service using batch endpoint
# Using internal service communication
auth_service_url = os.getenv('AUTH_SERVICE_URL', 'http://auth-service:8000')
user_data_map = {}
async with httpx.AsyncClient() as client:
try:
# Use batch endpoint for efficiency
response = await client.post(
f"{auth_service_url}/api/v1/auth/users/batch",
json={"user_ids": user_ids},
timeout=10.0,
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
batch_result = response.json()
user_data_map = batch_result.get("users", {})
logger.info(
"Batch user fetch successful",
requested_count=len(user_ids),
found_count=batch_result.get("found_count", 0)
)
else:
logger.warning(
"Batch user fetch failed, falling back to individual calls",
status_code=response.status_code
)
# Fallback to individual calls if batch fails
for user_id in user_ids:
try:
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()
user_data_map[user_id] = user_data
except Exception as e:
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
continue
except Exception as e:
logger.warning("Batch user fetch failed, falling back to individual calls", error=str(e))
# Fallback to individual calls
for user_id in user_ids:
try:
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()
user_data_map[user_id] = user_data
except Exception as e:
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
continue
# Enrich members with user data
for member in members:
user_id_str = str(member.user_id)
if user_id_str in user_data_map and user_data_map[user_id_str] is not None:
user_data = user_data_map[user_id_str]
# Add user fields as attributes to the member object
member.user_email = user_data.get("email")
member.user_full_name = user_data.get("full_name")
member.user = user_data # Store full user object for compatibility
else:
# Set defaults for missing users
member.user_email = None
member.user_full_name = "Unknown User"
member.user = None
return members
except Exception as e:
logger.warning("Failed to enrich members with user info", error=str(e))
# Return members without enrichment if it fails
return members
async def get_user_memberships(
self,
user_id: str,
active_only: bool = True
) -> List[TenantMember]:
"""Get all tenants a user is a member of"""
try:
filters = {"user_id": user_id}
if active_only:
filters["is_active"] = True
return await self.get_multi(
filters=filters,
order_by="joined_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get user memberships",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get memberships: {str(e)}")
async def verify_user_access(
self,
user_id: str,
tenant_id: str
) -> Dict[str, Any]:
"""Verify if user has access to tenant and return access details"""
try:
membership = await self.get_membership(tenant_id, user_id)
if not membership or not membership.is_active:
return {
"has_access": False,
"role": "none",
"permissions": []
}
# Parse permissions
permissions = []
if membership.permissions:
try:
permissions = json.loads(membership.permissions)
except json.JSONDecodeError:
logger.warning("Invalid permissions JSON for membership",
membership_id=membership.id)
permissions = self._get_default_permissions(membership.role)
return {
"has_access": True,
"role": membership.role,
"permissions": permissions,
"membership_id": str(membership.id),
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None
}
except Exception as e:
logger.error("Failed to verify user access",
user_id=user_id,
tenant_id=tenant_id,
error=str(e))
return {
"has_access": False,
"role": "none",
"permissions": []
}
async def update_member_role(
self,
tenant_id: str,
user_id: str,
new_role: str,
updated_by: str = None
) -> Optional[TenantMember]:
"""Update member role and permissions"""
try:
valid_roles = ["owner", "admin", "member", "viewer"]
if new_role not in valid_roles:
raise ValidationError(f"Invalid role. Must be one of: {valid_roles}")
membership = await self.get_membership(tenant_id, user_id)
if not membership:
raise ValidationError("Membership not found")
# Get new permissions based on role
new_permissions = self._get_default_permissions(new_role)
updated_membership = await self.update(membership.id, {
"role": new_role,
"permissions": json.dumps(new_permissions)
})
logger.info("Member role updated",
membership_id=membership.id,
tenant_id=tenant_id,
user_id=user_id,
old_role=membership.role,
new_role=new_role,
updated_by=updated_by)
return updated_membership
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update member role",
tenant_id=tenant_id,
user_id=user_id,
new_role=new_role,
error=str(e))
raise DatabaseError(f"Failed to update role: {str(e)}")
async def deactivate_membership(
self,
tenant_id: str,
user_id: str,
deactivated_by: str = None
) -> Optional[TenantMember]:
"""Deactivate a membership (remove user from tenant)"""
try:
membership = await self.get_membership(tenant_id, user_id)
if not membership:
raise ValidationError("Membership not found")
# Don't allow deactivating the owner
if membership.role == "owner":
raise ValidationError("Cannot deactivate the owner membership")
updated_membership = await self.update(membership.id, {
"is_active": False
})
logger.info("Membership deactivated",
membership_id=membership.id,
tenant_id=tenant_id,
user_id=user_id,
deactivated_by=deactivated_by)
return updated_membership
except ValidationError:
raise
except Exception as e:
logger.error("Failed to deactivate membership",
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to deactivate membership: {str(e)}")
async def reactivate_membership(
self,
tenant_id: str,
user_id: str,
reactivated_by: str = None
) -> Optional[TenantMember]:
"""Reactivate a deactivated membership"""
try:
membership = await self.get_membership(tenant_id, user_id)
if not membership:
raise ValidationError("Membership not found")
updated_membership = await self.update(membership.id, {
"is_active": True,
"joined_at": datetime.utcnow() # Update join date
})
logger.info("Membership reactivated",
membership_id=membership.id,
tenant_id=tenant_id,
user_id=user_id,
reactivated_by=reactivated_by)
return updated_membership
except ValidationError:
raise
except Exception as e:
logger.error("Failed to reactivate membership",
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to reactivate membership: {str(e)}")
async def get_membership_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get membership statistics for a tenant"""
try:
# Get counts by role
role_query = text("""
SELECT role, COUNT(*) as count
FROM tenant_members
WHERE tenant_id = :tenant_id AND is_active = true
GROUP BY role
ORDER BY count DESC
""")
result = await self.session.execute(role_query, {"tenant_id": tenant_id})
members_by_role = {row.role: row.count for row in result.fetchall()}
# Get basic counts
total_members = await self.count(filters={"tenant_id": tenant_id})
active_members = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
# Get recent activity (members joined in last 30 days)
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_joins = len(await self.get_multi(
filters={
"tenant_id": tenant_id,
"is_active": True
},
limit=1000 # High limit to get accurate count
))
# Filter for recent joins (manual filtering since we can't use date range in filters easily)
recent_members = 0
all_active_members = await self.get_tenant_members(tenant_id, active_only=True)
for member in all_active_members:
if member.joined_at and member.joined_at >= thirty_days_ago:
recent_members += 1
return {
"total_members": total_members,
"active_members": active_members,
"inactive_members": total_members - active_members,
"members_by_role": members_by_role,
"recent_joins_30d": recent_members
}
except Exception as e:
logger.error("Failed to get membership statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_members": 0,
"active_members": 0,
"inactive_members": 0,
"members_by_role": {},
"recent_joins_30d": 0
}
def _get_default_permissions(self, role: str) -> str:
"""Get default permissions JSON string for a role"""
permission_map = {
"owner": ["read", "write", "admin", "delete"],
"admin": ["read", "write", "admin"],
"member": ["read", "write"],
"viewer": ["read"]
}
permissions = permission_map.get(role, ["read"])
return json.dumps(permissions)
async def bulk_update_permissions(
self,
tenant_id: str,
role_permissions: Dict[str, List[str]]
) -> int:
"""Bulk update permissions for all members of specific roles"""
try:
updated_count = 0
for role, permissions in role_permissions.items():
members = await self.get_tenant_members(
tenant_id, active_only=True, role=role
)
for member in members:
await self.update(member.id, {
"permissions": json.dumps(permissions)
})
updated_count += 1
logger.info("Bulk updated member permissions",
tenant_id=tenant_id,
updated_count=updated_count,
roles=list(role_permissions.keys()))
return updated_count
except Exception as e:
logger.error("Failed to bulk update permissions",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Bulk permission update failed: {str(e)}")
async def cleanup_inactive_memberships(self, days_old: int = 180) -> int:
"""Clean up old inactive memberships"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM tenant_members
WHERE is_active = false
AND created_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up inactive memberships",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup inactive memberships",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")

View File

@@ -0,0 +1,680 @@
"""
Tenant Repository
Repository for tenant operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import uuid
from .base import TenantBaseRepository
from app.models.tenants import Tenant, Subscription
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class TenantRepository(TenantBaseRepository):
"""Repository for tenant operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Tenants are relatively stable, longer cache time (10 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_tenant(self, tenant_data: Dict[str, Any]) -> Tenant:
"""Create a new tenant with validation"""
try:
# Validate tenant data
validation_result = self._validate_tenant_data(
tenant_data,
["name", "address", "postal_code", "owner_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid tenant data: {validation_result['errors']}")
# Generate subdomain if not provided
if "subdomain" not in tenant_data or not tenant_data["subdomain"]:
subdomain = await self._generate_unique_subdomain(tenant_data["name"])
tenant_data["subdomain"] = subdomain
else:
# Check if provided subdomain is unique
existing_tenant = await self.get_by_subdomain(tenant_data["subdomain"])
if existing_tenant:
raise DuplicateRecordError(f"Subdomain {tenant_data['subdomain']} already exists")
# Set default values
if "business_type" not in tenant_data:
tenant_data["business_type"] = "bakery"
if "city" not in tenant_data:
tenant_data["city"] = "Madrid"
if "is_active" not in tenant_data:
tenant_data["is_active"] = True
if "ml_model_trained" not in tenant_data:
tenant_data["ml_model_trained"] = False
# Create tenant
tenant = await self.create(tenant_data)
logger.info("Tenant created successfully",
tenant_id=tenant.id,
name=tenant.name,
subdomain=tenant.subdomain,
owner_id=tenant.owner_id)
return tenant
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create tenant",
name=tenant_data.get("name"),
error=str(e))
raise DatabaseError(f"Failed to create tenant: {str(e)}")
async def get_by_subdomain(self, subdomain: str) -> Optional[Tenant]:
"""Get tenant by subdomain"""
try:
return await self.get_by_field("subdomain", subdomain)
except Exception as e:
logger.error("Failed to get tenant by subdomain",
subdomain=subdomain,
error=str(e))
raise DatabaseError(f"Failed to get tenant: {str(e)}")
async def get_tenants_by_owner(self, owner_id: str) -> List[Tenant]:
"""Get all tenants owned by a user"""
try:
return await self.get_multi(
filters={"owner_id": owner_id, "is_active": True},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenants by owner",
owner_id=owner_id,
error=str(e))
raise DatabaseError(f"Failed to get tenants: {str(e)}")
async def get_active_tenants(self, skip: int = 0, limit: int = 100) -> List[Tenant]:
"""Get all active tenants"""
return await self.get_active_records(skip=skip, limit=limit)
async def search_tenants(
self,
search_term: str,
business_type: str = None,
city: str = None,
skip: int = 0,
limit: int = 50
) -> List[Tenant]:
"""Search tenants by name, address, or other criteria"""
try:
# Build search conditions
conditions = ["is_active = true"]
params = {"skip": skip, "limit": limit}
# Add text search
conditions.append("(LOWER(name) LIKE LOWER(:search_term) OR LOWER(address) LIKE LOWER(:search_term))")
params["search_term"] = f"%{search_term}%"
# Add business type filter
if business_type:
conditions.append("business_type = :business_type")
params["business_type"] = business_type
# Add city filter
if city:
conditions.append("LOWER(city) = LOWER(:city)")
params["city"] = city
query_text = f"""
SELECT * FROM tenants
WHERE {' AND '.join(conditions)}
ORDER BY name ASC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), params)
tenants = []
for row in result.fetchall():
record_dict = dict(row._mapping)
tenant = self.model(**record_dict)
tenants.append(tenant)
return tenants
except Exception as e:
logger.error("Failed to search tenants",
search_term=search_term,
error=str(e))
return []
async def update_tenant_model_status(
self,
tenant_id: str,
ml_model_trained: bool,
last_training_date: datetime = None
) -> Optional[Tenant]:
"""Update tenant model training status"""
try:
update_data = {
"ml_model_trained": ml_model_trained,
"updated_at": datetime.utcnow()
}
if last_training_date:
update_data["last_training_date"] = last_training_date
elif ml_model_trained:
update_data["last_training_date"] = datetime.utcnow()
updated_tenant = await self.update(tenant_id, update_data)
logger.info("Tenant model status updated",
tenant_id=tenant_id,
ml_model_trained=ml_model_trained,
last_training_date=last_training_date)
return updated_tenant
except Exception as e:
logger.error("Failed to update tenant model status",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to update model status: {str(e)}")
async def get_tenants_by_location(
self,
latitude: float,
longitude: float,
radius_km: float = 10.0,
limit: int = 50
) -> List[Tenant]:
"""Get tenants within a geographic radius"""
try:
# Using Haversine formula for distance calculation
query_text = """
SELECT *,
(6371 * acos(
cos(radians(:latitude)) *
cos(radians(latitude)) *
cos(radians(longitude) - radians(:longitude)) +
sin(radians(:latitude)) *
sin(radians(latitude))
)) AS distance_km
FROM tenants
WHERE is_active = true
AND latitude IS NOT NULL
AND longitude IS NOT NULL
HAVING distance_km <= :radius_km
ORDER BY distance_km ASC
LIMIT :limit
"""
result = await self.session.execute(text(query_text), {
"latitude": latitude,
"longitude": longitude,
"radius_km": radius_km,
"limit": limit
})
tenants = []
for row in result.fetchall():
# Create tenant object (excluding the calculated distance_km field)
record_dict = dict(row._mapping)
record_dict.pop("distance_km", None) # Remove calculated field
tenant = self.model(**record_dict)
tenants.append(tenant)
return tenants
except Exception as e:
logger.error("Failed to get tenants by location",
latitude=latitude,
longitude=longitude,
radius_km=radius_km,
error=str(e))
return []
async def get_tenant_statistics(self) -> Dict[str, Any]:
"""Get global tenant statistics"""
try:
# Get basic counts
total_tenants = await self.count()
active_tenants = await self.count(filters={"is_active": True})
# Get tenants by business type
business_type_query = text("""
SELECT business_type, COUNT(*) as count
FROM tenants
WHERE is_active = true
GROUP BY business_type
ORDER BY count DESC
""")
result = await self.session.execute(business_type_query)
business_type_stats = {row.business_type: row.count for row in result.fetchall()}
# Get tenants by subscription tier - now from subscriptions table
tier_query = text("""
SELECT s.plan as subscription_tier, COUNT(*) as count
FROM tenants t
LEFT JOIN subscriptions s ON t.id = s.tenant_id AND s.status = 'active'
WHERE t.is_active = true
GROUP BY s.plan
ORDER BY count DESC
""")
tier_result = await self.session.execute(tier_query)
tier_stats = {}
for row in tier_result.fetchall():
tier = row.subscription_tier if row.subscription_tier else "no_subscription"
tier_stats[tier] = row.count
# Get model training statistics
model_query = text("""
SELECT
COUNT(CASE WHEN ml_model_trained = true THEN 1 END) as trained_count,
COUNT(CASE WHEN ml_model_trained = false THEN 1 END) as untrained_count,
AVG(EXTRACT(EPOCH FROM (NOW() - last_training_date))/86400) as avg_days_since_training
FROM tenants
WHERE is_active = true
""")
model_result = await self.session.execute(model_query)
model_row = model_result.fetchone()
# Get recent registrations (last 30 days)
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_registrations = await self.count(filters={
"created_at": f">= '{thirty_days_ago.isoformat()}'"
})
return {
"total_tenants": total_tenants,
"active_tenants": active_tenants,
"inactive_tenants": total_tenants - active_tenants,
"tenants_by_business_type": business_type_stats,
"tenants_by_subscription": tier_stats,
"model_training": {
"trained_tenants": int(model_row.trained_count or 0),
"untrained_tenants": int(model_row.untrained_count or 0),
"avg_days_since_training": float(model_row.avg_days_since_training or 0)
} if model_row else {
"trained_tenants": 0,
"untrained_tenants": 0,
"avg_days_since_training": 0.0
},
"recent_registrations_30d": recent_registrations
}
except Exception as e:
logger.error("Failed to get tenant statistics", error=str(e))
return {
"total_tenants": 0,
"active_tenants": 0,
"inactive_tenants": 0,
"tenants_by_business_type": {},
"tenants_by_subscription": {},
"model_training": {
"trained_tenants": 0,
"untrained_tenants": 0,
"avg_days_since_training": 0.0
},
"recent_registrations_30d": 0
}
async def _generate_unique_subdomain(self, name: str) -> str:
"""Generate a unique subdomain from tenant name"""
try:
# Clean the name to create a subdomain
subdomain = name.lower().replace(' ', '-')
# Remove accents
subdomain = subdomain.replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u')
subdomain = subdomain.replace('ñ', 'n')
# Keep only alphanumeric and hyphens
subdomain = ''.join(c for c in subdomain if c.isalnum() or c == '-')
# Remove multiple consecutive hyphens
while '--' in subdomain:
subdomain = subdomain.replace('--', '-')
# Remove leading/trailing hyphens
subdomain = subdomain.strip('-')
# Ensure minimum length
if len(subdomain) < 3:
subdomain = f"tenant-{subdomain}"
# Check if subdomain exists
existing_tenant = await self.get_by_subdomain(subdomain)
if not existing_tenant:
return subdomain
# If it exists, add a unique suffix
counter = 1
while True:
candidate = f"{subdomain}-{counter}"
existing_tenant = await self.get_by_subdomain(candidate)
if not existing_tenant:
return candidate
counter += 1
# Prevent infinite loop
if counter > 9999:
return f"{subdomain}-{uuid.uuid4().hex[:6]}"
except Exception as e:
logger.error("Failed to generate unique subdomain",
name=name,
error=str(e))
# Fallback to UUID-based subdomain
return f"tenant-{uuid.uuid4().hex[:8]}"
async def deactivate_tenant(self, tenant_id: str) -> Optional[Tenant]:
"""Deactivate a tenant"""
return await self.deactivate_record(tenant_id)
async def activate_tenant(self, tenant_id: str) -> Optional[Tenant]:
"""Activate a tenant"""
return await self.activate_record(tenant_id)
async def get_child_tenants(self, parent_tenant_id: str) -> List[Tenant]:
"""Get all child tenants for a parent tenant"""
try:
return await self.get_multi(
filters={"parent_tenant_id": parent_tenant_id, "is_active": True},
order_by="created_at",
order_desc=False
)
except Exception as e:
logger.error("Failed to get child tenants",
parent_tenant_id=parent_tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get child tenants: {str(e)}")
async def get(self, record_id: Any) -> Optional[Tenant]:
"""Get tenant by ID - alias for get_by_id for compatibility"""
return await self.get_by_id(record_id)
async def get_child_tenant_count(self, parent_tenant_id: str) -> int:
"""Get count of child tenants for a parent tenant"""
try:
child_tenants = await self.get_child_tenants(parent_tenant_id)
return len(child_tenants)
except Exception as e:
logger.error("Failed to get child tenant count",
parent_tenant_id=parent_tenant_id,
error=str(e))
return 0
async def get_user_tenants_with_hierarchy(self, user_id: str) -> List[Dict[str, Any]]:
"""
Get all tenants a user has access to, organized in hierarchy.
Returns parent tenants with their children nested.
"""
try:
# Get all tenants where user is owner or member
query_text = """
SELECT DISTINCT t.*
FROM tenants t
LEFT JOIN tenant_members tm ON t.id = tm.tenant_id
WHERE (t.owner_id = :user_id OR tm.user_id = :user_id)
AND t.is_active = true
ORDER BY t.tenant_type DESC, t.created_at ASC
"""
result = await self.session.execute(text(query_text), {"user_id": user_id})
tenants = []
for row in result.fetchall():
record_dict = dict(row._mapping)
tenant = self.model(**record_dict)
tenants.append(tenant)
# Organize into hierarchy
tenant_hierarchy = []
parent_map = {}
# First pass: collect all parent/standalone tenants
for tenant in tenants:
if tenant.tenant_type in ['parent', 'standalone']:
tenant_dict = {
'id': str(tenant.id),
'name': tenant.name,
'subdomain': tenant.subdomain,
'tenant_type': tenant.tenant_type,
'business_type': tenant.business_type,
'business_model': tenant.business_model,
'city': tenant.city,
'is_active': tenant.is_active,
'children': [] if tenant.tenant_type == 'parent' else None
}
tenant_hierarchy.append(tenant_dict)
parent_map[str(tenant.id)] = tenant_dict
# Second pass: attach children to their parents
for tenant in tenants:
if tenant.tenant_type == 'child' and tenant.parent_tenant_id:
parent_id = str(tenant.parent_tenant_id)
if parent_id in parent_map:
child_dict = {
'id': str(tenant.id),
'name': tenant.name,
'subdomain': tenant.subdomain,
'tenant_type': 'child',
'parent_tenant_id': parent_id,
'city': tenant.city,
'is_active': tenant.is_active
}
parent_map[parent_id]['children'].append(child_dict)
return tenant_hierarchy
except Exception as e:
logger.error("Failed to get user tenants with hierarchy",
user_id=user_id,
error=str(e))
return []
async def get_tenants_by_session_id(self, session_id: str) -> List[Tenant]:
"""
Get tenants associated with a specific demo session using the demo_session_id field.
"""
try:
return await self.get_multi(
filters={
"demo_session_id": session_id,
"is_active": True
},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenants by session ID",
session_id=session_id,
error=str(e))
raise DatabaseError(f"Failed to get tenants by session ID: {str(e)}")
async def get_professional_demo_tenants(self, session_id: str) -> List[Tenant]:
"""
Get professional demo tenants filtered by session.
Args:
session_id: Required demo session ID to filter tenants
Returns:
List of professional demo tenants for this specific session
"""
try:
filters = {
"business_model": "professional_bakery",
"is_demo": True,
"is_active": True,
"demo_session_id": session_id # Always filter by session
}
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get professional demo tenants",
session_id=session_id,
error=str(e))
raise DatabaseError(f"Failed to get professional demo tenants: {str(e)}")
async def get_enterprise_demo_tenants(self, session_id: str) -> List[Tenant]:
"""
Get enterprise demo tenants (parent and children) filtered by session.
Args:
session_id: Required demo session ID to filter tenants
Returns:
List of enterprise demo tenants (1 parent + 3 children) for this specific session
"""
try:
# Get enterprise demo parent tenants for this session
parent_tenants = await self.get_multi(
filters={
"tenant_type": "parent",
"is_demo": True,
"is_active": True,
"demo_session_id": session_id # Always filter by session
},
order_by="created_at",
order_desc=True
)
# Get child tenants for the enterprise demo session
child_tenants = await self.get_multi(
filters={
"tenant_type": "child",
"is_demo": True,
"is_active": True,
"demo_session_id": session_id # Always filter by session
},
order_by="created_at",
order_desc=True
)
# Combine parent and child tenants
return parent_tenants + child_tenants
except Exception as e:
logger.error("Failed to get enterprise demo tenants",
session_id=session_id,
error=str(e))
raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}")
async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
"""
Get tenant by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
Tenant object if found, None otherwise
"""
try:
# Find tenant by joining with subscriptions table
# Tenant doesn't have customer_id directly, so we need to find via subscription
query = select(Tenant).join(
Subscription, Subscription.tenant_id == Tenant.id
).where(Subscription.customer_id == customer_id)
result = await self.session.execute(query)
tenant = result.scalar_one_or_none()
if tenant:
logger.debug("Found tenant by customer_id",
customer_id=customer_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.debug("No tenant found for customer_id",
customer_id=customer_id)
return None
except Exception as e:
logger.error("Error getting tenant by customer_id",
customer_id=customer_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}")
async def get_user_primary_tenant(self, user_id: str) -> Optional[Tenant]:
"""
Get the primary tenant for a user (the tenant they own)
Args:
user_id: User ID to find primary tenant for
Returns:
Tenant object if found, None otherwise
"""
try:
logger.debug("Getting primary tenant for user", user_id=user_id)
# Query for tenant where user is the owner
query = select(Tenant).where(Tenant.owner_id == user_id)
result = await self.session.execute(query)
tenant = result.scalar_one_or_none()
if tenant:
logger.debug("Found primary tenant for user",
user_id=user_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.debug("No primary tenant found for user", user_id=user_id)
return None
except Exception as e:
logger.error("Error getting primary tenant for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get primary tenant for user: {str(e)}")
async def get_any_user_tenant(self, user_id: str) -> Optional[Tenant]:
"""
Get any tenant that the user has access to (via tenant_members)
Args:
user_id: User ID to find accessible tenants for
Returns:
Tenant object if found, None otherwise
"""
try:
logger.debug("Getting any accessible tenant for user", user_id=user_id)
# Query for tenant members where user has access
from app.models.tenants import TenantMember
query = select(Tenant).join(
TenantMember, Tenant.id == TenantMember.tenant_id
).where(TenantMember.user_id == user_id)
result = await self.session.execute(query)
tenant = result.scalar_one_or_none()
if tenant:
logger.debug("Found accessible tenant for user",
user_id=user_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.debug("No accessible tenants found for user", user_id=user_id)
return None
except Exception as e:
logger.error("Error getting accessible tenant for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get accessible tenant for user: {str(e)}")

View File

@@ -0,0 +1,82 @@
# services/tenant/app/repositories/tenant_settings_repository.py
"""
Tenant Settings Repository
Data access layer for tenant settings
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional
from uuid import UUID
import structlog
from ..models.tenant_settings import TenantSettings
logger = structlog.get_logger()
class TenantSettingsRepository:
"""Repository for TenantSettings data access"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_tenant_id(self, tenant_id: UUID) -> Optional[TenantSettings]:
"""
Get tenant settings by tenant ID
Args:
tenant_id: UUID of the tenant
Returns:
TenantSettings or None if not found
"""
result = await self.db.execute(
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
)
return result.scalar_one_or_none()
async def create(self, settings: TenantSettings) -> TenantSettings:
"""
Create new tenant settings
Args:
settings: TenantSettings instance to create
Returns:
Created TenantSettings instance
"""
self.db.add(settings)
await self.db.commit()
await self.db.refresh(settings)
return settings
async def update(self, settings: TenantSettings) -> TenantSettings:
"""
Update tenant settings
Args:
settings: TenantSettings instance with updates
Returns:
Updated TenantSettings instance
"""
await self.db.commit()
await self.db.refresh(settings)
return settings
async def delete(self, tenant_id: UUID) -> None:
"""
Delete tenant settings
Args:
tenant_id: UUID of the tenant
"""
result = await self.db.execute(
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
)
settings = result.scalar_one_or_none()
if settings:
await self.db.delete(settings)
await self.db.commit()