Initial commit - production deployment
This commit is contained in:
0
services/tenant/app/__init__.py
Normal file
0
services/tenant/app/__init__.py
Normal file
8
services/tenant/app/api/__init__.py
Normal file
8
services/tenant/app/api/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Tenant API Package
|
||||
API endpoints for tenant management
|
||||
"""
|
||||
|
||||
from . import tenants
|
||||
|
||||
__all__ = ["tenants"]
|
||||
359
services/tenant/app/api/enterprise_upgrade.py
Normal file
359
services/tenant/app/api/enterprise_upgrade.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""
|
||||
Enterprise Upgrade API
|
||||
Endpoints for upgrading tenants to enterprise tier and managing child outlets
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, Any, Optional
|
||||
import uuid
|
||||
from datetime import datetime, date
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.tenants import Tenant
|
||||
from app.models.tenant_location import TenantLocation
|
||||
from app.services.tenant_service import EnhancedTenantService
|
||||
from app.core.config import settings
|
||||
from shared.auth.tenant_access import verify_tenant_permission_dep
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.clients.subscription_client import SubscriptionServiceClient, get_subscription_service_client
|
||||
from shared.subscription.plans import SubscriptionTier, QuotaLimits
|
||||
from shared.database.base import create_database_manager
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Dependency injection for enhanced tenant service
|
||||
def get_enhanced_tenant_service():
|
||||
try:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
return EnhancedTenantService(database_manager)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create enhanced tenant service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
|
||||
# Pydantic models for request bodies
|
||||
class EnterpriseUpgradeRequest(BaseModel):
|
||||
location_name: Optional[str] = Field(default="Central Production Facility")
|
||||
address: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
postal_code: Optional[str] = None
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
production_capacity_kg: Optional[int] = Field(default=1000)
|
||||
|
||||
|
||||
class ChildOutletRequest(BaseModel):
|
||||
name: str
|
||||
subdomain: str
|
||||
address: str
|
||||
city: Optional[str] = None
|
||||
postal_code: str
|
||||
latitude: Optional[float] = None
|
||||
longitude: Optional[float] = None
|
||||
phone: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
delivery_days: Optional[list] = None
|
||||
|
||||
|
||||
@router.post("/tenants/{tenant_id}/upgrade-to-enterprise")
|
||||
async def upgrade_to_enterprise(
|
||||
tenant_id: str,
|
||||
upgrade_data: EnterpriseUpgradeRequest,
|
||||
subscription_client: SubscriptionServiceClient = Depends(get_subscription_service_client),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Upgrade a tenant to enterprise tier with central production facility
|
||||
"""
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
|
||||
# Get the current tenant
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
tenant = await tenant_repo.get_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
# Verify current subscription allows upgrade to enterprise
|
||||
current_subscription = await subscription_client.get_subscription(tenant_id)
|
||||
if current_subscription['plan'] not in [SubscriptionTier.STARTER.value, SubscriptionTier.PROFESSIONAL.value]:
|
||||
raise HTTPException(status_code=400, detail="Only starter and professional tier tenants can be upgraded to enterprise")
|
||||
|
||||
# Verify user has admin/owner role
|
||||
# This is handled by current_user check
|
||||
|
||||
# Update tenant to parent type
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
updated_tenant = await tenant_repo.update(
|
||||
tenant_id,
|
||||
{
|
||||
'tenant_type': 'parent',
|
||||
'hierarchy_path': f"{tenant_id}" # Root path
|
||||
}
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
# Create central production location
|
||||
location_data = {
|
||||
'tenant_id': tenant_id,
|
||||
'name': upgrade_data.location_name,
|
||||
'location_type': 'central_production',
|
||||
'address': upgrade_data.address or tenant.address,
|
||||
'city': upgrade_data.city or tenant.city,
|
||||
'postal_code': upgrade_data.postal_code or tenant.postal_code,
|
||||
'latitude': upgrade_data.latitude or tenant.latitude,
|
||||
'longitude': upgrade_data.longitude or tenant.longitude,
|
||||
'capacity': upgrade_data.production_capacity_kg,
|
||||
'is_active': True
|
||||
}
|
||||
|
||||
from app.repositories.tenant_location_repository import TenantLocationRepository
|
||||
from app.core.database import database_manager
|
||||
|
||||
# Create async session
|
||||
async with database_manager.get_session() as session:
|
||||
location_repo = TenantLocationRepository(session)
|
||||
created_location = await location_repo.create_location(location_data)
|
||||
await session.commit()
|
||||
|
||||
# Update subscription to enterprise tier
|
||||
await subscription_client.update_subscription_plan(
|
||||
tenant_id=tenant_id,
|
||||
new_plan=SubscriptionTier.ENTERPRISE.value
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'tenant': updated_tenant,
|
||||
'production_location': created_location,
|
||||
'message': 'Tenant successfully upgraded to enterprise tier'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to upgrade tenant: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{parent_id}/add-child-outlet")
|
||||
async def add_child_outlet(
|
||||
parent_id: str,
|
||||
child_data: ChildOutletRequest,
|
||||
subscription_client: SubscriptionServiceClient = Depends(get_subscription_service_client),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Add a new child outlet to a parent tenant
|
||||
"""
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
|
||||
# Get parent tenant and verify it's a parent
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
parent_tenant = await tenant_repo.get_by_id(parent_id)
|
||||
if not parent_tenant:
|
||||
raise HTTPException(status_code=400, detail="Parent tenant not found")
|
||||
|
||||
parent_dict = {
|
||||
'id': str(parent_tenant.id),
|
||||
'name': parent_tenant.name,
|
||||
'tenant_type': parent_tenant.tenant_type,
|
||||
'subscription_tier': parent_tenant.subscription_tier,
|
||||
'business_type': parent_tenant.business_type,
|
||||
'business_model': parent_tenant.business_model,
|
||||
'city': parent_tenant.city,
|
||||
'phone': parent_tenant.phone,
|
||||
'email': parent_tenant.email,
|
||||
'owner_id': parent_tenant.owner_id
|
||||
}
|
||||
|
||||
if parent_dict.get('tenant_type') != 'parent':
|
||||
raise HTTPException(status_code=400, detail="Tenant is not a parent type")
|
||||
|
||||
# Validate subscription tier
|
||||
from shared.clients import get_tenant_client
|
||||
from shared.subscription.plans import PlanFeatures
|
||||
|
||||
tenant_client = get_tenant_client(config=settings, service_name="tenant-service")
|
||||
subscription = await tenant_client.get_tenant_subscription(parent_id)
|
||||
|
||||
if not subscription:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="No active subscription found for parent tenant"
|
||||
)
|
||||
|
||||
tier = subscription.get("plan", "starter")
|
||||
if not PlanFeatures.validate_tenant_access(tier, "child"):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Creating child outlets requires Enterprise subscription. Current plan: {tier}"
|
||||
)
|
||||
|
||||
# Check if parent has reached child quota
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
current_child_count = await tenant_repo.get_child_tenant_count(parent_id)
|
||||
|
||||
# Get max children from subscription plan
|
||||
max_children = QuotaLimits.get_limit("MAX_CHILD_TENANTS", tier)
|
||||
|
||||
if max_children is not None and current_child_count >= max_children:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"Child tenant limit reached. Current: {current_child_count}, Maximum: {max_children}"
|
||||
)
|
||||
|
||||
# Create new child tenant
|
||||
child_id = str(uuid.uuid4())
|
||||
child_tenant_data = {
|
||||
'id': child_id,
|
||||
'name': child_data.name,
|
||||
'subdomain': child_data.subdomain,
|
||||
'business_type': parent_dict.get('business_type', 'bakery'),
|
||||
'business_model': parent_dict.get('business_model', 'retail_bakery'),
|
||||
'address': child_data.address,
|
||||
'city': child_data.city or parent_dict.get('city'),
|
||||
'postal_code': child_data.postal_code,
|
||||
'latitude': child_data.latitude,
|
||||
'longitude': child_data.longitude,
|
||||
'phone': child_data.phone or parent_dict.get('phone'),
|
||||
'email': child_data.email or parent_dict.get('email'),
|
||||
'parent_tenant_id': parent_id,
|
||||
'tenant_type': 'child',
|
||||
'hierarchy_path': f"{parent_id}.{child_id}",
|
||||
'owner_id': parent_dict.get('owner_id'), # Same owner as parent
|
||||
'is_active': True
|
||||
}
|
||||
|
||||
# Use database managed session
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
created_child = await tenant_repo.create(child_tenant_data)
|
||||
await session.commit()
|
||||
|
||||
created_child_dict = {
|
||||
'id': str(created_child.id),
|
||||
'name': created_child.name,
|
||||
'subdomain': created_child.subdomain
|
||||
}
|
||||
|
||||
# Create retail outlet location for the child
|
||||
location_data = {
|
||||
'tenant_id': uuid.UUID(child_id),
|
||||
'name': f"Outlet - {child_data.name}",
|
||||
'location_type': 'retail_outlet',
|
||||
'address': child_data.address,
|
||||
'city': child_data.city or parent_dict.get('city'),
|
||||
'postal_code': child_data.postal_code,
|
||||
'latitude': child_data.latitude,
|
||||
'longitude': child_data.longitude,
|
||||
'delivery_windows': child_data.delivery_days,
|
||||
'is_active': True
|
||||
}
|
||||
|
||||
from app.repositories.tenant_location_repository import TenantLocationRepository
|
||||
|
||||
# Create async session
|
||||
async with database_manager.get_session() as session:
|
||||
location_repo = TenantLocationRepository(session)
|
||||
created_location = await location_repo.create_location(location_data)
|
||||
await session.commit()
|
||||
|
||||
location_dict = {
|
||||
'id': str(created_location.id) if created_location else None,
|
||||
'name': created_location.name if created_location else None
|
||||
}
|
||||
|
||||
# Copy relevant settings from parent (with child-specific overrides)
|
||||
# This would typically involve copying settings via tenant settings service
|
||||
|
||||
# Create child subscription inheriting from parent
|
||||
await subscription_client.create_child_subscription(
|
||||
child_tenant_id=child_id,
|
||||
parent_tenant_id=parent_id
|
||||
)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'child_tenant': created_child_dict,
|
||||
'location': location_dict,
|
||||
'message': 'Child outlet successfully added'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to add child outlet: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{tenant_id}/hierarchy")
|
||||
async def get_tenant_hierarchy(
|
||||
tenant_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Get tenant hierarchy information
|
||||
"""
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
tenant = await tenant_repo.get_by_id(tenant_id)
|
||||
if not tenant:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
result = {
|
||||
'tenant_id': tenant_id,
|
||||
'name': tenant.name,
|
||||
'tenant_type': tenant.tenant_type,
|
||||
'parent_tenant_id': tenant.parent_tenant_id,
|
||||
'hierarchy_path': tenant.hierarchy_path,
|
||||
'is_parent': tenant.tenant_type == 'parent',
|
||||
'is_child': tenant.tenant_type == 'child'
|
||||
}
|
||||
|
||||
# If this is a parent, include child count
|
||||
if tenant.tenant_type == 'parent':
|
||||
child_count = await tenant_repo.get_child_tenant_count(tenant_id)
|
||||
result['child_tenant_count'] = child_count
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get hierarchy: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/users/{user_id}/tenant-hierarchy")
|
||||
async def get_user_accessible_tenant_hierarchy(
|
||||
user_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Get all tenants a user has access to, organized in hierarchy
|
||||
"""
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
|
||||
# Fetch all tenants where user has access, organized hierarchically
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
user_tenants = await tenant_repo.get_user_tenants_with_hierarchy(user_id)
|
||||
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'tenants': user_tenants,
|
||||
'total_count': len(user_tenants)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get user hierarchy: {str(e)}")
|
||||
827
services/tenant/app/api/internal_demo.py
Normal file
827
services/tenant/app/api/internal_demo.py
Normal file
@@ -0,0 +1,827 @@
|
||||
"""
|
||||
Internal Demo Cloning API
|
||||
Service-to-service endpoint for cloning tenant data
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
import structlog
|
||||
import uuid
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from typing import Optional
|
||||
import os
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.tenants import Tenant, Subscription, TenantMember
|
||||
from app.models.tenant_location import TenantLocation
|
||||
from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/internal/demo", tags=["internal"])
|
||||
|
||||
# Base demo tenant IDs
|
||||
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
|
||||
|
||||
|
||||
def parse_date_field(
|
||||
field_value: any,
|
||||
session_time: datetime,
|
||||
field_name: str = "date"
|
||||
) -> Optional[datetime]:
|
||||
"""
|
||||
Parse a date field from JSON, supporting BASE_TS markers and ISO timestamps.
|
||||
|
||||
Args:
|
||||
field_value: The date field value (can be BASE_TS marker, ISO string, or None)
|
||||
session_time: Session creation time (timezone-aware UTC)
|
||||
field_name: Name of the field (for logging)
|
||||
|
||||
Returns:
|
||||
Timezone-aware UTC datetime or None
|
||||
"""
|
||||
if field_value is None:
|
||||
return None
|
||||
|
||||
# Handle BASE_TS markers
|
||||
if isinstance(field_value, str) and field_value.startswith("BASE_TS"):
|
||||
try:
|
||||
return resolve_time_marker(field_value, session_time)
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(
|
||||
"Failed to resolve BASE_TS marker",
|
||||
field_name=field_name,
|
||||
marker=field_value,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
# Handle ISO timestamps (legacy format - convert to absolute datetime)
|
||||
if isinstance(field_value, str) and ('T' in field_value or 'Z' in field_value):
|
||||
try:
|
||||
parsed_date = datetime.fromisoformat(field_value.replace('Z', '+00:00'))
|
||||
# Adjust relative to session time
|
||||
return adjust_date_for_demo(parsed_date, session_time)
|
||||
except (ValueError, AttributeError) as e:
|
||||
logger.warning(
|
||||
"Failed to parse ISO timestamp",
|
||||
field_name=field_name,
|
||||
value=field_value,
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
logger.warning(
|
||||
"Unknown date format",
|
||||
field_name=field_name,
|
||||
value=field_value,
|
||||
value_type=type(field_value).__name__
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@router.post("/clone")
|
||||
async def clone_demo_data(
|
||||
base_tenant_id: str,
|
||||
virtual_tenant_id: str,
|
||||
demo_account_type: str,
|
||||
session_id: Optional[str] = None,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Clone tenant service data for a virtual demo tenant
|
||||
|
||||
This endpoint creates the virtual tenant record that will be used
|
||||
for the demo session. No actual data cloning is needed in tenant service
|
||||
beyond creating the tenant record itself.
|
||||
|
||||
Args:
|
||||
base_tenant_id: Template tenant UUID (not used, for consistency)
|
||||
virtual_tenant_id: Target virtual tenant UUID
|
||||
demo_account_type: Type of demo account
|
||||
session_id: Originating session ID for tracing
|
||||
|
||||
Returns:
|
||||
Cloning status and record count
|
||||
"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
logger.info(
|
||||
"Starting tenant data cloning",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
demo_account_type=demo_account_type,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate UUIDs
|
||||
virtual_uuid = uuid.UUID(virtual_tenant_id)
|
||||
|
||||
# Check if tenant already exists
|
||||
result = await db.execute(
|
||||
select(Tenant).where(Tenant.id == virtual_uuid)
|
||||
)
|
||||
existing_tenant = result.scalars().first()
|
||||
|
||||
if existing_tenant:
|
||||
logger.info(
|
||||
"Virtual tenant already exists",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
tenant_name=existing_tenant.name
|
||||
)
|
||||
|
||||
# Ensure the tenant has a subscription (copy from template if missing)
|
||||
from datetime import timedelta
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.tenant_id == virtual_uuid,
|
||||
Subscription.status == "active"
|
||||
)
|
||||
)
|
||||
existing_subscription = result.scalars().first()
|
||||
|
||||
if not existing_subscription:
|
||||
logger.info("Creating missing subscription for existing virtual tenant by copying from template",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
base_tenant_id=base_tenant_id)
|
||||
|
||||
# Load subscription from seed data instead of cloning from template
|
||||
try:
|
||||
from shared.utils.seed_data_paths import get_seed_data_path
|
||||
|
||||
if demo_account_type == "professional":
|
||||
json_file = get_seed_data_path("professional", "01-tenant.json")
|
||||
elif demo_account_type == "enterprise":
|
||||
json_file = get_seed_data_path("enterprise", "01-tenant.json")
|
||||
else:
|
||||
raise ValueError(f"Invalid demo account type: {demo_account_type}")
|
||||
|
||||
except ImportError:
|
||||
# Fallback to original path
|
||||
seed_data_dir = Path(__file__).parent.parent.parent.parent / "infrastructure" / "seed-data"
|
||||
if demo_account_type == "professional":
|
||||
json_file = seed_data_dir / "professional" / "01-tenant.json"
|
||||
elif demo_account_type == "enterprise":
|
||||
json_file = seed_data_dir / "enterprise" / "parent" / "01-tenant.json"
|
||||
else:
|
||||
raise ValueError(f"Invalid demo account type: {demo_account_type}")
|
||||
|
||||
if json_file.exists():
|
||||
import json
|
||||
with open(json_file, 'r', encoding='utf-8') as f:
|
||||
seed_data = json.load(f)
|
||||
|
||||
subscription_data = seed_data.get('subscription')
|
||||
if subscription_data:
|
||||
# Load subscription from seed data
|
||||
subscription = Subscription(
|
||||
tenant_id=virtual_uuid,
|
||||
plan=subscription_data.get('plan', 'professional'),
|
||||
status=subscription_data.get('status', 'active'),
|
||||
monthly_price=subscription_data.get('monthly_price', 299.00),
|
||||
billing_cycle=subscription_data.get('billing_cycle', 'monthly'),
|
||||
max_users=subscription_data.get('max_users', 10),
|
||||
max_locations=subscription_data.get('max_locations', 3),
|
||||
max_products=subscription_data.get('max_products', 500),
|
||||
features=subscription_data.get('features', {}),
|
||||
trial_ends_at=parse_date_field(
|
||||
subscription_data.get('trial_ends_at'),
|
||||
session_time,
|
||||
"trial_ends_at"
|
||||
),
|
||||
next_billing_date=parse_date_field(
|
||||
subscription_data.get('next_billing_date'),
|
||||
session_time,
|
||||
"next_billing_date"
|
||||
),
|
||||
subscription_id=subscription_data.get('stripe_subscription_id'),
|
||||
customer_id=subscription_data.get('stripe_customer_id'),
|
||||
cancelled_at=parse_date_field(
|
||||
subscription_data.get('cancelled_at'),
|
||||
session_time,
|
||||
"cancelled_at"
|
||||
),
|
||||
cancellation_effective_date=parse_date_field(
|
||||
subscription_data.get('cancellation_effective_date'),
|
||||
session_time,
|
||||
"cancellation_effective_date"
|
||||
),
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
await db.commit()
|
||||
|
||||
logger.info("Subscription loaded from seed data successfully",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
plan=subscription.plan)
|
||||
else:
|
||||
logger.warning("No subscription found in seed data",
|
||||
virtual_tenant_id=virtual_tenant_id)
|
||||
else:
|
||||
logger.warning("Seed data file not found, falling back to default subscription",
|
||||
file_path=str(json_file))
|
||||
# Create default subscription if seed data not available
|
||||
subscription = Subscription(
|
||||
tenant_id=virtual_uuid,
|
||||
plan="professional" if demo_account_type == "professional" else "enterprise",
|
||||
status="active",
|
||||
monthly_price=299.00 if demo_account_type == "professional" else 799.00,
|
||||
max_users=10 if demo_account_type == "professional" else 50,
|
||||
max_locations=3 if demo_account_type == "professional" else -1,
|
||||
max_products=500 if demo_account_type == "professional" else -1,
|
||||
features={
|
||||
"production_planning": True,
|
||||
"procurement_management": True,
|
||||
"inventory_management": True,
|
||||
"sales_analytics": True,
|
||||
"multi_location": True,
|
||||
"advanced_reporting": True,
|
||||
"api_access": True,
|
||||
"priority_support": True
|
||||
},
|
||||
next_billing_date=datetime.now(timezone.utc) + timedelta(days=90),
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
|
||||
db.add(subscription)
|
||||
await db.commit()
|
||||
|
||||
logger.info("Default subscription created",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
plan=subscription.plan)
|
||||
|
||||
# Return success - idempotent operation
|
||||
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
return {
|
||||
"service": "tenant",
|
||||
"status": "completed",
|
||||
"records_cloned": 0 if existing_subscription else 1,
|
||||
"duration_ms": duration_ms,
|
||||
"details": {
|
||||
"tenant_already_exists": True,
|
||||
"tenant_id": str(virtual_uuid),
|
||||
"subscription_created": not existing_subscription
|
||||
}
|
||||
}
|
||||
|
||||
# Create virtual tenant record with required fields
|
||||
# Note: Use the actual demo user IDs from seed_demo_users.py
|
||||
# These match the demo users created in the auth service
|
||||
DEMO_OWNER_IDS = {
|
||||
"professional": "c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6", # María García López
|
||||
"enterprise": "d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7" # Carlos Martínez Ruiz
|
||||
}
|
||||
demo_owner_uuid = uuid.UUID(DEMO_OWNER_IDS.get(demo_account_type, DEMO_OWNER_IDS["professional"]))
|
||||
|
||||
tenant = Tenant(
|
||||
id=virtual_uuid,
|
||||
name=f"Demo Tenant - {demo_account_type.replace('_', ' ').title()}",
|
||||
address="Calle Demo 123", # Required field - provide demo address
|
||||
city="Madrid",
|
||||
postal_code="28001",
|
||||
business_type="bakery",
|
||||
is_demo=True,
|
||||
is_demo_template=False,
|
||||
demo_session_id=session_id, # Link tenant to demo session
|
||||
business_model=demo_account_type,
|
||||
is_active=True,
|
||||
timezone="Europe/Madrid",
|
||||
owner_id=demo_owner_uuid, # Required field - matches seed_demo_users.py
|
||||
tenant_type="parent" if demo_account_type in ["enterprise", "enterprise_parent"] else "standalone"
|
||||
)
|
||||
|
||||
db.add(tenant)
|
||||
await db.flush() # Flush to get the tenant ID
|
||||
|
||||
# Create demo subscription with appropriate tier based on demo account type
|
||||
|
||||
# Determine subscription tier based on demo account type
|
||||
if demo_account_type == "professional":
|
||||
plan = "professional"
|
||||
max_locations = 3
|
||||
elif demo_account_type in ["enterprise", "enterprise_parent"]:
|
||||
plan = "enterprise"
|
||||
max_locations = -1 # Unlimited
|
||||
elif demo_account_type == "enterprise_child":
|
||||
plan = "enterprise"
|
||||
max_locations = 1
|
||||
else:
|
||||
plan = "starter"
|
||||
max_locations = 1
|
||||
|
||||
demo_subscription = Subscription(
|
||||
tenant_id=tenant.id,
|
||||
plan=plan, # Set appropriate tier based on demo account type
|
||||
status="active",
|
||||
monthly_price=0.0, # Free for demo
|
||||
billing_cycle="monthly",
|
||||
max_users=-1, # Unlimited for demo
|
||||
max_locations=max_locations,
|
||||
max_products=-1, # Unlimited for demo
|
||||
features={},
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
db.add(demo_subscription)
|
||||
|
||||
# Create tenant member records for demo owner and staff
|
||||
import json
|
||||
|
||||
# Helper function to get permissions for role
|
||||
def get_permissions_for_role(role: str) -> str:
|
||||
permission_map = {
|
||||
"owner": ["read", "write", "admin", "delete"],
|
||||
"admin": ["read", "write", "admin"],
|
||||
"production_manager": ["read", "write"],
|
||||
"baker": ["read", "write"],
|
||||
"sales": ["read", "write"],
|
||||
"quality_control": ["read", "write"],
|
||||
"warehouse": ["read", "write"],
|
||||
"logistics": ["read", "write"],
|
||||
"procurement": ["read", "write"],
|
||||
"maintenance": ["read", "write"],
|
||||
"member": ["read", "write"],
|
||||
"viewer": ["read"]
|
||||
}
|
||||
permissions = permission_map.get(role, ["read"])
|
||||
return json.dumps(permissions)
|
||||
|
||||
# Define staff users for each demo account type (must match seed_demo_tenant_members.py)
|
||||
STAFF_USERS = {
|
||||
"professional": [
|
||||
# Owner
|
||||
{
|
||||
"user_id": uuid.UUID("c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6"),
|
||||
"role": "owner"
|
||||
},
|
||||
# Staff
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000001"),
|
||||
"role": "baker"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000002"),
|
||||
"role": "sales"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000003"),
|
||||
"role": "quality_control"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000004"),
|
||||
"role": "admin"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000005"),
|
||||
"role": "warehouse"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000006"),
|
||||
"role": "production_manager"
|
||||
}
|
||||
],
|
||||
"enterprise": [
|
||||
# Owner
|
||||
{
|
||||
"user_id": uuid.UUID("d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7"),
|
||||
"role": "owner"
|
||||
},
|
||||
# Staff
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000011"),
|
||||
"role": "production_manager"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000012"),
|
||||
"role": "quality_control"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000013"),
|
||||
"role": "logistics"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000014"),
|
||||
"role": "sales"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000015"),
|
||||
"role": "procurement"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000016"),
|
||||
"role": "maintenance"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Get staff users for this demo account type
|
||||
staff_users = STAFF_USERS.get(demo_account_type, [])
|
||||
|
||||
# Create tenant member records for all users (owner + staff)
|
||||
members_created = 0
|
||||
for staff_member in staff_users:
|
||||
tenant_member = TenantMember(
|
||||
tenant_id=virtual_uuid,
|
||||
user_id=staff_member["user_id"],
|
||||
role=staff_member["role"],
|
||||
permissions=get_permissions_for_role(staff_member["role"]),
|
||||
is_active=True,
|
||||
invited_by=demo_owner_uuid,
|
||||
invited_at=datetime.now(timezone.utc),
|
||||
joined_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(tenant_member)
|
||||
members_created += 1
|
||||
|
||||
logger.info(
|
||||
"Created tenant members for virtual tenant",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
members_created=members_created
|
||||
)
|
||||
|
||||
# Clone TenantLocations
|
||||
from app.models.tenant_location import TenantLocation
|
||||
|
||||
base_uuid = uuid.UUID(base_tenant_id)
|
||||
location_result = await db.execute(
|
||||
select(TenantLocation).where(TenantLocation.tenant_id == base_uuid)
|
||||
)
|
||||
base_locations = location_result.scalars().all()
|
||||
|
||||
records_cloned = 1 + members_created # Tenant + TenantMembers
|
||||
for base_location in base_locations:
|
||||
virtual_location = TenantLocation(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id=virtual_tenant_id,
|
||||
name=base_location.name,
|
||||
location_type=base_location.location_type,
|
||||
address=base_location.address,
|
||||
city=base_location.city,
|
||||
postal_code=base_location.postal_code,
|
||||
latitude=base_location.latitude,
|
||||
longitude=base_location.longitude,
|
||||
capacity=base_location.capacity,
|
||||
delivery_windows=base_location.delivery_windows,
|
||||
operational_hours=base_location.operational_hours,
|
||||
max_delivery_radius_km=base_location.max_delivery_radius_km,
|
||||
delivery_schedule_config=base_location.delivery_schedule_config,
|
||||
is_active=base_location.is_active,
|
||||
contact_person=base_location.contact_person,
|
||||
contact_phone=base_location.contact_phone,
|
||||
contact_email=base_location.contact_email,
|
||||
metadata_=base_location.metadata_ if isinstance(base_location.metadata_, dict) else (base_location.metadata_ or {})
|
||||
)
|
||||
db.add(virtual_location)
|
||||
records_cloned += 1
|
||||
|
||||
logger.info("Cloned TenantLocations", count=len(base_locations))
|
||||
|
||||
# Subscription already created earlier based on demo_account_type (lines 179-206)
|
||||
# No need to clone from template - this prevents duplicate subscription creation
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(tenant)
|
||||
|
||||
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
|
||||
logger.info(
|
||||
"Virtual tenant created successfully with subscription",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
tenant_name=tenant.name,
|
||||
subscription_plan=plan,
|
||||
duration_ms=duration_ms
|
||||
)
|
||||
|
||||
records_cloned = 1 + members_created + 1 # Tenant + TenantMembers + Subscription
|
||||
|
||||
return {
|
||||
"service": "tenant",
|
||||
"status": "completed",
|
||||
"records_cloned": records_cloned,
|
||||
"duration_ms": duration_ms,
|
||||
"details": {
|
||||
"tenant_id": str(tenant.id),
|
||||
"tenant_name": tenant.name,
|
||||
"business_model": tenant.business_model,
|
||||
"owner_id": str(demo_owner_uuid),
|
||||
"members_created": members_created,
|
||||
"subscription_plan": plan,
|
||||
"subscription_created": True
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
|
||||
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to clone tenant data",
|
||||
error=str(e),
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Rollback on error
|
||||
await db.rollback()
|
||||
|
||||
return {
|
||||
"service": "tenant",
|
||||
"status": "failed",
|
||||
"records_cloned": 0,
|
||||
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/create-child")
|
||||
async def create_child_outlet(
|
||||
request: dict,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Create a child outlet tenant for enterprise demos
|
||||
|
||||
Args:
|
||||
request: JSON request body with child tenant details
|
||||
|
||||
Returns:
|
||||
Creation status and tenant details
|
||||
"""
|
||||
# Extract parameters from request body
|
||||
base_tenant_id = request.get("base_tenant_id")
|
||||
virtual_tenant_id = request.get("virtual_tenant_id")
|
||||
parent_tenant_id = request.get("parent_tenant_id")
|
||||
child_name = request.get("child_name")
|
||||
location = request.get("location", {})
|
||||
session_id = request.get("session_id")
|
||||
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
logger.info(
|
||||
"Creating child outlet tenant",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
child_name=child_name,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
try:
|
||||
# Validate UUIDs
|
||||
virtual_uuid = uuid.UUID(virtual_tenant_id)
|
||||
parent_uuid = uuid.UUID(parent_tenant_id)
|
||||
|
||||
# Check if child tenant already exists
|
||||
result = await db.execute(select(Tenant).where(Tenant.id == virtual_uuid))
|
||||
existing_tenant = result.scalars().first()
|
||||
|
||||
if existing_tenant:
|
||||
logger.info(
|
||||
"Child tenant already exists",
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
tenant_name=existing_tenant.name
|
||||
)
|
||||
|
||||
# Return existing tenant - idempotent operation
|
||||
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
return {
|
||||
"service": "tenant",
|
||||
"status": "completed",
|
||||
"records_created": 0,
|
||||
"duration_ms": duration_ms,
|
||||
"details": {
|
||||
"tenant_id": str(virtual_uuid),
|
||||
"tenant_name": existing_tenant.name,
|
||||
"already_exists": True
|
||||
}
|
||||
}
|
||||
|
||||
# Get parent tenant to retrieve the correct owner_id
|
||||
parent_result = await db.execute(select(Tenant).where(Tenant.id == parent_uuid))
|
||||
parent_tenant = parent_result.scalars().first()
|
||||
|
||||
if not parent_tenant:
|
||||
logger.error("Parent tenant not found", parent_tenant_id=parent_tenant_id)
|
||||
return {
|
||||
"service": "tenant",
|
||||
"status": "failed",
|
||||
"records_created": 0,
|
||||
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"error": f"Parent tenant {parent_tenant_id} not found"
|
||||
}
|
||||
|
||||
# Use the parent's owner_id for the child tenant (enterprise demo owner)
|
||||
parent_owner_id = parent_tenant.owner_id
|
||||
|
||||
# Create child tenant with parent relationship
|
||||
child_tenant = Tenant(
|
||||
id=virtual_uuid,
|
||||
name=child_name,
|
||||
address=location.get("address", f"Calle Outlet {location.get('city', 'Madrid')}"),
|
||||
city=location.get("city", "Madrid"),
|
||||
postal_code=location.get("postal_code", "28001"),
|
||||
business_type="bakery",
|
||||
is_demo=True,
|
||||
is_demo_template=False,
|
||||
demo_session_id=session_id, # Link child tenant to demo session
|
||||
business_model="retail_outlet",
|
||||
is_active=True,
|
||||
timezone="Europe/Madrid",
|
||||
# Set parent relationship
|
||||
parent_tenant_id=parent_uuid,
|
||||
tenant_type="child",
|
||||
hierarchy_path=f"{str(parent_uuid)}.{str(virtual_uuid)}",
|
||||
|
||||
# Owner ID - MUST match the parent tenant owner (enterprise demo owner)
|
||||
# This ensures the parent owner can see and access child tenants
|
||||
owner_id=parent_owner_id
|
||||
)
|
||||
|
||||
db.add(child_tenant)
|
||||
await db.flush() # Flush to get the tenant ID
|
||||
|
||||
# Create TenantLocation for this retail outlet
|
||||
child_location = TenantLocation(
|
||||
id=uuid.uuid4(),
|
||||
tenant_id=virtual_uuid,
|
||||
name=f"{child_name} - Retail Outlet",
|
||||
location_type="retail_outlet",
|
||||
address=location.get("address", f"Calle Outlet {location.get('city', 'Madrid')}"),
|
||||
city=location.get("city", "Madrid"),
|
||||
postal_code=location.get("postal_code", "28001"),
|
||||
latitude=location.get("latitude"),
|
||||
longitude=location.get("longitude"),
|
||||
delivery_windows={
|
||||
"monday": "07:00-10:00",
|
||||
"wednesday": "07:00-10:00",
|
||||
"friday": "07:00-10:00"
|
||||
},
|
||||
operational_hours={
|
||||
"monday": "07:00-21:00",
|
||||
"tuesday": "07:00-21:00",
|
||||
"wednesday": "07:00-21:00",
|
||||
"thursday": "07:00-21:00",
|
||||
"friday": "07:00-21:00",
|
||||
"saturday": "08:00-21:00",
|
||||
"sunday": "09:00-21:00"
|
||||
},
|
||||
delivery_schedule_config={
|
||||
"delivery_days": ["monday", "wednesday", "friday"],
|
||||
"time_window": "07:00-10:00"
|
||||
},
|
||||
is_active=True
|
||||
)
|
||||
db.add(child_location)
|
||||
logger.info("Created TenantLocation for child", child_id=str(virtual_uuid), location_name=child_location.name)
|
||||
|
||||
# Create parent tenant lookup to get the correct plan for the child
|
||||
parent_result = await db.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.tenant_id == parent_uuid,
|
||||
Subscription.status == "active"
|
||||
)
|
||||
)
|
||||
parent_subscription = parent_result.scalars().first()
|
||||
|
||||
# Child inherits the same plan as parent
|
||||
parent_plan = parent_subscription.plan if parent_subscription else "enterprise"
|
||||
|
||||
child_subscription = Subscription(
|
||||
tenant_id=child_tenant.id,
|
||||
plan=parent_plan, # Child inherits the same plan as parent
|
||||
status="active",
|
||||
monthly_price=0.0, # Free for demo
|
||||
billing_cycle="monthly",
|
||||
max_users=10, # Demo limits
|
||||
max_locations=1, # Single location for outlet
|
||||
max_products=200,
|
||||
features={},
|
||||
is_tenant_linked=True # Required for check constraint when tenant_id is set
|
||||
)
|
||||
db.add(child_subscription)
|
||||
|
||||
# Create basic tenant members like parent
|
||||
import json
|
||||
|
||||
# Use the parent's owner_id (already retrieved above)
|
||||
# This ensures consistency between tenant.owner_id and TenantMember records
|
||||
|
||||
# Create tenant member for owner
|
||||
child_owner_member = TenantMember(
|
||||
tenant_id=virtual_uuid,
|
||||
user_id=parent_owner_id,
|
||||
role="owner",
|
||||
permissions=json.dumps(["read", "write", "admin", "delete"]),
|
||||
is_active=True,
|
||||
invited_by=parent_owner_id,
|
||||
invited_at=datetime.now(timezone.utc),
|
||||
joined_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(child_owner_member)
|
||||
|
||||
# Create staff members for the outlet from parent enterprise users
|
||||
# Use parent's enterprise staff (from enterprise/parent/02-auth.json)
|
||||
staff_users = [
|
||||
{
|
||||
"user_id": uuid.UUID("f6c54d0f-5899-4952-ad94-7a492c07167a"), # Laura López - Logistics
|
||||
"role": "logistics_coord"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("80765906-0074-4206-8f58-5867df1975fd"), # José Martínez - Quality
|
||||
"role": "quality_control"
|
||||
},
|
||||
{
|
||||
"user_id": uuid.UUID("701cb9d2-6049-4bb9-8d3a-1b3bd3aae45f"), # Francisco Moreno - Warehouse
|
||||
"role": "warehouse_supervisor"
|
||||
}
|
||||
]
|
||||
|
||||
members_created = 1 # Start with owner
|
||||
for staff_member in staff_users:
|
||||
tenant_member = TenantMember(
|
||||
tenant_id=virtual_uuid,
|
||||
user_id=staff_member["user_id"],
|
||||
role=staff_member["role"],
|
||||
permissions=json.dumps(["read", "write"]) if staff_member["role"] != "admin" else json.dumps(["read", "write", "admin"]),
|
||||
is_active=True,
|
||||
invited_by=parent_owner_id,
|
||||
invited_at=datetime.now(timezone.utc),
|
||||
joined_at=datetime.now(timezone.utc),
|
||||
created_at=datetime.now(timezone.utc)
|
||||
)
|
||||
db.add(tenant_member)
|
||||
members_created += 1
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(child_tenant)
|
||||
|
||||
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
|
||||
logger.info(
|
||||
"Child outlet created successfully",
|
||||
virtual_tenant_id=str(virtual_tenant_id),
|
||||
parent_tenant_id=str(parent_tenant_id),
|
||||
child_name=child_name,
|
||||
owner_id=str(parent_owner_id),
|
||||
duration_ms=duration_ms
|
||||
)
|
||||
|
||||
return {
|
||||
"service": "tenant",
|
||||
"status": "completed",
|
||||
"records_created": 2 + members_created, # Tenant + Subscription + Members
|
||||
"duration_ms": duration_ms,
|
||||
"details": {
|
||||
"tenant_id": str(child_tenant.id),
|
||||
"tenant_name": child_tenant.name,
|
||||
"parent_tenant_id": str(parent_tenant_id),
|
||||
"location": location,
|
||||
"members_created": members_created,
|
||||
"subscription_plan": "enterprise"
|
||||
}
|
||||
}
|
||||
|
||||
except ValueError as e:
|
||||
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
|
||||
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to create child outlet",
|
||||
error=str(e),
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
exc_info=True
|
||||
)
|
||||
|
||||
# Rollback on error
|
||||
await db.rollback()
|
||||
|
||||
return {
|
||||
"service": "tenant",
|
||||
"status": "failed",
|
||||
"records_created": 0,
|
||||
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/clone/health")
|
||||
async def clone_health_check():
|
||||
"""
|
||||
Health check for internal cloning endpoint
|
||||
Used by orchestrator to verify service availability
|
||||
"""
|
||||
return {
|
||||
"service": "tenant",
|
||||
"clone_endpoint": "available",
|
||||
"version": "2.0.0"
|
||||
}
|
||||
445
services/tenant/app/api/network_alerts.py
Normal file
445
services/tenant/app/api/network_alerts.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Network Alerts API
|
||||
Endpoints for aggregating and managing alerts across enterprise networks
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from pydantic import BaseModel, Field
|
||||
import structlog
|
||||
|
||||
from app.services.network_alerts_service import NetworkAlertsService
|
||||
from shared.auth.tenant_access import verify_tenant_permission_dep
|
||||
from shared.clients import get_tenant_client, get_alerts_client
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# Pydantic models for request/response
|
||||
class NetworkAlert(BaseModel):
|
||||
alert_id: str = Field(..., description="Unique alert ID")
|
||||
tenant_id: str = Field(..., description="Tenant ID where alert originated")
|
||||
tenant_name: str = Field(..., description="Tenant name")
|
||||
alert_type: str = Field(..., description="Type of alert: inventory, production, delivery, etc.")
|
||||
severity: str = Field(..., description="Severity: critical, high, medium, low")
|
||||
title: str = Field(..., description="Alert title")
|
||||
message: str = Field(..., description="Alert message")
|
||||
timestamp: str = Field(..., description="Alert timestamp")
|
||||
status: str = Field(..., description="Alert status: active, acknowledged, resolved")
|
||||
source_system: str = Field(..., description="System that generated the alert")
|
||||
related_entity_id: Optional[str] = Field(None, description="ID of related entity (product, route, etc.)")
|
||||
related_entity_type: Optional[str] = Field(None, description="Type of related entity")
|
||||
|
||||
|
||||
class AlertSeveritySummary(BaseModel):
|
||||
critical_count: int = Field(..., description="Number of critical alerts")
|
||||
high_count: int = Field(..., description="Number of high severity alerts")
|
||||
medium_count: int = Field(..., description="Number of medium severity alerts")
|
||||
low_count: int = Field(..., description="Number of low severity alerts")
|
||||
total_alerts: int = Field(..., description="Total number of alerts")
|
||||
|
||||
|
||||
class AlertTypeSummary(BaseModel):
|
||||
inventory_alerts: int = Field(..., description="Inventory-related alerts")
|
||||
production_alerts: int = Field(..., description="Production-related alerts")
|
||||
delivery_alerts: int = Field(..., description="Delivery-related alerts")
|
||||
equipment_alerts: int = Field(..., description="Equipment-related alerts")
|
||||
quality_alerts: int = Field(..., description="Quality-related alerts")
|
||||
other_alerts: int = Field(..., description="Other types of alerts")
|
||||
|
||||
|
||||
class NetworkAlertsSummary(BaseModel):
|
||||
total_alerts: int = Field(..., description="Total alerts across network")
|
||||
active_alerts: int = Field(..., description="Currently active alerts")
|
||||
acknowledged_alerts: int = Field(..., description="Acknowledged alerts")
|
||||
resolved_alerts: int = Field(..., description="Resolved alerts")
|
||||
severity_summary: AlertSeveritySummary = Field(..., description="Alerts by severity")
|
||||
type_summary: AlertTypeSummary = Field(..., description="Alerts by type")
|
||||
most_recent_alert: Optional[NetworkAlert] = Field(None, description="Most recent alert")
|
||||
|
||||
|
||||
class AlertCorrelation(BaseModel):
|
||||
correlation_id: str = Field(..., description="Correlation group ID")
|
||||
primary_alert: NetworkAlert = Field(..., description="Primary alert in the group")
|
||||
related_alerts: List[NetworkAlert] = Field(..., description="Alerts correlated with primary alert")
|
||||
correlation_type: str = Field(..., description="Type of correlation: causal, temporal, spatial")
|
||||
correlation_strength: float = Field(..., description="Correlation strength (0-1)")
|
||||
impact_analysis: str = Field(..., description="Analysis of combined impact")
|
||||
|
||||
|
||||
async def get_network_alerts_service() -> NetworkAlertsService:
|
||||
"""Dependency injection for NetworkAlertsService"""
|
||||
tenant_client = get_tenant_client(settings, "tenant-service")
|
||||
alerts_client = get_alerts_client(settings, "tenant-service")
|
||||
return NetworkAlertsService(tenant_client, alerts_client)
|
||||
|
||||
|
||||
@router.get("/tenants/{parent_id}/network/alerts",
|
||||
response_model=List[NetworkAlert],
|
||||
summary="Get aggregated alerts across network")
|
||||
async def get_network_alerts(
|
||||
parent_id: str,
|
||||
severity: Optional[str] = Query(None, description="Filter by severity: critical, high, medium, low"),
|
||||
alert_type: Optional[str] = Query(None, description="Filter by alert type"),
|
||||
status: Optional[str] = Query(None, description="Filter by status: active, acknowledged, resolved"),
|
||||
limit: int = Query(100, description="Maximum number of alerts to return"),
|
||||
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get aggregated alerts across all child tenants in a parent network
|
||||
|
||||
This endpoint provides a unified view of alerts across the entire enterprise network,
|
||||
enabling network managers to identify and prioritize issues that require attention.
|
||||
"""
|
||||
try:
|
||||
# Verify this is a parent tenant
|
||||
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can access network alerts"
|
||||
)
|
||||
|
||||
# Get all child tenants
|
||||
child_tenants = await network_alerts_service.get_child_tenants(parent_id)
|
||||
|
||||
if not child_tenants:
|
||||
return []
|
||||
|
||||
# Aggregate alerts from all child tenants
|
||||
all_alerts = []
|
||||
|
||||
for child in child_tenants:
|
||||
child_id = child['id']
|
||||
child_name = child['name']
|
||||
|
||||
# Get alerts for this child tenant
|
||||
child_alerts = await network_alerts_service.get_alerts_for_tenant(child_id)
|
||||
|
||||
# Enrich with tenant information and apply filters
|
||||
for alert in child_alerts:
|
||||
enriched_alert = {
|
||||
'alert_id': alert.get('alert_id', str(uuid.uuid4())),
|
||||
'tenant_id': child_id,
|
||||
'tenant_name': child_name,
|
||||
'alert_type': alert.get('alert_type', 'unknown'),
|
||||
'severity': alert.get('severity', 'medium'),
|
||||
'title': alert.get('title', 'No title'),
|
||||
'message': alert.get('message', 'No message'),
|
||||
'timestamp': alert.get('timestamp', datetime.now().isoformat()),
|
||||
'status': alert.get('status', 'active'),
|
||||
'source_system': alert.get('source_system', 'unknown'),
|
||||
'related_entity_id': alert.get('related_entity_id'),
|
||||
'related_entity_type': alert.get('related_entity_type')
|
||||
}
|
||||
|
||||
# Apply filters
|
||||
if severity and enriched_alert['severity'] != severity:
|
||||
continue
|
||||
if alert_type and enriched_alert['alert_type'] != alert_type:
|
||||
continue
|
||||
if status and enriched_alert['status'] != status:
|
||||
continue
|
||||
|
||||
all_alerts.append(enriched_alert)
|
||||
|
||||
# Sort by severity (critical first) and timestamp (newest first)
|
||||
severity_order = {'critical': 1, 'high': 2, 'medium': 3, 'low': 4}
|
||||
all_alerts.sort(key=lambda x: (severity_order.get(x['severity'], 5), -int(x['timestamp'] or 0)))
|
||||
|
||||
return all_alerts[:limit]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get network alerts", parent_id=parent_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get network alerts: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{parent_id}/network/alerts/summary",
|
||||
response_model=NetworkAlertsSummary,
|
||||
summary="Get network alerts summary")
|
||||
async def get_network_alerts_summary(
|
||||
parent_id: str,
|
||||
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get summary of alerts across the network
|
||||
|
||||
Provides aggregated metrics and statistics about alerts across all child tenants.
|
||||
"""
|
||||
try:
|
||||
# Verify this is a parent tenant
|
||||
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can access network alerts summary"
|
||||
)
|
||||
|
||||
# Get all network alerts
|
||||
all_alerts = await network_alerts_service.get_network_alerts(parent_id)
|
||||
|
||||
if not all_alerts:
|
||||
return NetworkAlertsSummary(
|
||||
total_alerts=0,
|
||||
active_alerts=0,
|
||||
acknowledged_alerts=0,
|
||||
resolved_alerts=0,
|
||||
severity_summary=AlertSeveritySummary(
|
||||
critical_count=0,
|
||||
high_count=0,
|
||||
medium_count=0,
|
||||
low_count=0,
|
||||
total_alerts=0
|
||||
),
|
||||
type_summary=AlertTypeSummary(
|
||||
inventory_alerts=0,
|
||||
production_alerts=0,
|
||||
delivery_alerts=0,
|
||||
equipment_alerts=0,
|
||||
quality_alerts=0,
|
||||
other_alerts=0
|
||||
),
|
||||
most_recent_alert=None
|
||||
)
|
||||
|
||||
# Calculate summary metrics
|
||||
active_alerts = sum(1 for a in all_alerts if a['status'] == 'active')
|
||||
acknowledged_alerts = sum(1 for a in all_alerts if a['status'] == 'acknowledged')
|
||||
resolved_alerts = sum(1 for a in all_alerts if a['status'] == 'resolved')
|
||||
|
||||
# Calculate severity summary
|
||||
severity_summary = AlertSeveritySummary(
|
||||
critical_count=sum(1 for a in all_alerts if a['severity'] == 'critical'),
|
||||
high_count=sum(1 for a in all_alerts if a['severity'] == 'high'),
|
||||
medium_count=sum(1 for a in all_alerts if a['severity'] == 'medium'),
|
||||
low_count=sum(1 for a in all_alerts if a['severity'] == 'low'),
|
||||
total_alerts=len(all_alerts)
|
||||
)
|
||||
|
||||
# Calculate type summary
|
||||
type_summary = AlertTypeSummary(
|
||||
inventory_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'inventory'),
|
||||
production_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'production'),
|
||||
delivery_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'delivery'),
|
||||
equipment_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'equipment'),
|
||||
quality_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'quality'),
|
||||
other_alerts=sum(1 for a in all_alerts if a['alert_type'] not in ['inventory', 'production', 'delivery', 'equipment', 'quality'])
|
||||
)
|
||||
|
||||
# Get most recent alert
|
||||
most_recent_alert = None
|
||||
if all_alerts:
|
||||
most_recent_alert = max(all_alerts, key=lambda x: x['timestamp'])
|
||||
|
||||
return NetworkAlertsSummary(
|
||||
total_alerts=len(all_alerts),
|
||||
active_alerts=active_alerts,
|
||||
acknowledged_alerts=acknowledged_alerts,
|
||||
resolved_alerts=resolved_alerts,
|
||||
severity_summary=severity_summary,
|
||||
type_summary=type_summary,
|
||||
most_recent_alert=most_recent_alert
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get network alerts summary", parent_id=parent_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get alerts summary: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{parent_id}/network/alerts/correlations",
|
||||
response_model=List[AlertCorrelation],
|
||||
summary="Get correlated alert groups")
|
||||
async def get_correlated_alerts(
|
||||
parent_id: str,
|
||||
min_correlation_strength: float = Query(0.7, ge=0.5, le=1.0, description="Minimum correlation strength"),
|
||||
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get groups of correlated alerts
|
||||
|
||||
Identifies alerts that are related or have cascading effects across the network.
|
||||
"""
|
||||
try:
|
||||
# Verify this is a parent tenant
|
||||
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can access alert correlations"
|
||||
)
|
||||
|
||||
# Get all network alerts
|
||||
all_alerts = await network_alerts_service.get_network_alerts(parent_id)
|
||||
|
||||
if not all_alerts:
|
||||
return []
|
||||
|
||||
# Detect correlations (simplified for demo)
|
||||
correlations = await network_alerts_service.detect_alert_correlations(
|
||||
all_alerts, min_correlation_strength
|
||||
)
|
||||
|
||||
return correlations
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get correlated alerts", parent_id=parent_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get alert correlations: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{parent_id}/network/alerts/{alert_id}/acknowledge",
|
||||
summary="Acknowledge network alert")
|
||||
async def acknowledge_network_alert(
|
||||
parent_id: str,
|
||||
alert_id: str,
|
||||
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Acknowledge a network alert
|
||||
|
||||
Marks an alert as acknowledged to indicate it's being addressed.
|
||||
"""
|
||||
try:
|
||||
# Verify this is a parent tenant
|
||||
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can acknowledge network alerts"
|
||||
)
|
||||
|
||||
# Acknowledge the alert
|
||||
result = await network_alerts_service.acknowledge_alert(parent_id, alert_id)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'alert_id': alert_id,
|
||||
'status': 'acknowledged',
|
||||
'message': 'Alert acknowledged successfully'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to acknowledge alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to acknowledge alert: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tenants/{parent_id}/network/alerts/{alert_id}/resolve",
|
||||
summary="Resolve network alert")
|
||||
async def resolve_network_alert(
|
||||
parent_id: str,
|
||||
alert_id: str,
|
||||
resolution_notes: Optional[str] = Query(None, description="Notes about resolution"),
|
||||
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Resolve a network alert
|
||||
|
||||
Marks an alert as resolved after the issue has been addressed.
|
||||
"""
|
||||
try:
|
||||
# Verify this is a parent tenant
|
||||
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can resolve network alerts"
|
||||
)
|
||||
|
||||
# Resolve the alert
|
||||
result = await network_alerts_service.resolve_alert(parent_id, alert_id, resolution_notes)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'alert_id': alert_id,
|
||||
'status': 'resolved',
|
||||
'resolution_notes': resolution_notes,
|
||||
'message': 'Alert resolved successfully'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to resolve alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to resolve alert: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{parent_id}/network/alerts/trends",
|
||||
summary="Get alert trends over time")
|
||||
async def get_alert_trends(
|
||||
parent_id: str,
|
||||
days: int = Query(30, ge=7, le=365, description="Number of days to analyze"),
|
||||
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get alert trends over time
|
||||
|
||||
Analyzes how alert patterns change over time to identify systemic issues.
|
||||
"""
|
||||
try:
|
||||
# Verify this is a parent tenant
|
||||
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can access alert trends"
|
||||
)
|
||||
|
||||
# Get alert trends
|
||||
trends = await network_alerts_service.get_alert_trends(parent_id, days)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'trends': trends,
|
||||
'period': f'Last {days} days'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get alert trends", parent_id=parent_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get alert trends: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/tenants/{parent_id}/network/alerts/prioritization",
|
||||
summary="Get prioritized alerts")
|
||||
async def get_prioritized_alerts(
|
||||
parent_id: str,
|
||||
limit: int = Query(10, description="Maximum number of alerts to return"),
|
||||
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
|
||||
verified_tenant: str = Depends(verify_tenant_permission_dep)
|
||||
):
|
||||
"""
|
||||
Get prioritized alerts based on impact and urgency
|
||||
|
||||
Uses AI to prioritize alerts based on potential business impact and urgency.
|
||||
"""
|
||||
try:
|
||||
# Verify this is a parent tenant
|
||||
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
|
||||
if tenant_info.get('tenant_type') != 'parent':
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Only parent tenants can access prioritized alerts"
|
||||
)
|
||||
|
||||
# Get prioritized alerts
|
||||
prioritized_alerts = await network_alerts_service.get_prioritized_alerts(parent_id, limit)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'prioritized_alerts': prioritized_alerts,
|
||||
'message': f'Top {len(prioritized_alerts)} prioritized alerts'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get prioritized alerts", parent_id=parent_id, error=str(e))
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get prioritized alerts: {str(e)}")
|
||||
|
||||
|
||||
# Import datetime at runtime to avoid circular imports
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
129
services/tenant/app/api/onboarding.py
Normal file
129
services/tenant/app/api/onboarding.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""
|
||||
Onboarding Status API
|
||||
Provides lightweight onboarding status checks by aggregating counts from multiple services
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
import structlog
|
||||
import asyncio
|
||||
import httpx
|
||||
import os
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.core.config import settings
|
||||
from shared.auth.decorators import get_current_tenant_id_dep
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/onboarding/status", include_tenant_prefix=False))
|
||||
async def get_onboarding_status(
|
||||
tenant_id: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get lightweight onboarding status by fetching counts from each service.
|
||||
|
||||
Returns:
|
||||
- ingredients_count: Number of active ingredients
|
||||
- suppliers_count: Number of active suppliers
|
||||
- recipes_count: Number of active recipes
|
||||
- has_minimum_setup: Boolean indicating if minimum requirements are met
|
||||
- progress_percentage: Overall onboarding progress (0-100)
|
||||
"""
|
||||
try:
|
||||
# Service URLs from environment
|
||||
inventory_url = os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")
|
||||
suppliers_url = os.getenv("SUPPLIERS_SERVICE_URL", "http://suppliers-service:8000")
|
||||
recipes_url = os.getenv("RECIPES_SERVICE_URL", "http://recipes-service:8000")
|
||||
|
||||
|
||||
# Fetch counts from all services in parallel
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
results = await asyncio.gather(
|
||||
client.get(
|
||||
f"{inventory_url}/internal/count",
|
||||
params={"tenant_id": tenant_id}
|
||||
),
|
||||
client.get(
|
||||
f"{suppliers_url}/internal/count",
|
||||
params={"tenant_id": tenant_id}
|
||||
),
|
||||
client.get(
|
||||
f"{recipes_url}/internal/count",
|
||||
params={"tenant_id": tenant_id}
|
||||
),
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Extract counts with fallback to 0
|
||||
ingredients_count = 0
|
||||
suppliers_count = 0
|
||||
recipes_count = 0
|
||||
|
||||
if not isinstance(results[0], Exception) and results[0].status_code == 200:
|
||||
ingredients_count = results[0].json().get("count", 0)
|
||||
|
||||
if not isinstance(results[1], Exception) and results[1].status_code == 200:
|
||||
suppliers_count = results[1].json().get("count", 0)
|
||||
|
||||
if not isinstance(results[2], Exception) and results[2].status_code == 200:
|
||||
recipes_count = results[2].json().get("count", 0)
|
||||
|
||||
# Calculate minimum setup requirements
|
||||
# Minimum: 3 ingredients, 1 supplier, 1 recipe
|
||||
has_minimum_ingredients = ingredients_count >= 3
|
||||
has_minimum_suppliers = suppliers_count >= 1
|
||||
has_minimum_recipes = recipes_count >= 1
|
||||
|
||||
has_minimum_setup = all([
|
||||
has_minimum_ingredients,
|
||||
has_minimum_suppliers,
|
||||
has_minimum_recipes
|
||||
])
|
||||
|
||||
# Calculate progress percentage
|
||||
# Each requirement contributes 33.33%
|
||||
progress = 0
|
||||
if has_minimum_ingredients:
|
||||
progress += 33
|
||||
if has_minimum_suppliers:
|
||||
progress += 33
|
||||
if has_minimum_recipes:
|
||||
progress += 34
|
||||
|
||||
return {
|
||||
"ingredients_count": ingredients_count,
|
||||
"suppliers_count": suppliers_count,
|
||||
"recipes_count": recipes_count,
|
||||
"has_minimum_setup": has_minimum_setup,
|
||||
"progress_percentage": progress,
|
||||
"requirements": {
|
||||
"ingredients": {
|
||||
"current": ingredients_count,
|
||||
"minimum": 3,
|
||||
"met": has_minimum_ingredients
|
||||
},
|
||||
"suppliers": {
|
||||
"current": suppliers_count,
|
||||
"minimum": 1,
|
||||
"met": has_minimum_suppliers
|
||||
},
|
||||
"recipes": {
|
||||
"current": recipes_count,
|
||||
"minimum": 1,
|
||||
"met": has_minimum_recipes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get onboarding status", tenant_id=tenant_id, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to get onboarding status: {str(e)}"
|
||||
)
|
||||
330
services/tenant/app/api/plans.py
Normal file
330
services/tenant/app/api/plans.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Subscription Plans API
|
||||
Public endpoint for fetching available subscription plans
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from typing import Dict, Any
|
||||
import structlog
|
||||
|
||||
from shared.subscription.plans import (
|
||||
SubscriptionTier,
|
||||
SubscriptionPlanMetadata,
|
||||
PlanPricing,
|
||||
QuotaLimits,
|
||||
PlanFeatures,
|
||||
FeatureCategories,
|
||||
UserFacingFeatures
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter(prefix="/plans", tags=["subscription-plans"])
|
||||
|
||||
|
||||
@router.get("", response_model=Dict[str, Any])
|
||||
async def get_available_plans():
|
||||
"""
|
||||
Get all available subscription plans with complete metadata
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Returns:
|
||||
Dictionary containing plan metadata for all tiers
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"plans": {
|
||||
"starter": {
|
||||
"name": "Starter",
|
||||
"description": "Perfect for small bakeries getting started",
|
||||
"monthly_price": 49.00,
|
||||
"yearly_price": 490.00,
|
||||
"features": [...],
|
||||
"limits": {...}
|
||||
},
|
||||
...
|
||||
}
|
||||
}
|
||||
```
|
||||
"""
|
||||
try:
|
||||
plans_data = {}
|
||||
|
||||
for tier in SubscriptionTier:
|
||||
metadata = SubscriptionPlanMetadata.PLANS[tier]
|
||||
|
||||
# Convert Decimal to float for JSON serialization
|
||||
plans_data[tier.value] = {
|
||||
"name": metadata["name"],
|
||||
"description_key": metadata["description_key"],
|
||||
"tagline_key": metadata["tagline_key"],
|
||||
"popular": metadata["popular"],
|
||||
"monthly_price": float(metadata["monthly_price"]),
|
||||
"yearly_price": float(metadata["yearly_price"]),
|
||||
"trial_days": metadata["trial_days"],
|
||||
"features": metadata["features"],
|
||||
"hero_features": metadata.get("hero_features", []),
|
||||
"roi_badge": metadata.get("roi_badge"),
|
||||
"business_metrics": metadata.get("business_metrics"),
|
||||
"limits": metadata["limits"],
|
||||
"support_key": metadata["support_key"],
|
||||
"recommended_for_key": metadata["recommended_for_key"],
|
||||
"contact_sales": metadata.get("contact_sales", False),
|
||||
"custom_pricing": metadata.get("custom_pricing", False),
|
||||
}
|
||||
|
||||
logger.info("subscription_plans_fetched", tier_count=len(plans_data))
|
||||
|
||||
return {"plans": plans_data}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("failed_to_fetch_plans", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch subscription plans"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{tier}", response_model=Dict[str, Any])
|
||||
async def get_plan_by_tier(tier: str):
|
||||
"""
|
||||
Get metadata for a specific subscription tier
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Args:
|
||||
tier: Subscription tier (starter, professional, enterprise)
|
||||
|
||||
Returns:
|
||||
Plan metadata for the specified tier
|
||||
|
||||
Raises:
|
||||
404: If tier is not found
|
||||
"""
|
||||
try:
|
||||
# Validate tier
|
||||
tier_enum = SubscriptionTier(tier.lower())
|
||||
|
||||
metadata = SubscriptionPlanMetadata.PLANS[tier_enum]
|
||||
|
||||
plan_data = {
|
||||
"tier": tier_enum.value,
|
||||
"name": metadata["name"],
|
||||
"description_key": metadata["description_key"],
|
||||
"tagline_key": metadata["tagline_key"],
|
||||
"popular": metadata["popular"],
|
||||
"monthly_price": float(metadata["monthly_price"]),
|
||||
"yearly_price": float(metadata["yearly_price"]),
|
||||
"trial_days": metadata["trial_days"],
|
||||
"features": metadata["features"],
|
||||
"hero_features": metadata.get("hero_features", []),
|
||||
"roi_badge": metadata.get("roi_badge"),
|
||||
"business_metrics": metadata.get("business_metrics"),
|
||||
"limits": metadata["limits"],
|
||||
"support_key": metadata["support_key"],
|
||||
"recommended_for_key": metadata["recommended_for_key"],
|
||||
"contact_sales": metadata.get("contact_sales", False),
|
||||
"custom_pricing": metadata.get("custom_pricing", False),
|
||||
}
|
||||
|
||||
logger.info("subscription_plan_fetched", tier=tier)
|
||||
|
||||
return plan_data
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Subscription tier '{tier}' not found"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("failed_to_fetch_plan", tier=tier, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch subscription plan"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{tier}/features")
|
||||
async def get_plan_features(tier: str):
|
||||
"""
|
||||
Get all features available in a subscription tier
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Args:
|
||||
tier: Subscription tier (starter, professional, enterprise)
|
||||
|
||||
Returns:
|
||||
List of feature keys available in the tier
|
||||
"""
|
||||
try:
|
||||
tier_enum = SubscriptionTier(tier.lower())
|
||||
features = PlanFeatures.get_features(tier_enum.value)
|
||||
|
||||
return {
|
||||
"tier": tier_enum.value,
|
||||
"features": features,
|
||||
"feature_count": len(features)
|
||||
}
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Subscription tier '{tier}' not found"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/{tier}/limits")
|
||||
async def get_plan_limits(tier: str):
|
||||
"""
|
||||
Get all quota limits for a subscription tier
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Args:
|
||||
tier: Subscription tier (starter, professional, enterprise)
|
||||
|
||||
Returns:
|
||||
All quota limits for the tier
|
||||
"""
|
||||
try:
|
||||
tier_enum = SubscriptionTier(tier.lower())
|
||||
|
||||
limits = {
|
||||
"tier": tier_enum.value,
|
||||
"team_and_organization": {
|
||||
"max_users": QuotaLimits.MAX_USERS[tier_enum],
|
||||
"max_locations": QuotaLimits.MAX_LOCATIONS[tier_enum],
|
||||
},
|
||||
"product_and_inventory": {
|
||||
"max_products": QuotaLimits.MAX_PRODUCTS[tier_enum],
|
||||
"max_recipes": QuotaLimits.MAX_RECIPES[tier_enum],
|
||||
"max_suppliers": QuotaLimits.MAX_SUPPLIERS[tier_enum],
|
||||
},
|
||||
"ml_and_analytics": {
|
||||
"training_jobs_per_day": QuotaLimits.TRAINING_JOBS_PER_DAY[tier_enum],
|
||||
"forecast_generation_per_day": QuotaLimits.FORECAST_GENERATION_PER_DAY[tier_enum],
|
||||
"dataset_size_rows": QuotaLimits.DATASET_SIZE_ROWS[tier_enum],
|
||||
"forecast_horizon_days": QuotaLimits.FORECAST_HORIZON_DAYS[tier_enum],
|
||||
"historical_data_access_days": QuotaLimits.HISTORICAL_DATA_ACCESS_DAYS[tier_enum],
|
||||
},
|
||||
"import_export": {
|
||||
"bulk_import_rows": QuotaLimits.BULK_IMPORT_ROWS[tier_enum],
|
||||
"bulk_export_rows": QuotaLimits.BULK_EXPORT_ROWS[tier_enum],
|
||||
},
|
||||
"integrations": {
|
||||
"pos_sync_interval_minutes": QuotaLimits.POS_SYNC_INTERVAL_MINUTES[tier_enum],
|
||||
"api_calls_per_hour": QuotaLimits.API_CALLS_PER_HOUR[tier_enum],
|
||||
"webhook_endpoints": QuotaLimits.WEBHOOK_ENDPOINTS[tier_enum],
|
||||
},
|
||||
"storage": {
|
||||
"file_storage_gb": QuotaLimits.FILE_STORAGE_GB[tier_enum],
|
||||
"report_retention_days": QuotaLimits.REPORT_RETENTION_DAYS[tier_enum],
|
||||
}
|
||||
}
|
||||
|
||||
return limits
|
||||
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Subscription tier '{tier}' not found"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/feature-categories")
|
||||
async def get_feature_categories():
|
||||
"""
|
||||
Get all feature categories with icons and translation keys
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Returns:
|
||||
Dictionary of feature categories
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"categories": FeatureCategories.CATEGORIES
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("failed_to_fetch_feature_categories", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch feature categories"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/feature-descriptions")
|
||||
async def get_feature_descriptions():
|
||||
"""
|
||||
Get user-facing feature descriptions with translation keys
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Returns:
|
||||
Dictionary of feature descriptions mapped by feature key
|
||||
"""
|
||||
try:
|
||||
return {
|
||||
"features": UserFacingFeatures.FEATURE_DISPLAY
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error("failed_to_fetch_feature_descriptions", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to fetch feature descriptions"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/compare")
|
||||
async def compare_plans():
|
||||
"""
|
||||
Get plan comparison data for all tiers
|
||||
|
||||
**Public endpoint** - No authentication required
|
||||
|
||||
Returns:
|
||||
Comparison matrix of all plans with key features and limits
|
||||
"""
|
||||
try:
|
||||
comparison = {
|
||||
"tiers": ["starter", "professional", "enterprise"],
|
||||
"pricing": {},
|
||||
"key_features": {},
|
||||
"key_limits": {}
|
||||
}
|
||||
|
||||
for tier in SubscriptionTier:
|
||||
metadata = SubscriptionPlanMetadata.PLANS[tier]
|
||||
|
||||
# Pricing
|
||||
comparison["pricing"][tier.value] = {
|
||||
"monthly": float(metadata["monthly_price"]),
|
||||
"yearly": float(metadata["yearly_price"]),
|
||||
"savings_percentage": round(
|
||||
((float(metadata["monthly_price"]) * 12) - float(metadata["yearly_price"])) /
|
||||
(float(metadata["monthly_price"]) * 12) * 100
|
||||
)
|
||||
}
|
||||
|
||||
# Key features (first 10)
|
||||
comparison["key_features"][tier.value] = metadata["features"][:10]
|
||||
|
||||
# Key limits
|
||||
comparison["key_limits"][tier.value] = {
|
||||
"users": metadata["limits"]["users"],
|
||||
"locations": metadata["limits"]["locations"],
|
||||
"products": metadata["limits"]["products"],
|
||||
"forecasts_per_day": metadata["limits"]["forecasts_per_day"],
|
||||
"training_jobs_per_day": QuotaLimits.TRAINING_JOBS_PER_DAY[tier],
|
||||
}
|
||||
|
||||
return comparison
|
||||
|
||||
except Exception as e:
|
||||
logger.error("failed_to_compare_plans", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to generate plan comparison"
|
||||
)
|
||||
1264
services/tenant/app/api/subscription.py
Normal file
1264
services/tenant/app/api/subscription.py
Normal file
File diff suppressed because it is too large
Load Diff
595
services/tenant/app/api/tenant_hierarchy.py
Normal file
595
services/tenant/app/api/tenant_hierarchy.py
Normal file
@@ -0,0 +1,595 @@
|
||||
"""
|
||||
Tenant Hierarchy API - Handles parent-child tenant relationships
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path
|
||||
from typing import List, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.schemas.tenants import (
|
||||
TenantResponse,
|
||||
ChildTenantCreate,
|
||||
BulkChildTenantsCreate,
|
||||
BulkChildTenantsResponse,
|
||||
ChildTenantResponse,
|
||||
TenantHierarchyResponse
|
||||
)
|
||||
from app.services.tenant_service import EnhancedTenantService
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.metrics import track_endpoint_metrics
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
|
||||
# Dependency injection for enhanced tenant service
|
||||
def get_enhanced_tenant_service():
|
||||
try:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
return EnhancedTenantService(database_manager)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create enhanced tenant service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/children", include_tenant_prefix=False), response_model=List[TenantResponse])
|
||||
@track_endpoint_metrics("tenant_children_list")
|
||||
async def get_tenant_children(
|
||||
tenant_id: UUID = Path(..., description="Parent Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Get all child tenants for a parent tenant.
|
||||
This endpoint returns all active child tenants associated with the specified parent tenant.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Get tenant children request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
user_type=current_user.get("type", "user"),
|
||||
is_service=current_user.get("type") == "service",
|
||||
role=current_user.get("role"),
|
||||
service_name=current_user.get("service", "none")
|
||||
)
|
||||
|
||||
# Skip access check for service-to-service calls
|
||||
is_service_call = current_user.get("type") == "service"
|
||||
if not is_service_call:
|
||||
# Verify user has access to the parent tenant
|
||||
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
|
||||
if not access_info.has_access:
|
||||
logger.warning(
|
||||
"Access denied to parent tenant",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to parent tenant"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Service-to-service call - bypassing access check",
|
||||
service=current_user.get("service"),
|
||||
tenant_id=str(tenant_id)
|
||||
)
|
||||
|
||||
# Get child tenants from repository
|
||||
from app.models.tenants import Tenant
|
||||
async with tenant_service.database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
child_tenants = await tenant_repo.get_child_tenants(str(tenant_id))
|
||||
|
||||
logger.debug(
|
||||
"Get tenant children successful",
|
||||
tenant_id=str(tenant_id),
|
||||
child_count=len(child_tenants)
|
||||
)
|
||||
|
||||
# Convert to plain dicts while still in session to avoid lazy-load issues
|
||||
child_dicts = []
|
||||
for child in child_tenants:
|
||||
# Handle subscription_tier safely - avoid lazy load
|
||||
try:
|
||||
# Try to get subscription_tier if subscriptions are already loaded
|
||||
sub_tier = child.__dict__.get('_subscription_tier_cache', 'enterprise')
|
||||
except:
|
||||
sub_tier = 'enterprise' # Default for enterprise children
|
||||
|
||||
child_dict = {
|
||||
'id': str(child.id),
|
||||
'name': child.name,
|
||||
'subdomain': child.subdomain,
|
||||
'business_type': child.business_type,
|
||||
'business_model': child.business_model,
|
||||
'address': child.address,
|
||||
'city': child.city,
|
||||
'postal_code': child.postal_code,
|
||||
'latitude': child.latitude,
|
||||
'longitude': child.longitude,
|
||||
'phone': child.phone,
|
||||
'email': child.email,
|
||||
'timezone': child.timezone,
|
||||
'owner_id': str(child.owner_id),
|
||||
'parent_tenant_id': str(child.parent_tenant_id) if child.parent_tenant_id else None,
|
||||
'tenant_type': child.tenant_type,
|
||||
'hierarchy_path': child.hierarchy_path,
|
||||
'subscription_tier': sub_tier, # Use the safely retrieved value
|
||||
'ml_model_trained': child.ml_model_trained,
|
||||
'last_training_date': child.last_training_date,
|
||||
'is_active': child.is_active,
|
||||
'is_demo': child.is_demo,
|
||||
'demo_session_id': child.demo_session_id,
|
||||
'created_at': child.created_at,
|
||||
'updated_at': child.updated_at
|
||||
}
|
||||
child_dicts.append(child_dict)
|
||||
|
||||
# Convert to Pydantic models outside the session without from_attributes
|
||||
child_responses = [TenantResponse(**child_dict) for child_dict in child_dicts]
|
||||
return child_responses
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant children failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Get tenant children failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/children/count", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("tenant_children_count")
|
||||
async def get_tenant_children_count(
|
||||
tenant_id: UUID = Path(..., description="Parent Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Get count of child tenants for a parent tenant.
|
||||
This endpoint returns the number of active child tenants associated with the specified parent tenant.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Get tenant children count request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Skip access check for service-to-service calls
|
||||
is_service_call = current_user.get("type") == "service"
|
||||
if not is_service_call:
|
||||
# Verify user has access to the parent tenant
|
||||
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
|
||||
if not access_info.has_access:
|
||||
logger.warning(
|
||||
"Access denied to parent tenant",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to parent tenant"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Service-to-service call - bypassing access check",
|
||||
service=current_user.get("service"),
|
||||
tenant_id=str(tenant_id)
|
||||
)
|
||||
|
||||
# Get child count from repository
|
||||
from app.models.tenants import Tenant
|
||||
async with tenant_service.database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
child_count = await tenant_repo.get_child_tenant_count(str(tenant_id))
|
||||
|
||||
logger.debug(
|
||||
"Get tenant children count successful",
|
||||
tenant_id=str(tenant_id),
|
||||
child_count=child_count
|
||||
)
|
||||
|
||||
return {
|
||||
"parent_tenant_id": str(tenant_id),
|
||||
"child_count": child_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant children count failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Get tenant children count failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/hierarchy", include_tenant_prefix=False), response_model=TenantHierarchyResponse)
|
||||
@track_endpoint_metrics("tenant_hierarchy")
|
||||
async def get_tenant_hierarchy(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Get tenant hierarchy information.
|
||||
|
||||
Returns hierarchy metadata for a tenant including:
|
||||
- Tenant type (standalone, parent, child)
|
||||
- Parent tenant ID (if this is a child)
|
||||
- Hierarchy path (materialized path)
|
||||
- Number of child tenants (for parent tenants)
|
||||
- Hierarchy level (depth in the tree)
|
||||
|
||||
This endpoint is used by the authentication layer for hierarchical access control
|
||||
and by enterprise features for network management.
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Get tenant hierarchy request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
user_type=current_user.get("type", "user"),
|
||||
is_service=current_user.get("type") == "service"
|
||||
)
|
||||
|
||||
# Get tenant from database
|
||||
from app.models.tenants import Tenant
|
||||
async with tenant_service.database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
|
||||
# Get the tenant
|
||||
tenant = await tenant_repo.get(str(tenant_id))
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tenant {tenant_id} not found"
|
||||
)
|
||||
|
||||
# Skip access check for service-to-service calls
|
||||
is_service_call = current_user.get("type") == "service"
|
||||
if not is_service_call:
|
||||
# Verify user has access to this tenant
|
||||
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
|
||||
if not access_info.has_access:
|
||||
logger.warning(
|
||||
"Access denied to tenant for hierarchy query",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"Service-to-service call - bypassing access check",
|
||||
service=current_user.get("service"),
|
||||
tenant_id=str(tenant_id)
|
||||
)
|
||||
|
||||
# Get child count if this is a parent tenant
|
||||
child_count = 0
|
||||
if tenant.tenant_type in ["parent", "standalone"]:
|
||||
child_count = await tenant_repo.get_child_tenant_count(str(tenant_id))
|
||||
|
||||
# Calculate hierarchy level from hierarchy_path
|
||||
hierarchy_level = 0
|
||||
if tenant.hierarchy_path:
|
||||
# hierarchy_path format: "parent_id" or "parent_id.child_id" or "parent_id.child_id.grandchild_id"
|
||||
hierarchy_level = tenant.hierarchy_path.count('.')
|
||||
|
||||
# Build response
|
||||
hierarchy_info = TenantHierarchyResponse(
|
||||
tenant_id=str(tenant.id),
|
||||
tenant_type=tenant.tenant_type or "standalone",
|
||||
parent_tenant_id=str(tenant.parent_tenant_id) if tenant.parent_tenant_id else None,
|
||||
hierarchy_path=tenant.hierarchy_path,
|
||||
child_count=child_count,
|
||||
hierarchy_level=hierarchy_level
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Get tenant hierarchy successful",
|
||||
tenant_id=str(tenant_id),
|
||||
tenant_type=tenant.tenant_type,
|
||||
parent_tenant_id=str(tenant.parent_tenant_id) if tenant.parent_tenant_id else None,
|
||||
child_count=child_count,
|
||||
hierarchy_level=hierarchy_level
|
||||
)
|
||||
|
||||
return hierarchy_info
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant hierarchy failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Get tenant hierarchy failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/api/v1/tenants/{tenant_id}/bulk-children", response_model=BulkChildTenantsResponse)
|
||||
@track_endpoint_metrics("bulk_create_child_tenants")
|
||||
async def bulk_create_child_tenants(
|
||||
request: BulkChildTenantsCreate,
|
||||
tenant_id: str = Path(..., description="Parent tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Bulk create child tenants for enterprise onboarding.
|
||||
|
||||
This endpoint creates multiple child tenants (outlets/branches) for an enterprise parent tenant
|
||||
and establishes the parent-child relationship. It's designed for use during the onboarding flow
|
||||
when an enterprise customer registers their network of locations.
|
||||
|
||||
Features:
|
||||
- Creates child tenants with proper hierarchy
|
||||
- Inherits subscription from parent
|
||||
- Optionally configures distribution routes
|
||||
- Returns detailed success/failure information
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Bulk child tenant creation request received",
|
||||
parent_tenant_id=tenant_id,
|
||||
child_count=len(request.child_tenants),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Verify parent tenant exists and user has access
|
||||
async with tenant_service.database_manager.get_session() as session:
|
||||
from app.models.tenants import Tenant
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
|
||||
parent_tenant = await tenant_repo.get_by_id(tenant_id)
|
||||
if not parent_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Parent tenant not found"
|
||||
)
|
||||
|
||||
# Verify user has access to parent tenant (owners/admins only)
|
||||
access_info = await tenant_service.verify_user_access(
|
||||
current_user["user_id"],
|
||||
tenant_id
|
||||
)
|
||||
if not access_info.has_access or access_info.role not in ["owner", "admin"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Only tenant owners/admins can create child tenants"
|
||||
)
|
||||
|
||||
# Verify parent is enterprise tier
|
||||
parent_subscription = await tenant_service.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not parent_subscription or parent_subscription.plan != "enterprise":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only enterprise tier tenants can have child tenants"
|
||||
)
|
||||
|
||||
# Update parent tenant type if it's still standalone
|
||||
if parent_tenant.tenant_type == "standalone":
|
||||
parent_tenant.tenant_type = "parent"
|
||||
parent_tenant.hierarchy_path = str(parent_tenant.id)
|
||||
await session.commit()
|
||||
await session.refresh(parent_tenant)
|
||||
|
||||
# Create child tenants
|
||||
created_tenants = []
|
||||
failed_tenants = []
|
||||
|
||||
for child_data in request.child_tenants:
|
||||
# Create a nested transaction (savepoint) for each child tenant
|
||||
# This allows us to rollback individual child tenant creation without affecting others
|
||||
async with session.begin_nested():
|
||||
try:
|
||||
# Create child tenant with full tenant model fields
|
||||
child_tenant = Tenant(
|
||||
name=child_data.name,
|
||||
subdomain=None, # Child tenants typically don't have subdomains
|
||||
business_type=child_data.business_type or parent_tenant.business_type,
|
||||
business_model=child_data.business_model or "retail_bakery", # Child outlets are typically retail
|
||||
address=child_data.address,
|
||||
city=child_data.city,
|
||||
postal_code=child_data.postal_code,
|
||||
latitude=child_data.latitude,
|
||||
longitude=child_data.longitude,
|
||||
phone=child_data.phone or parent_tenant.phone,
|
||||
email=child_data.email or parent_tenant.email,
|
||||
timezone=child_data.timezone or parent_tenant.timezone,
|
||||
owner_id=parent_tenant.owner_id,
|
||||
parent_tenant_id=parent_tenant.id,
|
||||
tenant_type="child",
|
||||
hierarchy_path=f"{parent_tenant.hierarchy_path}", # Will be updated after flush
|
||||
is_active=True,
|
||||
is_demo=parent_tenant.is_demo,
|
||||
demo_session_id=parent_tenant.demo_session_id,
|
||||
demo_expires_at=parent_tenant.demo_expires_at,
|
||||
metadata_={
|
||||
"location_code": child_data.location_code,
|
||||
"zone": child_data.zone,
|
||||
**(child_data.metadata or {})
|
||||
}
|
||||
)
|
||||
|
||||
session.add(child_tenant)
|
||||
await session.flush() # Get the ID without committing
|
||||
|
||||
# Update hierarchy_path now that we have the child tenant ID
|
||||
child_tenant.hierarchy_path = f"{parent_tenant.hierarchy_path}.{str(child_tenant.id)}"
|
||||
|
||||
# Create TenantLocation record for the child
|
||||
from app.models.tenant_location import TenantLocation
|
||||
location = TenantLocation(
|
||||
tenant_id=child_tenant.id,
|
||||
name=child_data.name,
|
||||
city=child_data.city,
|
||||
address=child_data.address,
|
||||
postal_code=child_data.postal_code,
|
||||
latitude=child_data.latitude,
|
||||
longitude=child_data.longitude,
|
||||
is_active=True,
|
||||
location_type="retail"
|
||||
)
|
||||
session.add(location)
|
||||
|
||||
# Inherit subscription from parent
|
||||
from app.models.tenants import Subscription
|
||||
from sqlalchemy import select
|
||||
parent_subscription_result = await session.execute(
|
||||
select(Subscription).where(
|
||||
Subscription.tenant_id == parent_tenant.id,
|
||||
Subscription.status == "active"
|
||||
)
|
||||
)
|
||||
parent_sub = parent_subscription_result.scalar_one_or_none()
|
||||
|
||||
if parent_sub:
|
||||
child_subscription = Subscription(
|
||||
tenant_id=child_tenant.id,
|
||||
plan=parent_sub.plan,
|
||||
status="active",
|
||||
billing_cycle=parent_sub.billing_cycle,
|
||||
monthly_price=0, # Child tenants don't pay separately
|
||||
trial_ends_at=parent_sub.trial_ends_at
|
||||
)
|
||||
session.add(child_subscription)
|
||||
|
||||
# Commit the nested transaction (savepoint)
|
||||
await session.flush()
|
||||
|
||||
# Refresh objects to get their final state
|
||||
await session.refresh(child_tenant)
|
||||
await session.refresh(location)
|
||||
|
||||
# Build response
|
||||
created_tenants.append(ChildTenantResponse(
|
||||
id=str(child_tenant.id),
|
||||
name=child_tenant.name,
|
||||
subdomain=child_tenant.subdomain,
|
||||
business_type=child_tenant.business_type,
|
||||
business_model=child_tenant.business_model,
|
||||
tenant_type=child_tenant.tenant_type,
|
||||
parent_tenant_id=str(child_tenant.parent_tenant_id),
|
||||
address=child_tenant.address,
|
||||
city=child_tenant.city,
|
||||
postal_code=child_tenant.postal_code,
|
||||
phone=child_tenant.phone,
|
||||
is_active=child_tenant.is_active,
|
||||
subscription_plan="enterprise",
|
||||
ml_model_trained=child_tenant.ml_model_trained,
|
||||
last_training_date=child_tenant.last_training_date,
|
||||
owner_id=str(child_tenant.owner_id),
|
||||
created_at=child_tenant.created_at,
|
||||
location_code=child_data.location_code,
|
||||
zone=child_data.zone,
|
||||
hierarchy_path=child_tenant.hierarchy_path
|
||||
))
|
||||
|
||||
logger.info(
|
||||
"Child tenant created successfully",
|
||||
child_tenant_id=str(child_tenant.id),
|
||||
child_name=child_tenant.name,
|
||||
location_code=child_data.location_code
|
||||
)
|
||||
|
||||
except Exception as child_error:
|
||||
logger.error(
|
||||
"Failed to create child tenant",
|
||||
child_name=child_data.name,
|
||||
error=str(child_error)
|
||||
)
|
||||
failed_tenants.append({
|
||||
"name": child_data.name,
|
||||
"location_code": child_data.location_code,
|
||||
"error": str(child_error)
|
||||
})
|
||||
# Nested transaction will automatically rollback on exception
|
||||
# This only rolls back the current child tenant, not the entire batch
|
||||
|
||||
# Commit all successful child tenant creations
|
||||
await session.commit()
|
||||
|
||||
# TODO: Configure distribution routes if requested
|
||||
distribution_configured = False
|
||||
if request.auto_configure_distribution and len(created_tenants) > 0:
|
||||
try:
|
||||
# This would call the distribution service to set up routes
|
||||
# For now, we'll skip this and just log
|
||||
logger.info(
|
||||
"Distribution route configuration requested",
|
||||
parent_tenant_id=tenant_id,
|
||||
child_count=len(created_tenants)
|
||||
)
|
||||
# distribution_configured = await configure_distribution_routes(...)
|
||||
except Exception as dist_error:
|
||||
logger.warning(
|
||||
"Failed to configure distribution routes",
|
||||
error=str(dist_error)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Bulk child tenant creation completed",
|
||||
parent_tenant_id=tenant_id,
|
||||
created_count=len(created_tenants),
|
||||
failed_count=len(failed_tenants)
|
||||
)
|
||||
|
||||
return BulkChildTenantsResponse(
|
||||
parent_tenant_id=tenant_id,
|
||||
created_count=len(created_tenants),
|
||||
failed_count=len(failed_tenants),
|
||||
created_tenants=created_tenants,
|
||||
failed_tenants=failed_tenants,
|
||||
distribution_configured=distribution_configured
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Bulk child tenant creation failed",
|
||||
parent_tenant_id=tenant_id,
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Bulk child tenant creation failed: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
# Register the router in the main app
|
||||
def register_hierarchy_routes(app):
|
||||
"""Register hierarchy routes with the main application"""
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
# Include the hierarchy routes with proper tenant prefix
|
||||
app.include_router(
|
||||
router,
|
||||
prefix="/api/v1",
|
||||
tags=["tenant-hierarchy"]
|
||||
)
|
||||
628
services/tenant/app/api/tenant_locations.py
Normal file
628
services/tenant/app/api/tenant_locations.py
Normal file
@@ -0,0 +1,628 @@
|
||||
"""
|
||||
Tenant Locations API - Handles tenant location operations
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from app.schemas.tenant_locations import (
|
||||
TenantLocationCreate,
|
||||
TenantLocationUpdate,
|
||||
TenantLocationResponse,
|
||||
TenantLocationsResponse,
|
||||
TenantLocationTypeFilter
|
||||
)
|
||||
from app.repositories.tenant_location_repository import TenantLocationRepository
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import admin_role_required
|
||||
from shared.monitoring.metrics import track_endpoint_metrics
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
|
||||
# Dependency injection for tenant location repository
|
||||
async def get_tenant_location_repository():
|
||||
"""Get tenant location repository instance with proper session management"""
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
|
||||
# Use async context manager properly to ensure session is closed
|
||||
async with database_manager.get_session() as session:
|
||||
yield TenantLocationRepository(session)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tenant location repository", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/locations", include_tenant_prefix=False), response_model=TenantLocationsResponse)
|
||||
@track_endpoint_metrics("tenant_locations_list")
|
||||
async def get_tenant_locations(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
location_types: str = Query(None, description="Comma-separated list of location types to filter"),
|
||||
is_active: Optional[bool] = Query(None, description="Filter by active status"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
|
||||
):
|
||||
"""
|
||||
Get all locations for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant to get locations for
|
||||
location_types: Optional comma-separated list of location types to filter (e.g., "central_production,retail_outlet")
|
||||
is_active: Optional filter for active locations only
|
||||
current_user: Current user making the request
|
||||
location_repo: Tenant location repository instance
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Get tenant locations request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
location_types=location_types,
|
||||
is_active=is_active
|
||||
)
|
||||
|
||||
# Check that the user has access to this tenant
|
||||
# This would typically be checked via access control middleware
|
||||
# For now, we'll trust the gateway has validated tenant access
|
||||
|
||||
locations = []
|
||||
|
||||
if location_types:
|
||||
# Filter by specific location types
|
||||
types_list = [t.strip() for t in location_types.split(",")]
|
||||
locations = await location_repo.get_locations_by_tenant_with_type(str(tenant_id), types_list)
|
||||
elif is_active is True:
|
||||
# Get only active locations
|
||||
locations = await location_repo.get_active_locations_by_tenant(str(tenant_id))
|
||||
elif is_active is False:
|
||||
# Get only inactive locations (by getting all and filtering in memory - not efficient but functional)
|
||||
all_locations = await location_repo.get_locations_by_tenant(str(tenant_id))
|
||||
locations = [loc for loc in all_locations if not loc.is_active]
|
||||
else:
|
||||
# Get all locations
|
||||
locations = await location_repo.get_locations_by_tenant(str(tenant_id))
|
||||
|
||||
logger.debug(
|
||||
"Get tenant locations successful",
|
||||
tenant_id=str(tenant_id),
|
||||
location_count=len(locations)
|
||||
)
|
||||
|
||||
# Convert to response format - handle metadata field to avoid SQLAlchemy conflicts
|
||||
location_responses = []
|
||||
for loc in locations:
|
||||
# Create dict from ORM object manually to handle metadata field properly
|
||||
loc_dict = {
|
||||
'id': str(loc.id),
|
||||
'tenant_id': str(loc.tenant_id),
|
||||
'name': loc.name,
|
||||
'location_type': loc.location_type,
|
||||
'address': loc.address,
|
||||
'city': loc.city,
|
||||
'postal_code': loc.postal_code,
|
||||
'latitude': loc.latitude,
|
||||
'longitude': loc.longitude,
|
||||
'contact_person': loc.contact_person,
|
||||
'contact_phone': loc.contact_phone,
|
||||
'contact_email': loc.contact_email,
|
||||
'is_active': loc.is_active,
|
||||
'delivery_windows': loc.delivery_windows,
|
||||
'operational_hours': loc.operational_hours,
|
||||
'capacity': loc.capacity,
|
||||
'max_delivery_radius_km': loc.max_delivery_radius_km,
|
||||
'delivery_schedule_config': loc.delivery_schedule_config,
|
||||
'metadata': loc.metadata_, # Use the actual column name to avoid conflict
|
||||
'created_at': loc.created_at,
|
||||
'updated_at': loc.updated_at
|
||||
}
|
||||
location_responses.append(TenantLocationResponse.model_validate(loc_dict))
|
||||
|
||||
return TenantLocationsResponse(
|
||||
locations=location_responses,
|
||||
total=len(location_responses)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant locations failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Get tenant locations failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/locations/{location_id}", include_tenant_prefix=False), response_model=TenantLocationResponse)
|
||||
@track_endpoint_metrics("tenant_location_get")
|
||||
async def get_tenant_location(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
location_id: UUID = Path(..., description="Location ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
|
||||
):
|
||||
"""
|
||||
Get a specific location for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
location_id: ID of the location to retrieve
|
||||
current_user: Current user making the request
|
||||
location_repo: Tenant location repository instance
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Get tenant location request received",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Get the specific location
|
||||
location = await location_repo.get_location_by_id(str(location_id))
|
||||
|
||||
if not location:
|
||||
logger.warning(
|
||||
"Location not found",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
# Verify that the location belongs to the specified tenant
|
||||
if str(location.tenant_id) != str(tenant_id):
|
||||
logger.warning(
|
||||
"Location does not belong to tenant",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
location_tenant_id=str(location.tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Get tenant location successful",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Create dict from ORM object manually to handle metadata field properly
|
||||
loc_dict = {
|
||||
'id': str(location.id),
|
||||
'tenant_id': str(location.tenant_id),
|
||||
'name': location.name,
|
||||
'location_type': location.location_type,
|
||||
'address': location.address,
|
||||
'city': location.city,
|
||||
'postal_code': location.postal_code,
|
||||
'latitude': location.latitude,
|
||||
'longitude': location.longitude,
|
||||
'contact_person': location.contact_person,
|
||||
'contact_phone': location.contact_phone,
|
||||
'contact_email': location.contact_email,
|
||||
'is_active': location.is_active,
|
||||
'delivery_windows': location.delivery_windows,
|
||||
'operational_hours': location.operational_hours,
|
||||
'capacity': location.capacity,
|
||||
'max_delivery_radius_km': location.max_delivery_radius_km,
|
||||
'delivery_schedule_config': location.delivery_schedule_config,
|
||||
'metadata': location.metadata_, # Use the actual column name to avoid conflict
|
||||
'created_at': location.created_at,
|
||||
'updated_at': location.updated_at
|
||||
}
|
||||
return TenantLocationResponse.model_validate(loc_dict)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant location failed",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Get tenant location failed"
|
||||
)
|
||||
|
||||
|
||||
@router.post(route_builder.build_base_route("{tenant_id}/locations", include_tenant_prefix=False), response_model=TenantLocationResponse)
|
||||
@admin_role_required
|
||||
async def create_tenant_location(
|
||||
location_data: TenantLocationCreate,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
|
||||
):
|
||||
"""
|
||||
Create a new location for a tenant.
|
||||
Requires admin or owner privileges.
|
||||
|
||||
Args:
|
||||
location_data: Location data to create
|
||||
tenant_id: ID of the tenant to create location for
|
||||
current_user: Current user making the request
|
||||
location_repo: Tenant location repository instance
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Create tenant location request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Verify that the tenant_id in the path matches the one in the data
|
||||
if str(tenant_id) != location_data.tenant_id:
|
||||
logger.warning(
|
||||
"Tenant ID mismatch",
|
||||
path_tenant_id=str(tenant_id),
|
||||
data_tenant_id=location_data.tenant_id,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Tenant ID in path does not match data"
|
||||
)
|
||||
|
||||
# Prepare location data by excluding unset values
|
||||
location_dict = location_data.model_dump(exclude_unset=True)
|
||||
# Ensure tenant_id comes from the path for security
|
||||
location_dict['tenant_id'] = str(tenant_id)
|
||||
|
||||
created_location = await location_repo.create_location(location_dict)
|
||||
|
||||
logger.info(
|
||||
"Created tenant location successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(created_location.id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Create dict from ORM object manually to handle metadata field properly
|
||||
loc_dict = {
|
||||
'id': str(created_location.id),
|
||||
'tenant_id': str(created_location.tenant_id),
|
||||
'name': created_location.name,
|
||||
'location_type': created_location.location_type,
|
||||
'address': created_location.address,
|
||||
'city': created_location.city,
|
||||
'postal_code': created_location.postal_code,
|
||||
'latitude': created_location.latitude,
|
||||
'longitude': created_location.longitude,
|
||||
'contact_person': created_location.contact_person,
|
||||
'contact_phone': created_location.contact_phone,
|
||||
'contact_email': created_location.contact_email,
|
||||
'is_active': created_location.is_active,
|
||||
'delivery_windows': created_location.delivery_windows,
|
||||
'operational_hours': created_location.operational_hours,
|
||||
'capacity': created_location.capacity,
|
||||
'max_delivery_radius_km': created_location.max_delivery_radius_km,
|
||||
'delivery_schedule_config': created_location.delivery_schedule_config,
|
||||
'metadata': created_location.metadata_, # Use the actual column name to avoid conflict
|
||||
'created_at': created_location.created_at,
|
||||
'updated_at': created_location.updated_at
|
||||
}
|
||||
return TenantLocationResponse.model_validate(loc_dict)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Create tenant location failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Create tenant location failed"
|
||||
)
|
||||
|
||||
|
||||
@router.put(route_builder.build_base_route("{tenant_id}/locations/{location_id}", include_tenant_prefix=False), response_model=TenantLocationResponse)
|
||||
@admin_role_required
|
||||
async def update_tenant_location(
|
||||
update_data: TenantLocationUpdate,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
location_id: UUID = Path(..., description="Location ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
|
||||
):
|
||||
"""
|
||||
Update a tenant location.
|
||||
Requires admin or owner privileges.
|
||||
|
||||
Args:
|
||||
update_data: Location data to update
|
||||
tenant_id: ID of the tenant
|
||||
location_id: ID of the location to update
|
||||
current_user: Current user making the request
|
||||
location_repo: Tenant location repository instance
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Update tenant location request received",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Check if the location exists and belongs to the tenant
|
||||
existing_location = await location_repo.get_location_by_id(str(location_id))
|
||||
if not existing_location:
|
||||
logger.warning(
|
||||
"Location not found for update",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
if str(existing_location.tenant_id) != str(tenant_id):
|
||||
logger.warning(
|
||||
"Location does not belong to tenant for update",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
location_tenant_id=str(existing_location.tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
# Prepare update data by excluding unset values
|
||||
update_dict = update_data.model_dump(exclude_unset=True)
|
||||
|
||||
updated_location = await location_repo.update_location(str(location_id), update_dict)
|
||||
|
||||
if not updated_location:
|
||||
logger.error(
|
||||
"Failed to update location (not found after verification)",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Updated tenant location successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Create dict from ORM object manually to handle metadata field properly
|
||||
loc_dict = {
|
||||
'id': str(updated_location.id),
|
||||
'tenant_id': str(updated_location.tenant_id),
|
||||
'name': updated_location.name,
|
||||
'location_type': updated_location.location_type,
|
||||
'address': updated_location.address,
|
||||
'city': updated_location.city,
|
||||
'postal_code': updated_location.postal_code,
|
||||
'latitude': updated_location.latitude,
|
||||
'longitude': updated_location.longitude,
|
||||
'contact_person': updated_location.contact_person,
|
||||
'contact_phone': updated_location.contact_phone,
|
||||
'contact_email': updated_location.contact_email,
|
||||
'is_active': updated_location.is_active,
|
||||
'delivery_windows': updated_location.delivery_windows,
|
||||
'operational_hours': updated_location.operational_hours,
|
||||
'capacity': updated_location.capacity,
|
||||
'max_delivery_radius_km': updated_location.max_delivery_radius_km,
|
||||
'delivery_schedule_config': updated_location.delivery_schedule_config,
|
||||
'metadata': updated_location.metadata_, # Use the actual column name to avoid conflict
|
||||
'created_at': updated_location.created_at,
|
||||
'updated_at': updated_location.updated_at
|
||||
}
|
||||
return TenantLocationResponse.model_validate(loc_dict)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Update tenant location failed",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Update tenant location failed"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(route_builder.build_base_route("{tenant_id}/locations/{location_id}", include_tenant_prefix=False))
|
||||
@admin_role_required
|
||||
async def delete_tenant_location(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
location_id: UUID = Path(..., description="Location ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
|
||||
):
|
||||
"""
|
||||
Delete a tenant location.
|
||||
Requires admin or owner privileges.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant
|
||||
location_id: ID of the location to delete
|
||||
current_user: Current user making the request
|
||||
location_repo: Tenant location repository instance
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Delete tenant location request received",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Check if the location exists and belongs to the tenant
|
||||
existing_location = await location_repo.get_location_by_id(str(location_id))
|
||||
if not existing_location:
|
||||
logger.warning(
|
||||
"Location not found for deletion",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
if str(existing_location.tenant_id) != str(tenant_id):
|
||||
logger.warning(
|
||||
"Location does not belong to tenant for deletion",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
location_tenant_id=str(existing_location.tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
deleted = await location_repo.delete_location(str(location_id))
|
||||
|
||||
if not deleted:
|
||||
logger.warning(
|
||||
"Location not found for deletion (race condition)",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Location not found"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Deleted tenant location successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Location deleted successfully",
|
||||
"location_id": str(location_id)
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Delete tenant location failed",
|
||||
tenant_id=str(tenant_id),
|
||||
location_id=str(location_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Delete tenant location failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/locations/type/{location_type}", include_tenant_prefix=False), response_model=TenantLocationsResponse)
|
||||
@track_endpoint_metrics("tenant_locations_by_type")
|
||||
async def get_tenant_locations_by_type(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
location_type: str = Path(..., description="Location type to filter by", pattern=r'^(central_production|retail_outlet|warehouse|store|branch)$'),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
|
||||
):
|
||||
"""
|
||||
Get all locations of a specific type for a tenant.
|
||||
|
||||
Args:
|
||||
tenant_id: ID of the tenant to get locations for
|
||||
location_type: Type of location to filter by
|
||||
current_user: Current user making the request
|
||||
location_repo: Tenant location repository instance
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Get tenant locations by type request received",
|
||||
tenant_id=str(tenant_id),
|
||||
location_type=location_type,
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
# Use the method that returns multiple locations by types
|
||||
location_list = await location_repo.get_locations_by_tenant_with_type(str(tenant_id), [location_type])
|
||||
|
||||
logger.debug(
|
||||
"Get tenant locations by type successful",
|
||||
tenant_id=str(tenant_id),
|
||||
location_type=location_type,
|
||||
location_count=len(location_list)
|
||||
)
|
||||
|
||||
# Convert to response format - handle metadata field to avoid SQLAlchemy conflicts
|
||||
location_responses = []
|
||||
for loc in location_list:
|
||||
# Create dict from ORM object manually to handle metadata field properly
|
||||
loc_dict = {
|
||||
'id': str(loc.id),
|
||||
'tenant_id': str(loc.tenant_id),
|
||||
'name': loc.name,
|
||||
'location_type': loc.location_type,
|
||||
'address': loc.address,
|
||||
'city': loc.city,
|
||||
'postal_code': loc.postal_code,
|
||||
'latitude': loc.latitude,
|
||||
'longitude': loc.longitude,
|
||||
'contact_person': loc.contact_person,
|
||||
'contact_phone': loc.contact_phone,
|
||||
'contact_email': loc.contact_email,
|
||||
'is_active': loc.is_active,
|
||||
'delivery_windows': loc.delivery_windows,
|
||||
'operational_hours': loc.operational_hours,
|
||||
'capacity': loc.capacity,
|
||||
'max_delivery_radius_km': loc.max_delivery_radius_km,
|
||||
'delivery_schedule_config': loc.delivery_schedule_config,
|
||||
'metadata': loc.metadata_, # Use the actual column name to avoid conflict
|
||||
'created_at': loc.created_at,
|
||||
'updated_at': loc.updated_at
|
||||
}
|
||||
location_responses.append(TenantLocationResponse.model_validate(loc_dict))
|
||||
|
||||
return TenantLocationsResponse(
|
||||
locations=location_responses,
|
||||
total=len(location_responses)
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant locations by type failed",
|
||||
tenant_id=str(tenant_id),
|
||||
location_type=location_type,
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Get tenant locations by type failed"
|
||||
)
|
||||
483
services/tenant/app/api/tenant_members.py
Normal file
483
services/tenant/app/api/tenant_members.py
Normal file
@@ -0,0 +1,483 @@
|
||||
"""
|
||||
Tenant Member Management API - ATOMIC operations
|
||||
Handles team member CRUD operations
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
|
||||
from typing import List, Dict, Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.schemas.tenants import TenantMemberResponse, AddMemberWithUserCreate, TenantResponse
|
||||
from app.services.tenant_service import EnhancedTenantService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.metrics import track_endpoint_metrics
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
# Dependency injection for enhanced tenant service
|
||||
def get_enhanced_tenant_service():
|
||||
try:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
return EnhancedTenantService(database_manager)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create enhanced tenant service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
@router.post(route_builder.build_base_route("{tenant_id}/members/with-user", include_tenant_prefix=False), response_model=TenantMemberResponse)
|
||||
@track_endpoint_metrics("tenant_add_member_with_user_creation")
|
||||
async def add_team_member_with_user_creation(
|
||||
member_data: AddMemberWithUserCreate,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Add a team member to tenant with optional user creation (pilot phase).
|
||||
|
||||
This endpoint supports two modes:
|
||||
1. Adding an existing user: Set user_id and create_user=False
|
||||
2. Creating a new user: Set create_user=True and provide email, full_name, password
|
||||
|
||||
In pilot phase, this allows owners to directly create users with passwords.
|
||||
In production, this will be replaced with an invitation-based flow.
|
||||
"""
|
||||
try:
|
||||
# CRITICAL: Check subscription limit before adding user
|
||||
from app.services.subscription_limit_service import SubscriptionLimitService
|
||||
|
||||
limit_service = SubscriptionLimitService()
|
||||
limit_check = await limit_service.can_add_user(str(tenant_id))
|
||||
|
||||
if not limit_check.get('can_add', False):
|
||||
logger.warning(
|
||||
"User limit exceeded",
|
||||
tenant_id=str(tenant_id),
|
||||
current=limit_check.get('current_count'),
|
||||
max=limit_check.get('max_allowed'),
|
||||
reason=limit_check.get('reason')
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "user_limit_exceeded",
|
||||
"message": limit_check.get('reason', 'User limit exceeded'),
|
||||
"current_count": limit_check.get('current_count'),
|
||||
"max_allowed": limit_check.get('max_allowed'),
|
||||
"upgrade_required": True
|
||||
}
|
||||
)
|
||||
|
||||
user_id_to_add = member_data.user_id
|
||||
|
||||
# If create_user is True, create the user first via auth service
|
||||
if member_data.create_user:
|
||||
logger.info(
|
||||
"Creating new user before adding to tenant",
|
||||
tenant_id=str(tenant_id),
|
||||
email=member_data.email,
|
||||
requested_by=current_user["user_id"]
|
||||
)
|
||||
|
||||
# Call auth service to create user
|
||||
from shared.clients.auth_client import AuthServiceClient
|
||||
from app.core.config import settings
|
||||
|
||||
auth_client = AuthServiceClient(settings)
|
||||
|
||||
# Map tenant role to user role
|
||||
# tenant roles: admin, member, viewer
|
||||
# user roles: admin, manager, user
|
||||
user_role_map = {
|
||||
"admin": "admin",
|
||||
"member": "manager",
|
||||
"viewer": "user"
|
||||
}
|
||||
user_role = user_role_map.get(member_data.role, "user")
|
||||
|
||||
try:
|
||||
user_create_data = {
|
||||
"email": member_data.email,
|
||||
"full_name": member_data.full_name,
|
||||
"password": member_data.password,
|
||||
"phone": member_data.phone,
|
||||
"role": user_role,
|
||||
"language": member_data.language or "es",
|
||||
"timezone": member_data.timezone or "Europe/Madrid"
|
||||
}
|
||||
|
||||
created_user = await auth_client.create_user_by_owner(user_create_data)
|
||||
user_id_to_add = created_user.get("id")
|
||||
|
||||
logger.info(
|
||||
"User created successfully",
|
||||
user_id=user_id_to_add,
|
||||
email=member_data.email,
|
||||
tenant_id=str(tenant_id)
|
||||
)
|
||||
|
||||
except Exception as auth_error:
|
||||
logger.error(
|
||||
"Failed to create user via auth service",
|
||||
error=str(auth_error),
|
||||
email=member_data.email
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to create user account: {str(auth_error)}"
|
||||
)
|
||||
|
||||
# Add the user (existing or newly created) to the tenant
|
||||
result = await tenant_service.add_team_member(
|
||||
str(tenant_id),
|
||||
user_id_to_add,
|
||||
member_data.role,
|
||||
current_user["user_id"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Team member added successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=user_id_to_add,
|
||||
role=member_data.role,
|
||||
user_was_created=member_data.create_user
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Add team member with user creation failed",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e)
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add team member"
|
||||
)
|
||||
|
||||
|
||||
@router.post(route_builder.build_base_route("{tenant_id}/members", include_tenant_prefix=False), response_model=TenantMemberResponse)
|
||||
@track_endpoint_metrics("tenant_add_member")
|
||||
async def add_team_member(
|
||||
user_id: str,
|
||||
role: str,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Add an existing team member to tenant (legacy endpoint)"""
|
||||
|
||||
try:
|
||||
# CRITICAL: Check subscription limit before adding user
|
||||
from app.services.subscription_limit_service import SubscriptionLimitService
|
||||
|
||||
limit_service = SubscriptionLimitService()
|
||||
limit_check = await limit_service.can_add_user(str(tenant_id))
|
||||
|
||||
if not limit_check.get('can_add', False):
|
||||
logger.warning(
|
||||
"User limit exceeded",
|
||||
tenant_id=str(tenant_id),
|
||||
current=limit_check.get('current_count'),
|
||||
max=limit_check.get('max_allowed'),
|
||||
reason=limit_check.get('reason')
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail={
|
||||
"error": "user_limit_exceeded",
|
||||
"message": limit_check.get('reason', 'User limit exceeded'),
|
||||
"current_count": limit_check.get('current_count'),
|
||||
"max_allowed": limit_check.get('max_allowed'),
|
||||
"upgrade_required": True
|
||||
}
|
||||
)
|
||||
|
||||
result = await tenant_service.add_team_member(
|
||||
str(tenant_id),
|
||||
user_id,
|
||||
role,
|
||||
current_user["user_id"]
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Team member added successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=user_id,
|
||||
role=role
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Add team member failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=user_id,
|
||||
role=role,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to add team member"
|
||||
)
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/members", include_tenant_prefix=False), response_model=List[TenantMemberResponse])
|
||||
async def get_team_members(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
active_only: bool = Query(True, description="Only return active members"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get all team members for a tenant with enhanced filtering"""
|
||||
|
||||
try:
|
||||
members = await tenant_service.get_team_members(
|
||||
str(tenant_id),
|
||||
current_user["user_id"],
|
||||
active_only=active_only
|
||||
)
|
||||
return members
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get team members failed",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get team members"
|
||||
)
|
||||
|
||||
@router.put(route_builder.build_base_route("{tenant_id}/members/{member_user_id}/role", include_tenant_prefix=False), response_model=TenantMemberResponse)
|
||||
@track_endpoint_metrics("tenant_update_member_role")
|
||||
async def update_member_role(
|
||||
new_role: str,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
member_user_id: str = Path(..., description="Member user ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Update team member role with enhanced permission validation"""
|
||||
|
||||
try:
|
||||
result = await tenant_service.update_member_role(
|
||||
str(tenant_id),
|
||||
member_user_id,
|
||||
new_role,
|
||||
current_user["user_id"]
|
||||
)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Update member role failed",
|
||||
tenant_id=str(tenant_id),
|
||||
member_user_id=member_user_id,
|
||||
new_role=new_role,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update member role"
|
||||
)
|
||||
|
||||
@router.delete(route_builder.build_base_route("{tenant_id}/members/{member_user_id}", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("tenant_remove_member")
|
||||
async def remove_team_member(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
member_user_id: str = Path(..., description="Member user ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Remove team member from tenant with enhanced validation"""
|
||||
|
||||
try:
|
||||
success = await tenant_service.remove_team_member(
|
||||
str(tenant_id),
|
||||
member_user_id,
|
||||
current_user["user_id"]
|
||||
)
|
||||
|
||||
if success:
|
||||
return {"success": True, "message": "Team member removed successfully"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to remove team member"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Remove team member failed",
|
||||
tenant_id=str(tenant_id),
|
||||
member_user_id=member_user_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to remove team member"
|
||||
)
|
||||
|
||||
@router.delete(route_builder.build_base_route("user/{user_id}/memberships", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("user_memberships_delete")
|
||||
async def delete_user_memberships(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Delete all tenant memberships for a user.
|
||||
Used by auth service when deleting a user account.
|
||||
Only accessible by internal services.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Delete user memberships request received",
|
||||
user_id=user_id,
|
||||
requesting_service=current_user.get("service", "unknown"),
|
||||
is_service=current_user.get("type") == "service"
|
||||
)
|
||||
|
||||
# Only allow internal service calls
|
||||
if current_user.get("type") != "service":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="This endpoint is only accessible to internal services"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await tenant_service.delete_user_memberships(user_id)
|
||||
|
||||
logger.info(
|
||||
"User memberships deleted successfully",
|
||||
user_id=user_id,
|
||||
deleted_count=result.get("deleted_count"),
|
||||
total_memberships=result.get("total_memberships")
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "User memberships deleted successfully",
|
||||
"summary": result
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Delete user memberships failed",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to delete user memberships"
|
||||
)
|
||||
|
||||
@router.post(route_builder.build_base_route("{tenant_id}/transfer-ownership", include_tenant_prefix=False), response_model=TenantResponse)
|
||||
@track_endpoint_metrics("tenant_transfer_ownership")
|
||||
async def transfer_ownership(
|
||||
new_owner_id: str,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Transfer tenant ownership to another admin.
|
||||
Only the current owner or internal services can perform this action.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Transfer ownership request received",
|
||||
tenant_id=str(tenant_id),
|
||||
new_owner_id=new_owner_id,
|
||||
requesting_user=current_user.get("user_id"),
|
||||
is_service=current_user.get("type") == "service"
|
||||
)
|
||||
|
||||
try:
|
||||
# Get current tenant to find current owner
|
||||
tenant_info = await tenant_service.get_tenant_by_id(str(tenant_id))
|
||||
if not tenant_info:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found"
|
||||
)
|
||||
|
||||
current_owner_id = tenant_info.owner_id
|
||||
|
||||
result = await tenant_service.transfer_tenant_ownership(
|
||||
str(tenant_id),
|
||||
current_owner_id,
|
||||
new_owner_id,
|
||||
requesting_user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Ownership transferred successfully",
|
||||
tenant_id=str(tenant_id),
|
||||
from_owner=current_owner_id,
|
||||
to_owner=new_owner_id
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Transfer ownership failed",
|
||||
tenant_id=str(tenant_id),
|
||||
new_owner_id=new_owner_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to transfer ownership"
|
||||
)
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/admins", include_tenant_prefix=False), response_model=List[TenantMemberResponse])
|
||||
@track_endpoint_metrics("tenant_get_admins")
|
||||
async def get_tenant_admins(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""
|
||||
Get all admins (owner + admins) for a tenant.
|
||||
Used by auth service to check for other admins before tenant deletion.
|
||||
"""
|
||||
|
||||
logger.info(
|
||||
"Get tenant admins request received",
|
||||
tenant_id=str(tenant_id),
|
||||
requesting_user=current_user.get("user_id"),
|
||||
is_service=current_user.get("type") == "service"
|
||||
)
|
||||
|
||||
try:
|
||||
admins = await tenant_service.get_tenant_admins(str(tenant_id))
|
||||
|
||||
logger.info(
|
||||
"Retrieved tenant admins",
|
||||
tenant_id=str(tenant_id),
|
||||
admin_count=len(admins)
|
||||
)
|
||||
|
||||
return admins
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get tenant admins failed",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get tenant admins"
|
||||
)
|
||||
734
services/tenant/app/api/tenant_operations.py
Normal file
734
services/tenant/app/api/tenant_operations.py
Normal file
@@ -0,0 +1,734 @@
|
||||
"""
|
||||
Tenant Operations API - BUSINESS operations
|
||||
Handles complex tenant operations, registration, search, and analytics
|
||||
|
||||
NOTE: All subscription-related endpoints have been moved to subscription.py
|
||||
as part of the architecture redesign for better separation of concerns.
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
|
||||
from typing import List, Dict, Any, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from app.schemas.tenants import (
|
||||
BakeryRegistration, TenantResponse, TenantAccessResponse,
|
||||
TenantSearchRequest
|
||||
)
|
||||
from app.services.tenant_service import EnhancedTenantService
|
||||
from app.services.payment_service import PaymentService
|
||||
from app.models import AuditLog
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
require_admin_role_dep
|
||||
)
|
||||
from app.core.database import get_db
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from shared.auth.access_control import owner_role_required, admin_role_required
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.metrics import track_endpoint_metrics
|
||||
from shared.security import create_audit_logger, AuditSeverity, AuditAction
|
||||
from shared.config.base import is_internal_service
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
# Initialize audit logger
|
||||
audit_logger = create_audit_logger("tenant-service", AuditLog)
|
||||
|
||||
|
||||
# Dependency injection for enhanced tenant service
|
||||
def get_enhanced_tenant_service():
|
||||
try:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
return EnhancedTenantService(database_manager)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create enhanced tenant service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
|
||||
def get_payment_service():
|
||||
try:
|
||||
return PaymentService()
|
||||
except Exception as e:
|
||||
logger.error("Failed to create payment service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Payment service initialization failed")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TENANT REGISTRATION & ACCESS OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.post(route_builder.build_base_route("register", include_tenant_prefix=False), response_model=TenantResponse)
|
||||
async def register_bakery(
|
||||
bakery_data: BakeryRegistration,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service),
|
||||
payment_service: PaymentService = Depends(get_payment_service),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Register a new bakery/tenant with enhanced validation and features"""
|
||||
|
||||
try:
|
||||
coupon_validation = None
|
||||
success = None
|
||||
discount = None
|
||||
error = None
|
||||
|
||||
result = await tenant_service.create_bakery(
|
||||
bakery_data,
|
||||
current_user["user_id"]
|
||||
)
|
||||
|
||||
tenant_id = result.id
|
||||
|
||||
if bakery_data.link_existing_subscription and bakery_data.subscription_id:
|
||||
logger.info("Linking existing subscription to new tenant",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=bakery_data.subscription_id,
|
||||
user_id=current_user["user_id"])
|
||||
|
||||
try:
|
||||
from app.services.subscription_service import SubscriptionService
|
||||
|
||||
subscription_service = SubscriptionService(db)
|
||||
|
||||
linking_result = await subscription_service.link_subscription_to_tenant(
|
||||
subscription_id=bakery_data.subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"]
|
||||
)
|
||||
|
||||
logger.info("Subscription linked successfully during tenant registration",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=bakery_data.subscription_id)
|
||||
|
||||
except Exception as linking_error:
|
||||
logger.error("Error linking subscription during tenant registration",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=bakery_data.subscription_id,
|
||||
error=str(linking_error))
|
||||
|
||||
elif bakery_data.coupon_code:
|
||||
from app.services.coupon_service import CouponService
|
||||
|
||||
coupon_service = CouponService(db)
|
||||
coupon_validation = await coupon_service.validate_coupon_code(
|
||||
bakery_data.coupon_code,
|
||||
tenant_id
|
||||
)
|
||||
|
||||
if not coupon_validation["valid"]:
|
||||
logger.warning(
|
||||
"Invalid coupon code provided during registration",
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
error=coupon_validation["error_message"]
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=coupon_validation["error_message"]
|
||||
)
|
||||
|
||||
success, discount, error = await coupon_service.redeem_coupon(
|
||||
bakery_data.coupon_code,
|
||||
tenant_id,
|
||||
base_trial_days=0
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info("Coupon redeemed during registration",
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
tenant_id=tenant_id)
|
||||
else:
|
||||
logger.warning("Failed to redeem coupon during registration",
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
error=error)
|
||||
else:
|
||||
try:
|
||||
from app.repositories.subscription_repository import SubscriptionRepository
|
||||
from app.models.tenants import Subscription
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from app.core.config import settings
|
||||
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
async with database_manager.get_session() as session:
|
||||
subscription_repo = SubscriptionRepository(Subscription, session)
|
||||
|
||||
existing_subscription = await subscription_repo.get_by_tenant_id(str(result.id))
|
||||
|
||||
if existing_subscription:
|
||||
logger.info(
|
||||
"Tenant already has an active subscription, skipping default subscription creation",
|
||||
tenant_id=str(result.id),
|
||||
existing_plan=existing_subscription.plan,
|
||||
subscription_id=str(existing_subscription.id)
|
||||
)
|
||||
else:
|
||||
trial_end_date = datetime.now(timezone.utc)
|
||||
next_billing_date = trial_end_date
|
||||
|
||||
await subscription_repo.create_subscription({
|
||||
"tenant_id": str(result.id),
|
||||
"plan": "starter",
|
||||
"status": "trial",
|
||||
"billing_cycle": "monthly",
|
||||
"next_billing_date": next_billing_date,
|
||||
"trial_ends_at": trial_end_date
|
||||
})
|
||||
await session.commit()
|
||||
|
||||
logger.info(
|
||||
"Default subscription created for new tenant",
|
||||
tenant_id=str(result.id),
|
||||
plan="starter",
|
||||
trial_days=0
|
||||
)
|
||||
except Exception as subscription_error:
|
||||
logger.error(
|
||||
"Failed to create default subscription for tenant",
|
||||
tenant_id=str(result.id),
|
||||
error=str(subscription_error)
|
||||
)
|
||||
|
||||
if coupon_validation and coupon_validation["valid"]:
|
||||
from app.core.config import settings
|
||||
from app.services.coupon_service import CouponService
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
coupon_service = CouponService(session)
|
||||
success, discount, error = await coupon_service.redeem_coupon(
|
||||
bakery_data.coupon_code,
|
||||
result.id,
|
||||
base_trial_days=0
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.info(
|
||||
"Coupon redeemed successfully",
|
||||
tenant_id=result.id,
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
discount=discount
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Failed to redeem coupon after registration",
|
||||
tenant_id=result.id,
|
||||
coupon_code=bakery_data.coupon_code,
|
||||
error=error
|
||||
)
|
||||
|
||||
logger.info("Bakery registered successfully",
|
||||
name=bakery_data.name,
|
||||
owner_email=current_user.get('email'),
|
||||
tenant_id=result.id,
|
||||
coupon_applied=bakery_data.coupon_code is not None)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Bakery registration failed",
|
||||
name=bakery_data.name,
|
||||
owner_id=current_user["user_id"],
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Bakery registration failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/my-access", include_tenant_prefix=False), response_model=TenantAccessResponse)
|
||||
async def get_current_user_tenant_access(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
):
|
||||
"""Get current user's access to tenant with role and permissions"""
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
tenant_service = EnhancedTenantService(database_manager)
|
||||
|
||||
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
|
||||
return access_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Current user access verification failed",
|
||||
user_id=current_user["user_id"],
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Access verification failed"
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}/access/{user_id}", include_tenant_prefix=False), response_model=TenantAccessResponse)
|
||||
async def verify_tenant_access(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
user_id: str = Path(..., description="User ID")
|
||||
):
|
||||
"""Verify if user has access to tenant - Enhanced version with detailed permissions"""
|
||||
|
||||
if is_internal_service(user_id):
|
||||
logger.info("Service access granted", service=user_id, tenant_id=str(tenant_id))
|
||||
return TenantAccessResponse(
|
||||
has_access=True,
|
||||
role="service",
|
||||
permissions=["read", "write"]
|
||||
)
|
||||
|
||||
try:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
tenant_service = EnhancedTenantService(database_manager)
|
||||
|
||||
access_info = await tenant_service.verify_user_access(user_id, str(tenant_id))
|
||||
return access_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Access verification failed",
|
||||
user_id=user_id,
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Access verification failed"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TENANT SEARCH & DISCOVERY OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.get(route_builder.build_base_route("subdomain/{subdomain}", include_tenant_prefix=False), response_model=TenantResponse)
|
||||
@track_endpoint_metrics("tenant_get_by_subdomain")
|
||||
async def get_tenant_by_subdomain(
|
||||
subdomain: str = Path(..., description="Tenant subdomain"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get tenant by subdomain with enhanced validation"""
|
||||
|
||||
tenant = await tenant_service.get_tenant_by_subdomain(subdomain)
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found"
|
||||
)
|
||||
|
||||
access = await tenant_service.verify_user_access(current_user["user_id"], tenant.id)
|
||||
if not access.has_access:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant"
|
||||
)
|
||||
|
||||
return tenant
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("user/{user_id}/owned", include_tenant_prefix=False), response_model=List[TenantResponse])
|
||||
async def get_user_owned_tenants(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get all tenants owned by a user with enhanced data"""
|
||||
|
||||
user_role = current_user.get('role', '').lower()
|
||||
|
||||
is_demo_user = current_user.get("is_demo", False) and user_id == "demo-user"
|
||||
|
||||
if user_id != current_user["user_id"] and not is_demo_user and user_role != 'admin':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Can only access your own tenants"
|
||||
)
|
||||
|
||||
if current_user.get("is_demo", False):
|
||||
demo_session_id = current_user.get("demo_session_id")
|
||||
demo_account_type = current_user.get("demo_account_type", "")
|
||||
|
||||
if demo_session_id:
|
||||
logger.info("Fetching virtual tenants for demo session",
|
||||
demo_session_id=demo_session_id,
|
||||
demo_account_type=demo_account_type)
|
||||
|
||||
virtual_tenants = await tenant_service.get_virtual_tenants_for_session(demo_session_id, demo_account_type)
|
||||
return virtual_tenants
|
||||
else:
|
||||
virtual_tenants = await tenant_service.get_demo_tenants_by_session_type(
|
||||
demo_account_type,
|
||||
str(current_user["user_id"])
|
||||
)
|
||||
return virtual_tenants
|
||||
|
||||
actual_user_id = current_user["user_id"] if is_demo_user else user_id
|
||||
tenants = await tenant_service.get_user_tenants(actual_user_id)
|
||||
return tenants
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("search", include_tenant_prefix=False), response_model=List[TenantResponse])
|
||||
@track_endpoint_metrics("tenant_search")
|
||||
async def search_tenants(
|
||||
search_term: str = Query(..., description="Search term"),
|
||||
business_type: Optional[str] = Query(None, description="Business type filter"),
|
||||
city: Optional[str] = Query(None, description="City filter"),
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Maximum number of records to return"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Search tenants with advanced filters and pagination"""
|
||||
|
||||
tenants = await tenant_service.search_tenants(
|
||||
search_term=search_term,
|
||||
business_type=business_type,
|
||||
city=city,
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
return tenants
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("nearby", include_tenant_prefix=False), response_model=List[TenantResponse])
|
||||
@track_endpoint_metrics("tenant_get_nearby")
|
||||
async def get_nearby_tenants(
|
||||
latitude: float = Query(..., description="Latitude coordinate"),
|
||||
longitude: float = Query(..., description="Longitude coordinate"),
|
||||
radius_km: float = Query(10.0, ge=0.1, le=100.0, description="Search radius in kilometers"),
|
||||
limit: int = Query(50, ge=1, le=100, description="Maximum number of results"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get tenants near a geographic location with enhanced geospatial search"""
|
||||
|
||||
tenants = await tenant_service.get_tenants_near_location(
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
radius_km=radius_km,
|
||||
limit=limit
|
||||
)
|
||||
return tenants
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("users/{user_id}", include_tenant_prefix=False), response_model=List[TenantResponse])
|
||||
@track_endpoint_metrics("tenant_get_user_tenants")
|
||||
async def get_user_tenants(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get all tenants owned by a user - Fixed endpoint for frontend"""
|
||||
|
||||
is_demo_user = current_user.get("is_demo", False)
|
||||
is_service_account = current_user.get("type") == "service"
|
||||
user_role = current_user.get('role', '').lower()
|
||||
|
||||
if user_id != current_user["user_id"] and not is_service_account and not (is_demo_user and user_id == "demo-user") and user_role != 'admin':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Can only access your own tenants"
|
||||
)
|
||||
|
||||
try:
|
||||
tenants = await tenant_service.get_user_tenants(user_id)
|
||||
logger.info("Retrieved user tenants", user_id=user_id, tenant_count=len(tenants))
|
||||
return tenants
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Get user tenants failed", user_id=user_id, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get user tenants"
|
||||
)
|
||||
|
||||
|
||||
@router.get(route_builder.build_base_route("members/user/{user_id}", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("tenant_get_user_memberships")
|
||||
async def get_user_memberships(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get all tenant memberships for a user (for authentication service)"""
|
||||
|
||||
is_demo_user = current_user.get("is_demo", False)
|
||||
is_service_account = current_user.get("type") == "service"
|
||||
user_role = current_user.get('role', '').lower()
|
||||
|
||||
if user_id != current_user["user_id"] and not is_service_account and not (is_demo_user and user_id == "demo-user") and user_role != 'admin':
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Can only access your own memberships"
|
||||
)
|
||||
|
||||
try:
|
||||
memberships = await tenant_service.get_user_memberships(user_id)
|
||||
logger.info("Retrieved user memberships", user_id=user_id, membership_count=len(memberships))
|
||||
return memberships
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Get user memberships failed", user_id=user_id, error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get user memberships"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TENANT MODEL STATUS OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.put(route_builder.build_base_route("{tenant_id}/model-status", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("tenant_update_model_status")
|
||||
async def update_tenant_model_status(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
ml_model_trained: bool = Query(..., description="Whether model is trained"),
|
||||
last_training_date: Optional[datetime] = Query(None, description="Last training date"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Update tenant model training status with enhanced tracking"""
|
||||
|
||||
try:
|
||||
result = await tenant_service.update_model_status(
|
||||
str(tenant_id),
|
||||
ml_model_trained,
|
||||
current_user["user_id"],
|
||||
last_training_date
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Model status update failed",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to update model status"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TENANT ACTIVATION/DEACTIVATION OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.post(route_builder.build_base_route("{tenant_id}/deactivate", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("tenant_deactivate")
|
||||
@owner_role_required
|
||||
async def deactivate_tenant(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Deactivate a tenant (owner only) with enhanced validation"""
|
||||
|
||||
try:
|
||||
success = await tenant_service.deactivate_tenant(
|
||||
str(tenant_id),
|
||||
current_user["user_id"]
|
||||
)
|
||||
|
||||
if success:
|
||||
try:
|
||||
from app.core.database import get_db_session
|
||||
async with get_db_session() as db:
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.DEACTIVATE.value,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant_id),
|
||||
severity=AuditSeverity.CRITICAL.value,
|
||||
description=f"Owner {current_user.get('email', current_user['user_id'])} deactivated tenant",
|
||||
endpoint="/{tenant_id}/deactivate",
|
||||
method="POST"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
return {"success": True, "message": "Tenant deactivated successfully"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to deactivate tenant"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tenant deactivation failed",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to deactivate tenant"
|
||||
)
|
||||
|
||||
|
||||
@router.post(route_builder.build_base_route("{tenant_id}/activate", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("tenant_activate")
|
||||
@owner_role_required
|
||||
async def activate_tenant(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Activate a previously deactivated tenant (owner only) with enhanced validation"""
|
||||
|
||||
try:
|
||||
success = await tenant_service.activate_tenant(
|
||||
str(tenant_id),
|
||||
current_user["user_id"]
|
||||
)
|
||||
|
||||
if success:
|
||||
try:
|
||||
from app.core.database import get_db_session
|
||||
async with get_db_session() as db:
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.ACTIVATE.value,
|
||||
resource_type="tenant",
|
||||
resource_id=str(tenant_id),
|
||||
severity=AuditSeverity.HIGH.value,
|
||||
description=f"Owner {current_user.get('email', current_user['user_id'])} activated tenant",
|
||||
endpoint="/{tenant_id}/activate",
|
||||
method="POST"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
return {"success": True, "message": "Tenant activated successfully"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to activate tenant"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tenant activation failed",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to activate tenant"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TENANT STATISTICS & ANALYTICS
|
||||
# ============================================================================
|
||||
|
||||
@router.get(route_builder.build_base_route("statistics", include_tenant_prefix=False), dependencies=[Depends(require_admin_role_dep)])
|
||||
@track_endpoint_metrics("tenant_get_statistics")
|
||||
async def get_tenant_statistics(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get comprehensive tenant statistics (admin only) with enhanced analytics"""
|
||||
|
||||
try:
|
||||
stats = await tenant_service.get_tenant_statistics()
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Get tenant statistics failed", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get tenant statistics"
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# USER-TENANT RELATIONSHIP OPERATIONS
|
||||
# ============================================================================
|
||||
|
||||
@router.get(route_builder.build_base_route("users/{user_id}/primary-tenant", include_tenant_prefix=False))
|
||||
async def get_user_primary_tenant(
|
||||
user_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Get the primary tenant for a user
|
||||
|
||||
This endpoint is used by the auth service to validate user subscriptions
|
||||
during login. It returns the user's primary tenant (the one they own or
|
||||
have primary access to).
|
||||
|
||||
Args:
|
||||
user_id: The user ID to look up
|
||||
|
||||
Returns:
|
||||
Dictionary with user's primary tenant information, or None if no tenant found
|
||||
|
||||
Example Response:
|
||||
{
|
||||
"user_id": "user-uuid",
|
||||
"tenant_id": "tenant-uuid",
|
||||
"tenant_name": "Bakery Name",
|
||||
"tenant_type": "standalone",
|
||||
"is_owner": true
|
||||
}
|
||||
"""
|
||||
try:
|
||||
from app.core.database import database_manager
|
||||
from app.repositories.tenant_repository import TenantRepository
|
||||
from app.models.tenants import Tenant
|
||||
|
||||
async with database_manager.get_session() as session:
|
||||
tenant_repo = TenantRepository(Tenant, session)
|
||||
|
||||
# Get user's primary tenant (the one they own)
|
||||
primary_tenant = await tenant_repo.get_user_primary_tenant(user_id)
|
||||
|
||||
if primary_tenant:
|
||||
logger.info("Found primary tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(primary_tenant.id),
|
||||
tenant_name=primary_tenant.name)
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'tenant_id': str(primary_tenant.id),
|
||||
'tenant_name': primary_tenant.name,
|
||||
'tenant_type': primary_tenant.tenant_type,
|
||||
'is_owner': True
|
||||
}
|
||||
else:
|
||||
# If no primary tenant found, check if user has access to any tenant
|
||||
any_tenant = await tenant_repo.get_any_user_tenant(user_id)
|
||||
|
||||
if any_tenant:
|
||||
logger.info("Found accessible tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(any_tenant.id),
|
||||
tenant_name=any_tenant.name)
|
||||
return {
|
||||
'user_id': user_id,
|
||||
'tenant_id': str(any_tenant.id),
|
||||
'tenant_name': any_tenant.name,
|
||||
'tenant_type': any_tenant.tenant_type,
|
||||
'is_owner': False
|
||||
}
|
||||
else:
|
||||
logger.info("No tenant found for user", user_id=user_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get primary tenant for user {user_id}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get primary tenant: {str(e)}")
|
||||
186
services/tenant/app/api/tenant_settings.py
Normal file
186
services/tenant/app/api/tenant_settings.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# services/tenant/app/api/tenant_settings.py
|
||||
"""
|
||||
Tenant Settings API Endpoints
|
||||
REST API for managing tenant-specific operational settings
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from uuid import UUID
|
||||
from typing import Dict, Any
|
||||
|
||||
from app.core.database import get_db
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
from ..services.tenant_settings_service import TenantSettingsService
|
||||
from ..schemas.tenant_settings import (
|
||||
TenantSettingsResponse,
|
||||
TenantSettingsUpdate,
|
||||
CategoryUpdateRequest,
|
||||
CategoryResetResponse
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{tenant_id}/settings",
|
||||
response_model=TenantSettingsResponse,
|
||||
summary="Get all tenant settings",
|
||||
description="Retrieve all operational settings for a tenant. Creates default settings if none exist."
|
||||
)
|
||||
async def get_tenant_settings(
|
||||
tenant_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get all settings for a tenant
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
|
||||
Returns all setting categories with their current values.
|
||||
If settings don't exist, default values are created and returned.
|
||||
"""
|
||||
service = TenantSettingsService(db)
|
||||
settings = await service.get_settings(tenant_id)
|
||||
return settings
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{tenant_id}/settings",
|
||||
response_model=TenantSettingsResponse,
|
||||
summary="Update tenant settings",
|
||||
description="Update one or more setting categories for a tenant. Only provided categories are updated."
|
||||
)
|
||||
async def update_tenant_settings(
|
||||
tenant_id: UUID,
|
||||
updates: TenantSettingsUpdate,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update tenant settings
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
- **updates**: Object containing setting categories to update
|
||||
|
||||
Only provided categories will be updated. Omitted categories remain unchanged.
|
||||
All values are validated against min/max constraints.
|
||||
"""
|
||||
service = TenantSettingsService(db)
|
||||
settings = await service.update_settings(tenant_id, updates)
|
||||
return settings
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{tenant_id}/settings/{category}",
|
||||
response_model=Dict[str, Any],
|
||||
summary="Get settings for a specific category",
|
||||
description="Retrieve settings for a single category (procurement, inventory, production, supplier, pos, or order)"
|
||||
)
|
||||
async def get_category_settings(
|
||||
tenant_id: UUID,
|
||||
category: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get settings for a specific category
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
- **category**: Category name (procurement, inventory, production, supplier, pos, order)
|
||||
|
||||
Returns settings for the specified category only.
|
||||
|
||||
Valid categories:
|
||||
- procurement: Auto-approval and procurement planning settings
|
||||
- inventory: Stock thresholds and temperature monitoring
|
||||
- production: Capacity, quality, and scheduling settings
|
||||
- supplier: Payment terms and performance thresholds
|
||||
- pos: POS integration sync settings
|
||||
- order: Discount and delivery settings
|
||||
"""
|
||||
service = TenantSettingsService(db)
|
||||
category_settings = await service.get_category(tenant_id, category)
|
||||
return {
|
||||
"tenant_id": str(tenant_id),
|
||||
"category": category,
|
||||
"settings": category_settings
|
||||
}
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{tenant_id}/settings/{category}",
|
||||
response_model=TenantSettingsResponse,
|
||||
summary="Update settings for a specific category",
|
||||
description="Update all or some fields within a single category"
|
||||
)
|
||||
async def update_category_settings(
|
||||
tenant_id: UUID,
|
||||
category: str,
|
||||
request: CategoryUpdateRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update settings for a specific category
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
- **category**: Category name
|
||||
- **request**: Object containing the settings to update
|
||||
|
||||
Updates only the specified category. All values are validated.
|
||||
"""
|
||||
service = TenantSettingsService(db)
|
||||
settings = await service.update_category(tenant_id, category, request.settings)
|
||||
return settings
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{tenant_id}/settings/{category}/reset",
|
||||
response_model=CategoryResetResponse,
|
||||
summary="Reset category to default values",
|
||||
description="Reset a specific category to its default values"
|
||||
)
|
||||
async def reset_category_settings(
|
||||
tenant_id: UUID,
|
||||
category: str,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Reset a category to default values
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
- **category**: Category name
|
||||
|
||||
Resets all settings in the specified category to their default values.
|
||||
This operation cannot be undone.
|
||||
"""
|
||||
service = TenantSettingsService(db)
|
||||
reset_settings = await service.reset_category(tenant_id, category)
|
||||
|
||||
return CategoryResetResponse(
|
||||
category=category,
|
||||
settings=reset_settings,
|
||||
message=f"Category '{category}' has been reset to default values"
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{tenant_id}/settings",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
summary="Delete tenant settings",
|
||||
description="Delete all settings for a tenant (used when tenant is deleted)"
|
||||
)
|
||||
async def delete_tenant_settings(
|
||||
tenant_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Delete tenant settings
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
|
||||
This endpoint is typically called automatically when a tenant is deleted.
|
||||
It removes all setting data for the tenant.
|
||||
"""
|
||||
service = TenantSettingsService(db)
|
||||
await service.delete_settings(tenant_id)
|
||||
return None
|
||||
285
services/tenant/app/api/tenants.py
Normal file
285
services/tenant/app/api/tenants.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""
|
||||
Tenant API - ATOMIC operations
|
||||
Handles basic CRUD operations for tenants
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
|
||||
from typing import Dict, Any, List
|
||||
from uuid import UUID
|
||||
|
||||
from app.schemas.tenants import TenantResponse, TenantUpdate
|
||||
from app.services.tenant_service import EnhancedTenantService
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from shared.auth.access_control import admin_role_required
|
||||
from shared.routing.route_builder import RouteBuilder
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.monitoring.metrics import track_endpoint_metrics
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
route_builder = RouteBuilder("tenants")
|
||||
|
||||
# Dependency injection for enhanced tenant service
|
||||
def get_enhanced_tenant_service():
|
||||
try:
|
||||
from app.core.config import settings
|
||||
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
return EnhancedTenantService(database_manager)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create enhanced tenant service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
@router.get(route_builder.build_base_route("", include_tenant_prefix=False), response_model=List[TenantResponse])
|
||||
@track_endpoint_metrics("tenants_list")
|
||||
async def get_active_tenants(
|
||||
skip: int = Query(0, ge=0, description="Number of records to skip"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get all active tenants - Available to service accounts and admins"""
|
||||
|
||||
logger.info(
|
||||
"Get active tenants request received",
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
user_id=current_user.get("user_id"),
|
||||
user_type=current_user.get("type", "user"),
|
||||
is_service=current_user.get("type") == "service",
|
||||
role=current_user.get("role"),
|
||||
service_name=current_user.get("service", "none")
|
||||
)
|
||||
|
||||
# Allow service accounts to call this endpoint
|
||||
if current_user.get("type") != "service":
|
||||
# For non-service users, could add additional role checks here if needed
|
||||
logger.debug(
|
||||
"Non-service user requesting active tenants",
|
||||
user_id=current_user.get("user_id"),
|
||||
role=current_user.get("role")
|
||||
)
|
||||
|
||||
tenants = await tenant_service.get_active_tenants(skip=skip, limit=limit)
|
||||
|
||||
logger.debug(
|
||||
"Get active tenants successful",
|
||||
count=len(tenants),
|
||||
skip=skip,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return tenants
|
||||
|
||||
@router.get(route_builder.build_base_route("{tenant_id}", include_tenant_prefix=False), response_model=TenantResponse)
|
||||
@track_endpoint_metrics("tenant_get")
|
||||
async def get_tenant(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get tenant by ID - ATOMIC operation - ENHANCED with logging"""
|
||||
|
||||
logger.info(
|
||||
"Tenant GET request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
user_type=current_user.get("type", "user"),
|
||||
is_service=current_user.get("type") == "service",
|
||||
role=current_user.get("role"),
|
||||
service_name=current_user.get("service", "none")
|
||||
)
|
||||
|
||||
tenant = await tenant_service.get_tenant_by_id(str(tenant_id))
|
||||
if not tenant:
|
||||
logger.warning(
|
||||
"Tenant not found",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Tenant not found"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Tenant GET request successful",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id")
|
||||
)
|
||||
|
||||
return tenant
|
||||
|
||||
@router.put(route_builder.build_base_route("{tenant_id}", include_tenant_prefix=False), response_model=TenantResponse)
|
||||
@admin_role_required
|
||||
async def update_tenant(
|
||||
update_data: TenantUpdate,
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Update tenant information - ATOMIC operation (Admin+ only)"""
|
||||
|
||||
try:
|
||||
result = await tenant_service.update_tenant(
|
||||
str(tenant_id),
|
||||
update_data,
|
||||
current_user["user_id"]
|
||||
)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tenant update failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user["user_id"],
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Tenant update failed"
|
||||
)
|
||||
|
||||
@router.get(route_builder.build_base_route("user/{user_id}/tenants", include_tenant_prefix=False), response_model=List[TenantResponse])
|
||||
@track_endpoint_metrics("user_tenants_list")
|
||||
async def get_user_tenants(
|
||||
user_id: str = Path(..., description="User ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Get all tenants accessible by a user"""
|
||||
|
||||
logger.info(
|
||||
"Get user tenants request received",
|
||||
user_id=user_id,
|
||||
requesting_user=current_user.get("user_id"),
|
||||
is_demo=current_user.get("is_demo", False)
|
||||
)
|
||||
|
||||
# Allow demo users to access tenant information for demo-user
|
||||
is_demo_user = current_user.get("is_demo", False)
|
||||
is_service_account = current_user.get("type") == "service"
|
||||
|
||||
# For demo sessions, when frontend requests with "demo-user", use the actual demo owner ID
|
||||
actual_user_id = user_id
|
||||
if is_demo_user and user_id == "demo-user":
|
||||
actual_user_id = current_user.get("user_id")
|
||||
logger.info(
|
||||
"Demo session: mapping demo-user to actual owner",
|
||||
requested_user_id=user_id,
|
||||
actual_user_id=actual_user_id
|
||||
)
|
||||
|
||||
if current_user.get("user_id") != actual_user_id and not is_service_account and not (is_demo_user and user_id == "demo-user"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Can only access own tenants"
|
||||
)
|
||||
|
||||
try:
|
||||
# For demo sessions, use session-specific filtering to prevent cross-session data leakage
|
||||
if is_demo_user:
|
||||
demo_session_id = current_user.get("demo_session_id")
|
||||
demo_account_type = current_user.get("demo_account_type", "professional")
|
||||
|
||||
logger.info(
|
||||
"Demo session detected for get_user_tenants",
|
||||
user_id=user_id,
|
||||
actual_user_id=actual_user_id,
|
||||
demo_session_id=demo_session_id,
|
||||
demo_account_type=demo_account_type,
|
||||
has_session_id=bool(demo_session_id)
|
||||
)
|
||||
|
||||
if demo_session_id:
|
||||
# Get only tenants for this specific demo session
|
||||
tenants = await tenant_service.get_virtual_tenants_for_session(demo_session_id, demo_account_type)
|
||||
logger.info(
|
||||
"Get demo session tenants successful",
|
||||
user_id=user_id,
|
||||
demo_session_id=demo_session_id,
|
||||
demo_account_type=demo_account_type,
|
||||
tenant_count=len(tenants),
|
||||
tenant_ids=[str(t.id) for t in tenants] if tenants else []
|
||||
)
|
||||
return tenants
|
||||
else:
|
||||
logger.warning(
|
||||
"Demo user without session ID - falling back to regular user tenants",
|
||||
user_id=actual_user_id
|
||||
)
|
||||
|
||||
# Regular users or demo fallback: get tenants by ownership and membership
|
||||
tenants = await tenant_service.get_user_tenants(actual_user_id)
|
||||
|
||||
logger.debug(
|
||||
"Get user tenants successful",
|
||||
user_id=user_id,
|
||||
tenant_count=len(tenants)
|
||||
)
|
||||
|
||||
return tenants
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Get user tenants failed",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get user tenants"
|
||||
)
|
||||
|
||||
@router.delete(route_builder.build_base_route("{tenant_id}", include_tenant_prefix=False))
|
||||
@track_endpoint_metrics("tenant_delete")
|
||||
async def delete_tenant(
|
||||
tenant_id: UUID = Path(..., description="Tenant ID"),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_dep),
|
||||
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
|
||||
):
|
||||
"""Delete tenant and all associated data - ATOMIC operation (Owner/Admin or System only)"""
|
||||
|
||||
logger.info(
|
||||
"Tenant DELETE request received",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
user_type=current_user.get("type", "user"),
|
||||
is_service=current_user.get("type") == "service",
|
||||
role=current_user.get("role"),
|
||||
service_name=current_user.get("service", "none")
|
||||
)
|
||||
|
||||
try:
|
||||
# Allow internal service calls to bypass admin check
|
||||
skip_admin_check = current_user.get("type") == "service"
|
||||
|
||||
result = await tenant_service.delete_tenant(
|
||||
str(tenant_id),
|
||||
requesting_user_id=current_user.get("user_id"),
|
||||
skip_admin_check=skip_admin_check
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Tenant DELETE request successful",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
deleted_items=result.get("deleted_items")
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Tenant deleted successfully",
|
||||
"summary": result
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Tenant deletion failed",
|
||||
tenant_id=str(tenant_id),
|
||||
user_id=current_user.get("user_id"),
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Tenant deletion failed"
|
||||
)
|
||||
357
services/tenant/app/api/usage_forecast.py
Normal file
357
services/tenant/app/api/usage_forecast.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Usage Forecasting API
|
||||
|
||||
This endpoint predicts when a tenant will hit their subscription limits
|
||||
based on historical usage growth rates.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
import redis.asyncio as redis
|
||||
|
||||
from shared.auth.decorators import get_current_user_dep
|
||||
from app.core.config import settings
|
||||
from app.core.database import database_manager
|
||||
from app.services.subscription_limit_service import SubscriptionLimitService
|
||||
|
||||
router = APIRouter(prefix="/usage-forecast", tags=["usage-forecast"])
|
||||
|
||||
|
||||
class UsageDataPoint(BaseModel):
|
||||
"""Single usage data point"""
|
||||
date: str
|
||||
value: int
|
||||
|
||||
|
||||
class MetricForecast(BaseModel):
|
||||
"""Forecast for a single metric"""
|
||||
metric: str
|
||||
label: str
|
||||
current: int
|
||||
limit: Optional[int] # None = unlimited
|
||||
unit: str
|
||||
daily_growth_rate: Optional[float] # None if not enough data
|
||||
predicted_breach_date: Optional[str] # ISO date string, None if unlimited or no breach
|
||||
days_until_breach: Optional[int] # None if unlimited or no breach
|
||||
usage_percentage: float
|
||||
status: str # 'safe', 'warning', 'critical', 'unlimited'
|
||||
trend_data: List[UsageDataPoint] # 30-day history
|
||||
|
||||
|
||||
class UsageForecastResponse(BaseModel):
|
||||
"""Complete usage forecast response"""
|
||||
tenant_id: str
|
||||
forecasted_at: str
|
||||
metrics: List[MetricForecast]
|
||||
|
||||
|
||||
async def get_redis_client() -> redis.Redis:
|
||||
"""Get Redis client for usage tracking"""
|
||||
return redis.from_url(
|
||||
settings.REDIS_URL,
|
||||
encoding="utf-8",
|
||||
decode_responses=True
|
||||
)
|
||||
|
||||
|
||||
async def get_usage_history(
|
||||
redis_client: redis.Redis,
|
||||
tenant_id: str,
|
||||
metric: str,
|
||||
days: int = 30
|
||||
) -> List[UsageDataPoint]:
|
||||
"""
|
||||
Get historical usage data for a metric from Redis
|
||||
|
||||
Usage data is stored with keys like:
|
||||
usage:daily:{tenant_id}:{metric}:{date}
|
||||
"""
|
||||
history = []
|
||||
today = datetime.utcnow().date()
|
||||
|
||||
for i in range(days):
|
||||
date = today - timedelta(days=i)
|
||||
date_str = date.isoformat()
|
||||
key = f"usage:daily:{tenant_id}:{metric}:{date_str}"
|
||||
|
||||
try:
|
||||
value = await redis_client.get(key)
|
||||
if value is not None:
|
||||
history.append(UsageDataPoint(
|
||||
date=date_str,
|
||||
value=int(value)
|
||||
))
|
||||
except Exception as e:
|
||||
print(f"Error fetching usage for {key}: {e}")
|
||||
continue
|
||||
|
||||
# Return in chronological order (oldest first)
|
||||
return list(reversed(history))
|
||||
|
||||
|
||||
def calculate_growth_rate(history: List[UsageDataPoint]) -> Optional[float]:
|
||||
"""
|
||||
Calculate daily growth rate using linear regression
|
||||
|
||||
Returns average daily increase, or None if insufficient data
|
||||
"""
|
||||
if len(history) < 7: # Need at least 7 days of data
|
||||
return None
|
||||
|
||||
# Simple linear regression
|
||||
n = len(history)
|
||||
sum_x = sum(range(n))
|
||||
sum_y = sum(point.value for point in history)
|
||||
sum_xy = sum(i * point.value for i, point in enumerate(history))
|
||||
sum_x_squared = sum(i * i for i in range(n))
|
||||
|
||||
# Calculate slope (daily growth rate)
|
||||
denominator = (n * sum_x_squared) - (sum_x ** 2)
|
||||
if denominator == 0:
|
||||
return None
|
||||
|
||||
slope = ((n * sum_xy) - (sum_x * sum_y)) / denominator
|
||||
|
||||
return max(slope, 0) # Can't have negative growth for breach prediction
|
||||
|
||||
|
||||
def predict_breach_date(
|
||||
current: int,
|
||||
limit: int,
|
||||
daily_growth_rate: float
|
||||
) -> Optional[tuple[str, int]]:
|
||||
"""
|
||||
Predict when usage will breach the limit
|
||||
|
||||
Returns (breach_date_iso, days_until_breach) or None if no breach predicted
|
||||
"""
|
||||
if daily_growth_rate <= 0:
|
||||
return None
|
||||
|
||||
remaining_capacity = limit - current
|
||||
if remaining_capacity <= 0:
|
||||
# Already at or over limit
|
||||
return datetime.utcnow().date().isoformat(), 0
|
||||
|
||||
days_until_breach = int(remaining_capacity / daily_growth_rate)
|
||||
|
||||
if days_until_breach > 365: # Don't predict beyond 1 year
|
||||
return None
|
||||
|
||||
breach_date = datetime.utcnow().date() + timedelta(days=days_until_breach)
|
||||
|
||||
return breach_date.isoformat(), days_until_breach
|
||||
|
||||
|
||||
def determine_status(usage_percentage: float, days_until_breach: Optional[int]) -> str:
|
||||
"""Determine metric status based on usage and time to breach"""
|
||||
if usage_percentage >= 100:
|
||||
return 'critical'
|
||||
elif usage_percentage >= 90:
|
||||
return 'critical'
|
||||
elif usage_percentage >= 80 or (days_until_breach is not None and days_until_breach <= 14):
|
||||
return 'warning'
|
||||
else:
|
||||
return 'safe'
|
||||
|
||||
|
||||
@router.get("", response_model=UsageForecastResponse)
|
||||
async def get_usage_forecast(
|
||||
tenant_id: str = Query(..., description="Tenant ID"),
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
) -> UsageForecastResponse:
|
||||
"""
|
||||
Get usage forecasts for all metrics
|
||||
|
||||
Predicts when the tenant will hit their subscription limits based on
|
||||
historical usage growth rates from the past 30 days.
|
||||
|
||||
Returns predictions for:
|
||||
- Users
|
||||
- Locations
|
||||
- Products
|
||||
- Recipes
|
||||
- Suppliers
|
||||
- Training jobs (daily)
|
||||
- Forecasts (daily)
|
||||
- API calls (hourly average converted to daily)
|
||||
- File storage
|
||||
"""
|
||||
# Initialize services
|
||||
redis_client = await get_redis_client()
|
||||
limit_service = SubscriptionLimitService(database_manager=database_manager)
|
||||
|
||||
try:
|
||||
# Get current usage summary (includes limits)
|
||||
usage_summary = await limit_service.get_usage_summary(tenant_id)
|
||||
|
||||
if not usage_summary or 'error' in usage_summary:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No active subscription found for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Extract usage data
|
||||
usage = usage_summary.get('usage', {})
|
||||
|
||||
# Define metrics to forecast
|
||||
metric_configs = [
|
||||
{
|
||||
'key': 'users',
|
||||
'label': 'Users',
|
||||
'current': usage.get('users', {}).get('current', 0),
|
||||
'limit': usage.get('users', {}).get('limit'),
|
||||
'unit': ''
|
||||
},
|
||||
{
|
||||
'key': 'locations',
|
||||
'label': 'Locations',
|
||||
'current': usage.get('locations', {}).get('current', 0),
|
||||
'limit': usage.get('locations', {}).get('limit'),
|
||||
'unit': ''
|
||||
},
|
||||
{
|
||||
'key': 'products',
|
||||
'label': 'Products',
|
||||
'current': usage.get('products', {}).get('current', 0),
|
||||
'limit': usage.get('products', {}).get('limit'),
|
||||
'unit': ''
|
||||
},
|
||||
{
|
||||
'key': 'recipes',
|
||||
'label': 'Recipes',
|
||||
'current': usage.get('recipes', {}).get('current', 0),
|
||||
'limit': usage.get('recipes', {}).get('limit'),
|
||||
'unit': ''
|
||||
},
|
||||
{
|
||||
'key': 'suppliers',
|
||||
'label': 'Suppliers',
|
||||
'current': usage.get('suppliers', {}).get('current', 0),
|
||||
'limit': usage.get('suppliers', {}).get('limit'),
|
||||
'unit': ''
|
||||
},
|
||||
{
|
||||
'key': 'training_jobs',
|
||||
'label': 'Training Jobs',
|
||||
'current': usage.get('training_jobs_today', {}).get('current', 0),
|
||||
'limit': usage.get('training_jobs_today', {}).get('limit'),
|
||||
'unit': '/day'
|
||||
},
|
||||
{
|
||||
'key': 'forecasts',
|
||||
'label': 'Forecasts',
|
||||
'current': usage.get('forecasts_today', {}).get('current', 0),
|
||||
'limit': usage.get('forecasts_today', {}).get('limit'),
|
||||
'unit': '/day'
|
||||
},
|
||||
{
|
||||
'key': 'api_calls',
|
||||
'label': 'API Calls',
|
||||
'current': usage.get('api_calls_this_hour', {}).get('current', 0),
|
||||
'limit': usage.get('api_calls_this_hour', {}).get('limit'),
|
||||
'unit': '/hour'
|
||||
},
|
||||
{
|
||||
'key': 'storage',
|
||||
'label': 'File Storage',
|
||||
'current': int(usage.get('file_storage_used_gb', {}).get('current', 0)),
|
||||
'limit': usage.get('file_storage_used_gb', {}).get('limit'),
|
||||
'unit': ' GB'
|
||||
}
|
||||
]
|
||||
|
||||
forecasts: List[MetricForecast] = []
|
||||
|
||||
for config in metric_configs:
|
||||
metric_key = config['key']
|
||||
current = config['current']
|
||||
limit = config['limit']
|
||||
|
||||
# Get usage history
|
||||
history = await get_usage_history(redis_client, tenant_id, metric_key, days=30)
|
||||
|
||||
# Calculate usage percentage
|
||||
if limit is None or limit == -1:
|
||||
usage_percentage = 0.0
|
||||
status = 'unlimited'
|
||||
growth_rate = None
|
||||
breach_date = None
|
||||
days_until = None
|
||||
else:
|
||||
usage_percentage = (current / limit * 100) if limit > 0 else 0
|
||||
|
||||
# Calculate growth rate
|
||||
growth_rate = calculate_growth_rate(history) if history else None
|
||||
|
||||
# Predict breach
|
||||
if growth_rate is not None and growth_rate > 0:
|
||||
breach_result = predict_breach_date(current, limit, growth_rate)
|
||||
if breach_result:
|
||||
breach_date, days_until = breach_result
|
||||
else:
|
||||
breach_date, days_until = None, None
|
||||
else:
|
||||
breach_date, days_until = None, None
|
||||
|
||||
# Determine status
|
||||
status = determine_status(usage_percentage, days_until)
|
||||
|
||||
forecasts.append(MetricForecast(
|
||||
metric=metric_key,
|
||||
label=config['label'],
|
||||
current=current,
|
||||
limit=limit,
|
||||
unit=config['unit'],
|
||||
daily_growth_rate=growth_rate,
|
||||
predicted_breach_date=breach_date,
|
||||
days_until_breach=days_until,
|
||||
usage_percentage=round(usage_percentage, 1),
|
||||
status=status,
|
||||
trend_data=history[-30:] # Last 30 days
|
||||
))
|
||||
|
||||
return UsageForecastResponse(
|
||||
tenant_id=tenant_id,
|
||||
forecasted_at=datetime.utcnow().isoformat(),
|
||||
metrics=forecasts
|
||||
)
|
||||
|
||||
finally:
|
||||
await redis_client.close()
|
||||
|
||||
|
||||
@router.post("/track-usage")
|
||||
async def track_daily_usage(
|
||||
tenant_id: str,
|
||||
metric: str,
|
||||
value: int,
|
||||
current_user: dict = Depends(get_current_user_dep)
|
||||
):
|
||||
"""
|
||||
Manually track daily usage for a metric
|
||||
|
||||
This endpoint is called by services to record daily usage snapshots.
|
||||
The data is stored in Redis with a 60-day TTL.
|
||||
"""
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
try:
|
||||
date_str = datetime.utcnow().date().isoformat()
|
||||
key = f"usage:daily:{tenant_id}:{metric}:{date_str}"
|
||||
|
||||
# Store usage with 60-day TTL
|
||||
await redis_client.setex(key, 60 * 24 * 60 * 60, str(value))
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"metric": metric,
|
||||
"value": value,
|
||||
"date": date_str
|
||||
}
|
||||
|
||||
finally:
|
||||
await redis_client.close()
|
||||
97
services/tenant/app/api/webhooks.py
Normal file
97
services/tenant/app/api/webhooks.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Webhook endpoints for handling payment provider events
|
||||
These endpoints receive events from payment providers like Stripe
|
||||
All event processing is handled by SubscriptionOrchestrationService
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import stripe
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.subscription_orchestration_service import SubscriptionOrchestrationService
|
||||
from app.core.config import settings
|
||||
from app.core.database import get_db
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_subscription_orchestration_service(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
) -> SubscriptionOrchestrationService:
|
||||
"""Dependency injection for SubscriptionOrchestrationService"""
|
||||
try:
|
||||
return SubscriptionOrchestrationService(db)
|
||||
except Exception as e:
|
||||
logger.error("Failed to create subscription orchestration service", error=str(e))
|
||||
raise HTTPException(status_code=500, detail="Service initialization failed")
|
||||
|
||||
|
||||
@router.post("/webhooks/stripe")
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
|
||||
):
|
||||
"""
|
||||
Stripe webhook endpoint to handle payment events
|
||||
This endpoint verifies webhook signatures and processes Stripe events
|
||||
"""
|
||||
try:
|
||||
# Get the payload and signature
|
||||
payload = await request.body()
|
||||
sig_header = request.headers.get('stripe-signature')
|
||||
|
||||
if not sig_header:
|
||||
logger.error("Missing stripe-signature header")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing signature header"
|
||||
)
|
||||
|
||||
# Verify the webhook signature
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
logger.error("Invalid webhook signature", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid signature"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error("Invalid payload", error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid payload"
|
||||
)
|
||||
|
||||
# Get event type and data
|
||||
event_type = event['type']
|
||||
event_data = event['data']['object']
|
||||
|
||||
logger.info("Processing Stripe webhook event",
|
||||
event_type=event_type,
|
||||
event_id=event.get('id'))
|
||||
|
||||
# Use orchestration service to handle the event
|
||||
result = await orchestration_service.handle_payment_webhook(event_type, event_data)
|
||||
|
||||
logger.info("Webhook event processed via orchestration service",
|
||||
event_type=event_type,
|
||||
actions_taken=result.get("actions_taken", []))
|
||||
|
||||
return {"success": True, "event_type": event_type, "actions_taken": result.get("actions_taken", [])}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Error processing Stripe webhook", error=str(e), exc_info=True)
|
||||
# Return 200 OK even on processing errors to prevent Stripe retries
|
||||
# Only return 4xx for signature verification failures
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Webhook processing error",
|
||||
"details": str(e)
|
||||
}
|
||||
308
services/tenant/app/api/whatsapp_admin.py
Normal file
308
services/tenant/app/api/whatsapp_admin.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# services/tenant/app/api/whatsapp_admin.py
|
||||
"""
|
||||
WhatsApp Admin API Endpoints
|
||||
Admin-only endpoints for managing WhatsApp phone number assignments
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from uuid import UUID
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import httpx
|
||||
import os
|
||||
|
||||
from app.core.database import get_db
|
||||
from app.models.tenant_settings import TenantSettings
|
||||
from app.models.tenants import Tenant
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ================================================================
|
||||
# SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class WhatsAppPhoneNumberInfo(BaseModel):
|
||||
"""Information about a WhatsApp phone number from Meta API"""
|
||||
id: str = Field(..., description="Phone Number ID")
|
||||
display_phone_number: str = Field(..., description="Display phone number (e.g., +34 612 345 678)")
|
||||
verified_name: str = Field(..., description="Verified business name")
|
||||
quality_rating: str = Field(..., description="Quality rating (GREEN, YELLOW, RED)")
|
||||
|
||||
|
||||
class TenantWhatsAppStatus(BaseModel):
|
||||
"""WhatsApp status for a tenant"""
|
||||
tenant_id: UUID
|
||||
tenant_name: str
|
||||
whatsapp_enabled: bool
|
||||
phone_number_id: Optional[str] = None
|
||||
display_phone_number: Optional[str] = None
|
||||
|
||||
|
||||
class AssignPhoneNumberRequest(BaseModel):
|
||||
"""Request to assign phone number to tenant"""
|
||||
phone_number_id: str = Field(..., description="Meta WhatsApp Phone Number ID")
|
||||
display_phone_number: str = Field(..., description="Display format (e.g., '+34 612 345 678')")
|
||||
|
||||
|
||||
class AssignPhoneNumberResponse(BaseModel):
|
||||
"""Response after assigning phone number"""
|
||||
success: bool
|
||||
message: str
|
||||
tenant_id: UUID
|
||||
phone_number_id: str
|
||||
display_phone_number: str
|
||||
|
||||
|
||||
# ================================================================
|
||||
# ENDPOINTS
|
||||
# ================================================================
|
||||
|
||||
@router.get(
|
||||
"/admin/whatsapp/phone-numbers",
|
||||
response_model=List[WhatsAppPhoneNumberInfo],
|
||||
summary="List available WhatsApp phone numbers",
|
||||
description="Get all phone numbers available in the master WhatsApp Business Account"
|
||||
)
|
||||
async def list_available_phone_numbers():
|
||||
"""
|
||||
List all phone numbers from the master WhatsApp Business Account
|
||||
|
||||
Requires:
|
||||
- WHATSAPP_BUSINESS_ACCOUNT_ID environment variable
|
||||
- WHATSAPP_ACCESS_TOKEN environment variable
|
||||
|
||||
Returns list of available phone numbers with their status
|
||||
"""
|
||||
business_account_id = os.getenv("WHATSAPP_BUSINESS_ACCOUNT_ID")
|
||||
access_token = os.getenv("WHATSAPP_ACCESS_TOKEN")
|
||||
api_version = os.getenv("WHATSAPP_API_VERSION", "v18.0")
|
||||
|
||||
if not business_account_id or not access_token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="WhatsApp master account not configured. Set WHATSAPP_BUSINESS_ACCOUNT_ID and WHATSAPP_ACCESS_TOKEN environment variables."
|
||||
)
|
||||
|
||||
try:
|
||||
# Fetch phone numbers from Meta Graph API
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.get(
|
||||
f"https://graph.facebook.com/{api_version}/{business_account_id}/phone_numbers",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
params={
|
||||
"fields": "id,display_phone_number,verified_name,quality_rating"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
error_data = response.json()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Meta API error: {error_data.get('error', {}).get('message', 'Unknown error')}"
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
phone_numbers = data.get("data", [])
|
||||
|
||||
return [
|
||||
WhatsAppPhoneNumberInfo(
|
||||
id=phone.get("id"),
|
||||
display_phone_number=phone.get("display_phone_number"),
|
||||
verified_name=phone.get("verified_name", ""),
|
||||
quality_rating=phone.get("quality_rating", "UNKNOWN")
|
||||
)
|
||||
for phone in phone_numbers
|
||||
]
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_502_BAD_GATEWAY,
|
||||
detail=f"Failed to fetch phone numbers from Meta: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/admin/whatsapp/tenants",
|
||||
response_model=List[TenantWhatsAppStatus],
|
||||
summary="List all tenants with WhatsApp status",
|
||||
description="Get WhatsApp configuration status for all tenants"
|
||||
)
|
||||
async def list_tenant_whatsapp_status(
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
List all tenants with their WhatsApp configuration status
|
||||
|
||||
Returns:
|
||||
- tenant_id: Tenant UUID
|
||||
- tenant_name: Tenant name
|
||||
- whatsapp_enabled: Whether WhatsApp is enabled
|
||||
- phone_number_id: Assigned phone number ID (if any)
|
||||
- display_phone_number: Display format (if any)
|
||||
"""
|
||||
# Query all tenants with their settings
|
||||
query = select(Tenant, TenantSettings).outerjoin(
|
||||
TenantSettings,
|
||||
Tenant.id == TenantSettings.tenant_id
|
||||
)
|
||||
|
||||
result = await db.execute(query)
|
||||
rows = result.all()
|
||||
|
||||
tenant_statuses = []
|
||||
for tenant, settings in rows:
|
||||
notification_settings = settings.notification_settings if settings else {}
|
||||
|
||||
tenant_statuses.append(
|
||||
TenantWhatsAppStatus(
|
||||
tenant_id=tenant.id,
|
||||
tenant_name=tenant.name,
|
||||
whatsapp_enabled=notification_settings.get("whatsapp_enabled", False),
|
||||
phone_number_id=notification_settings.get("whatsapp_phone_number_id", ""),
|
||||
display_phone_number=notification_settings.get("whatsapp_display_phone_number", "")
|
||||
)
|
||||
)
|
||||
|
||||
return tenant_statuses
|
||||
|
||||
|
||||
@router.post(
|
||||
"/admin/whatsapp/tenants/{tenant_id}/assign-phone",
|
||||
response_model=AssignPhoneNumberResponse,
|
||||
summary="Assign phone number to tenant",
|
||||
description="Assign a WhatsApp phone number from the master account to a tenant"
|
||||
)
|
||||
async def assign_phone_number_to_tenant(
|
||||
tenant_id: UUID,
|
||||
request: AssignPhoneNumberRequest,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Assign a WhatsApp phone number to a tenant
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
- **phone_number_id**: Meta Phone Number ID from master account
|
||||
- **display_phone_number**: Human-readable format (e.g., "+34 612 345 678")
|
||||
|
||||
This will:
|
||||
1. Validate the tenant exists
|
||||
2. Check if phone number is already assigned to another tenant
|
||||
3. Update tenant's notification settings
|
||||
4. Enable WhatsApp for the tenant
|
||||
"""
|
||||
# Verify tenant exists
|
||||
tenant_query = select(Tenant).where(Tenant.id == tenant_id)
|
||||
tenant_result = await db.execute(tenant_query)
|
||||
tenant = tenant_result.scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Tenant {tenant_id} not found"
|
||||
)
|
||||
|
||||
# Check if phone number is already assigned to another tenant
|
||||
settings_query = select(TenantSettings).where(TenantSettings.tenant_id != tenant_id)
|
||||
settings_result = await db.execute(settings_query)
|
||||
all_settings = settings_result.scalars().all()
|
||||
|
||||
for settings in all_settings:
|
||||
notification_settings = settings.notification_settings or {}
|
||||
if notification_settings.get("whatsapp_phone_number_id") == request.phone_number_id:
|
||||
# Get the other tenant's name
|
||||
other_tenant_query = select(Tenant).where(Tenant.id == settings.tenant_id)
|
||||
other_tenant_result = await db.execute(other_tenant_query)
|
||||
other_tenant = other_tenant_result.scalar_one_or_none()
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Phone number {request.display_phone_number} is already assigned to tenant '{other_tenant.name if other_tenant else 'Unknown'}'"
|
||||
)
|
||||
|
||||
# Get or create tenant settings
|
||||
settings_query = select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
|
||||
settings_result = await db.execute(settings_query)
|
||||
settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
# Create default settings
|
||||
settings = TenantSettings(
|
||||
tenant_id=tenant_id,
|
||||
**TenantSettings.get_default_settings()
|
||||
)
|
||||
db.add(settings)
|
||||
|
||||
# Update notification settings
|
||||
notification_settings = settings.notification_settings or {}
|
||||
notification_settings["whatsapp_enabled"] = True
|
||||
notification_settings["whatsapp_phone_number_id"] = request.phone_number_id
|
||||
notification_settings["whatsapp_display_phone_number"] = request.display_phone_number
|
||||
|
||||
settings.notification_settings = notification_settings
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(settings)
|
||||
|
||||
return AssignPhoneNumberResponse(
|
||||
success=True,
|
||||
message=f"Phone number {request.display_phone_number} assigned to tenant '{tenant.name}'",
|
||||
tenant_id=tenant_id,
|
||||
phone_number_id=request.phone_number_id,
|
||||
display_phone_number=request.display_phone_number
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/admin/whatsapp/tenants/{tenant_id}/unassign-phone",
|
||||
response_model=AssignPhoneNumberResponse,
|
||||
summary="Unassign phone number from tenant",
|
||||
description="Remove WhatsApp phone number assignment from a tenant"
|
||||
)
|
||||
async def unassign_phone_number_from_tenant(
|
||||
tenant_id: UUID,
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Unassign WhatsApp phone number from a tenant
|
||||
|
||||
- **tenant_id**: UUID of the tenant
|
||||
|
||||
This will:
|
||||
1. Clear the phone number assignment
|
||||
2. Disable WhatsApp for the tenant
|
||||
"""
|
||||
# Get tenant settings
|
||||
settings_query = select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
|
||||
settings_result = await db.execute(settings_query)
|
||||
settings = settings_result.scalar_one_or_none()
|
||||
|
||||
if not settings:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Settings not found for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Get current values for response
|
||||
notification_settings = settings.notification_settings or {}
|
||||
old_phone_id = notification_settings.get("whatsapp_phone_number_id", "")
|
||||
old_display_phone = notification_settings.get("whatsapp_display_phone_number", "")
|
||||
|
||||
# Update notification settings
|
||||
notification_settings["whatsapp_enabled"] = False
|
||||
notification_settings["whatsapp_phone_number_id"] = ""
|
||||
notification_settings["whatsapp_display_phone_number"] = ""
|
||||
|
||||
settings.notification_settings = notification_settings
|
||||
|
||||
await db.commit()
|
||||
|
||||
return AssignPhoneNumberResponse(
|
||||
success=True,
|
||||
message=f"Phone number unassigned from tenant",
|
||||
tenant_id=tenant_id,
|
||||
phone_number_id=old_phone_id,
|
||||
display_phone_number=old_display_phone
|
||||
)
|
||||
0
services/tenant/app/core/__init__.py
Normal file
0
services/tenant/app/core/__init__.py
Normal file
133
services/tenant/app/core/config.py
Normal file
133
services/tenant/app/core/config.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# ================================================================
|
||||
# TENANT SERVICE CONFIGURATION
|
||||
# services/tenant/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Tenant service configuration
|
||||
Multi-tenant management and subscription handling
|
||||
"""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
from typing import Dict, Tuple, ClassVar
|
||||
|
||||
class TenantSettings(BaseServiceSettings):
|
||||
"""Tenant service specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Tenant Service"
|
||||
SERVICE_NAME: str = "tenant-service"
|
||||
DESCRIPTION: str = "Multi-tenant management and subscription service"
|
||||
|
||||
# Database configuration (secure approach - build from components)
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""Build database URL from secure components"""
|
||||
# Try complete URL first (for backward compatibility)
|
||||
complete_url = os.getenv("TENANT_DATABASE_URL")
|
||||
if complete_url:
|
||||
return complete_url
|
||||
|
||||
# Build from components (secure approach)
|
||||
user = os.getenv("TENANT_DB_USER", "tenant_user")
|
||||
password = os.getenv("TENANT_DB_PASSWORD", "tenant_pass123")
|
||||
host = os.getenv("TENANT_DB_HOST", "localhost")
|
||||
port = os.getenv("TENANT_DB_PORT", "5432")
|
||||
name = os.getenv("TENANT_DB_NAME", "tenant_db")
|
||||
|
||||
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
|
||||
|
||||
# Redis Database (dedicated for tenant data)
|
||||
REDIS_DB: int = 4
|
||||
|
||||
# Service URLs for usage tracking
|
||||
RECIPES_SERVICE_URL: str = os.getenv("RECIPES_SERVICE_URL", "http://recipes-service:8004")
|
||||
SUPPLIERS_SERVICE_URL: str = os.getenv("SUPPLIERS_SERVICE_URL", "http://suppliers-service:8005")
|
||||
|
||||
# Subscription Plans
|
||||
DEFAULT_PLAN: str = os.getenv("DEFAULT_PLAN", "basic")
|
||||
TRIAL_PERIOD_DAYS: int = int(os.getenv("TRIAL_PERIOD_DAYS", "0"))
|
||||
|
||||
# Plan Limits
|
||||
BASIC_PLAN_LOCATIONS: int = int(os.getenv("BASIC_PLAN_LOCATIONS", "1"))
|
||||
BASIC_PLAN_PREDICTIONS_PER_DAY: int = int(os.getenv("BASIC_PLAN_PREDICTIONS_PER_DAY", "100"))
|
||||
BASIC_PLAN_DATA_RETENTION_DAYS: int = int(os.getenv("BASIC_PLAN_DATA_RETENTION_DAYS", "90"))
|
||||
|
||||
PREMIUM_PLAN_LOCATIONS: int = int(os.getenv("PREMIUM_PLAN_LOCATIONS", "5"))
|
||||
PREMIUM_PLAN_PREDICTIONS_PER_DAY: int = int(os.getenv("PREMIUM_PLAN_PREDICTIONS_PER_DAY", "1000"))
|
||||
PREMIUM_PLAN_DATA_RETENTION_DAYS: int = int(os.getenv("PREMIUM_PLAN_DATA_RETENTION_DAYS", "365"))
|
||||
|
||||
ENTERPRISE_PLAN_LOCATIONS: int = int(os.getenv("ENTERPRISE_PLAN_LOCATIONS", "50"))
|
||||
ENTERPRISE_PLAN_PREDICTIONS_PER_DAY: int = int(os.getenv("ENTERPRISE_PLAN_PREDICTIONS_PER_DAY", "10000"))
|
||||
ENTERPRISE_PLAN_DATA_RETENTION_DAYS: int = int(os.getenv("ENTERPRISE_PLAN_DATA_RETENTION_DAYS", "1095"))
|
||||
|
||||
# Billing Configuration
|
||||
BILLING_ENABLED: bool = os.getenv("BILLING_ENABLED", "false").lower() == "true"
|
||||
BILLING_CURRENCY: str = os.getenv("BILLING_CURRENCY", "EUR")
|
||||
BILLING_CYCLE_DAYS: int = int(os.getenv("BILLING_CYCLE_DAYS", "30"))
|
||||
|
||||
# Stripe Proration Configuration
|
||||
DEFAULT_PRORATION_BEHAVIOR: str = os.getenv("DEFAULT_PRORATION_BEHAVIOR", "create_prorations")
|
||||
UPGRADE_PRORATION_BEHAVIOR: str = os.getenv("UPGRADE_PRORATION_BEHAVIOR", "create_prorations")
|
||||
DOWNGRADE_PRORATION_BEHAVIOR: str = os.getenv("DOWNGRADE_PRORATION_BEHAVIOR", "none")
|
||||
BILLING_CYCLE_CHANGE_PRORATION: str = os.getenv("BILLING_CYCLE_CHANGE_PRORATION", "create_prorations")
|
||||
|
||||
# Stripe Subscription Update Settings
|
||||
STRIPE_BILLING_CYCLE_ANCHOR: str = os.getenv("STRIPE_BILLING_CYCLE_ANCHOR", "unchanged")
|
||||
STRIPE_PAYMENT_BEHAVIOR: str = os.getenv("STRIPE_PAYMENT_BEHAVIOR", "error_if_incomplete")
|
||||
ALLOW_IMMEDIATE_SUBSCRIPTION_CHANGES: bool = os.getenv("ALLOW_IMMEDIATE_SUBSCRIPTION_CHANGES", "true").lower() == "true"
|
||||
|
||||
# Resource Limits
|
||||
MAX_API_CALLS_PER_MINUTE: int = int(os.getenv("MAX_API_CALLS_PER_MINUTE", "100"))
|
||||
MAX_STORAGE_MB: int = int(os.getenv("MAX_STORAGE_MB", "1024"))
|
||||
MAX_CONCURRENT_REQUESTS: int = int(os.getenv("MAX_CONCURRENT_REQUESTS", "10"))
|
||||
|
||||
# Spanish Business Configuration
|
||||
SPANISH_TAX_RATE: float = float(os.getenv("SPANISH_TAX_RATE", "0.21")) # IVA 21%
|
||||
INVOICE_LANGUAGE: str = os.getenv("INVOICE_LANGUAGE", "es")
|
||||
SUPPORT_EMAIL: str = os.getenv("SUPPORT_EMAIL", "soporte@bakeryforecast.es")
|
||||
|
||||
# Onboarding
|
||||
ONBOARDING_ENABLED: bool = os.getenv("ONBOARDING_ENABLED", "true").lower() == "true"
|
||||
DEMO_DATA_ENABLED: bool = os.getenv("DEMO_DATA_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Compliance
|
||||
GDPR_COMPLIANCE_ENABLED: bool = True
|
||||
DATA_EXPORT_ENABLED: bool = True
|
||||
DATA_DELETION_ENABLED: bool = True
|
||||
|
||||
# Stripe Payment Configuration
|
||||
STRIPE_PUBLISHABLE_KEY: str = os.getenv("STRIPE_PUBLISHABLE_KEY", "")
|
||||
STRIPE_SECRET_KEY: str = os.getenv("STRIPE_SECRET_KEY", "")
|
||||
STRIPE_WEBHOOK_SECRET: str = os.getenv("STRIPE_WEBHOOK_SECRET", "")
|
||||
|
||||
# Stripe Price IDs for subscription plans
|
||||
STARTER_MONTHLY_PRICE_ID: str = os.getenv("STARTER_MONTHLY_PRICE_ID", "price_1Sp0p3IzCdnBmAVT2Gs7z5np")
|
||||
STARTER_YEARLY_PRICE_ID: str = os.getenv("STARTER_YEARLY_PRICE_ID", "price_1Sp0twIzCdnBmAVTD1lNLedx")
|
||||
PROFESSIONAL_MONTHLY_PRICE_ID: str = os.getenv("PROFESSIONAL_MONTHLY_PRICE_ID", "price_1Sp0w7IzCdnBmAVTp0Jxhh1u")
|
||||
PROFESSIONAL_YEARLY_PRICE_ID: str = os.getenv("PROFESSIONAL_YEARLY_PRICE_ID", "price_1Sp0yAIzCdnBmAVTLoGl4QCb")
|
||||
ENTERPRISE_MONTHLY_PRICE_ID: str = os.getenv("ENTERPRISE_MONTHLY_PRICE_ID", "price_1Sp0zAIzCdnBmAVTXpApF7YO")
|
||||
ENTERPRISE_YEARLY_PRICE_ID: str = os.getenv("ENTERPRISE_YEARLY_PRICE_ID", "price_1Sp15mIzCdnBmAVTuxffMpV5")
|
||||
|
||||
# Price ID mapping for easy lookup
|
||||
STRIPE_PRICE_ID_MAPPING: ClassVar[Dict[Tuple[str, str], str]] = {
|
||||
('starter', 'monthly'): STARTER_MONTHLY_PRICE_ID,
|
||||
('starter', 'yearly'): STARTER_YEARLY_PRICE_ID,
|
||||
('professional', 'monthly'): PROFESSIONAL_MONTHLY_PRICE_ID,
|
||||
('professional', 'yearly'): PROFESSIONAL_YEARLY_PRICE_ID,
|
||||
('enterprise', 'monthly'): ENTERPRISE_MONTHLY_PRICE_ID,
|
||||
('enterprise', 'yearly'): ENTERPRISE_YEARLY_PRICE_ID,
|
||||
}
|
||||
|
||||
# ============================================================
|
||||
# SCHEDULER CONFIGURATION
|
||||
# ============================================================
|
||||
|
||||
# Usage tracking scheduler
|
||||
USAGE_TRACKING_ENABLED: bool = os.getenv("USAGE_TRACKING_ENABLED", "true").lower() == "true"
|
||||
USAGE_TRACKING_HOUR: int = int(os.getenv("USAGE_TRACKING_HOUR", "2"))
|
||||
USAGE_TRACKING_MINUTE: int = int(os.getenv("USAGE_TRACKING_MINUTE", "0"))
|
||||
USAGE_TRACKING_TIMEZONE: str = os.getenv("USAGE_TRACKING_TIMEZONE", "UTC")
|
||||
|
||||
settings = TenantSettings()
|
||||
12
services/tenant/app/core/database.py
Normal file
12
services/tenant/app/core/database.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Database configuration for tenant service
|
||||
"""
|
||||
|
||||
from shared.database.base import DatabaseManager
|
||||
from app.core.config import settings
|
||||
|
||||
# Initialize database manager
|
||||
database_manager = DatabaseManager(settings.DATABASE_URL, service_name="tenant-service")
|
||||
|
||||
# Alias for convenience
|
||||
get_db = database_manager.get_db
|
||||
127
services/tenant/app/jobs/startup_seeder.py
Normal file
127
services/tenant/app/jobs/startup_seeder.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Startup seeder for tenant service.
|
||||
Seeds initial data (like pilot coupons) on service startup.
|
||||
All operations are idempotent - safe to run multiple times.
|
||||
"""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
|
||||
import structlog
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.coupon import CouponModel
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def ensure_pilot_coupon(session: AsyncSession) -> Optional[CouponModel]:
|
||||
"""
|
||||
Ensure the PILOT2025 coupon exists in the database.
|
||||
|
||||
This coupon provides 3 months (90 days) free trial extension
|
||||
for the first 20 pilot customers.
|
||||
|
||||
This function is idempotent - it will not create duplicates.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
|
||||
Returns:
|
||||
The coupon model (existing or newly created), or None if disabled
|
||||
"""
|
||||
# Check if pilot mode is enabled via environment variable
|
||||
pilot_mode_enabled = os.getenv("VITE_PILOT_MODE_ENABLED", "true").lower() == "true"
|
||||
|
||||
if not pilot_mode_enabled:
|
||||
logger.info("Pilot mode is disabled, skipping coupon seeding")
|
||||
return None
|
||||
|
||||
coupon_code = os.getenv("VITE_PILOT_COUPON_CODE", "PILOT2025")
|
||||
trial_months = int(os.getenv("VITE_PILOT_TRIAL_MONTHS", "3"))
|
||||
max_redemptions = int(os.getenv("PILOT_MAX_REDEMPTIONS", "20"))
|
||||
|
||||
# Check if coupon already exists
|
||||
result = await session.execute(
|
||||
select(CouponModel).where(CouponModel.code == coupon_code)
|
||||
)
|
||||
existing_coupon = result.scalars().first()
|
||||
|
||||
if existing_coupon:
|
||||
logger.info(
|
||||
"Pilot coupon already exists",
|
||||
code=coupon_code,
|
||||
current_redemptions=existing_coupon.current_redemptions,
|
||||
max_redemptions=existing_coupon.max_redemptions,
|
||||
active=existing_coupon.active
|
||||
)
|
||||
return existing_coupon
|
||||
|
||||
# Create new coupon
|
||||
now = datetime.now(timezone.utc)
|
||||
valid_until = now + timedelta(days=180) # Valid for 6 months
|
||||
trial_days = trial_months * 30 # Approximate days
|
||||
|
||||
coupon = CouponModel(
|
||||
id=uuid.uuid4(),
|
||||
code=coupon_code,
|
||||
discount_type="trial_extension",
|
||||
discount_value=trial_days,
|
||||
max_redemptions=max_redemptions,
|
||||
current_redemptions=0,
|
||||
valid_from=now,
|
||||
valid_until=valid_until,
|
||||
active=True,
|
||||
created_at=now,
|
||||
extra_data={
|
||||
"program": "pilot_launch_2025",
|
||||
"description": f"Programa piloto - {trial_months} meses gratis para los primeros {max_redemptions} clientes",
|
||||
"terms": "Válido para nuevos registros únicamente. Un cupón por cliente."
|
||||
}
|
||||
)
|
||||
|
||||
session.add(coupon)
|
||||
await session.commit()
|
||||
await session.refresh(coupon)
|
||||
|
||||
logger.info(
|
||||
"Pilot coupon created successfully",
|
||||
code=coupon_code,
|
||||
type="Trial Extension",
|
||||
value=f"{trial_days} days ({trial_months} months)",
|
||||
max_redemptions=max_redemptions,
|
||||
valid_until=valid_until.isoformat(),
|
||||
id=str(coupon.id)
|
||||
)
|
||||
|
||||
return coupon
|
||||
|
||||
|
||||
async def run_startup_seeders(database_manager) -> None:
|
||||
"""
|
||||
Run all startup seeders.
|
||||
|
||||
This function is called during service startup to ensure
|
||||
required seed data exists in the database.
|
||||
|
||||
Args:
|
||||
database_manager: The database manager instance
|
||||
"""
|
||||
logger.info("Running startup seeders...")
|
||||
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Seed pilot coupon
|
||||
await ensure_pilot_coupon(session)
|
||||
|
||||
logger.info("Startup seeders completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail startup - seed data is not critical
|
||||
logger.warning(
|
||||
"Startup seeder encountered an error (non-fatal)",
|
||||
error=str(e)
|
||||
)
|
||||
103
services/tenant/app/jobs/subscription_downgrade.py
Normal file
103
services/tenant/app/jobs/subscription_downgrade.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Background job to process subscription downgrades at period end
|
||||
|
||||
Runs periodically to check for subscriptions with:
|
||||
- status = 'pending_cancellation'
|
||||
- cancellation_effective_date <= now()
|
||||
|
||||
Converts them to 'inactive' status
|
||||
"""
|
||||
|
||||
import structlog
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_async_session_factory
|
||||
from app.models.tenants import Subscription
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
async def process_pending_cancellations():
|
||||
"""
|
||||
Process all subscriptions that have reached their cancellation_effective_date
|
||||
"""
|
||||
async_session_factory = get_async_session_factory()
|
||||
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
query = select(Subscription).where(
|
||||
Subscription.status == 'pending_cancellation',
|
||||
Subscription.cancellation_effective_date <= datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
subscriptions_to_downgrade = result.scalars().all()
|
||||
|
||||
downgraded_count = 0
|
||||
|
||||
for subscription in subscriptions_to_downgrade:
|
||||
subscription.status = 'inactive'
|
||||
subscription.plan = 'free'
|
||||
subscription.monthly_price = 0.0
|
||||
|
||||
logger.info(
|
||||
"subscription_downgraded_to_inactive",
|
||||
tenant_id=str(subscription.tenant_id),
|
||||
previous_plan=subscription.plan,
|
||||
cancellation_effective_date=subscription.cancellation_effective_date.isoformat()
|
||||
)
|
||||
|
||||
downgraded_count += 1
|
||||
|
||||
if downgraded_count > 0:
|
||||
await session.commit()
|
||||
logger.info(
|
||||
"subscriptions_downgraded",
|
||||
count=downgraded_count
|
||||
)
|
||||
else:
|
||||
logger.debug("no_subscriptions_to_downgrade")
|
||||
|
||||
return downgraded_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"subscription_downgrade_job_failed",
|
||||
error=str(e)
|
||||
)
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
async def run_subscription_downgrade_job():
|
||||
"""
|
||||
Main entry point for the subscription downgrade job
|
||||
Runs in a loop with configurable interval
|
||||
"""
|
||||
interval_seconds = 3600 # Run every hour
|
||||
|
||||
logger.info("subscription_downgrade_job_started", interval_seconds=interval_seconds)
|
||||
|
||||
while True:
|
||||
try:
|
||||
downgraded_count = await process_pending_cancellations()
|
||||
|
||||
logger.info(
|
||||
"subscription_downgrade_job_completed",
|
||||
downgraded_count=downgraded_count
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"subscription_downgrade_job_error",
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
await asyncio.sleep(interval_seconds)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_subscription_downgrade_job())
|
||||
247
services/tenant/app/jobs/usage_tracking_scheduler.py
Normal file
247
services/tenant/app/jobs/usage_tracking_scheduler.py
Normal file
@@ -0,0 +1,247 @@
|
||||
"""
|
||||
Usage Tracking Scheduler
|
||||
Tracks daily usage snapshots for all active tenants
|
||||
"""
|
||||
import asyncio
|
||||
import structlog
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy import select, func
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class UsageTrackingScheduler:
|
||||
"""Scheduler for daily usage tracking"""
|
||||
|
||||
def __init__(self, db_manager, redis_client, config):
|
||||
self.db_manager = db_manager
|
||||
self.redis = redis_client
|
||||
self.config = config
|
||||
self._running = False
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
|
||||
def seconds_until_target_time(self) -> float:
|
||||
"""Calculate seconds until next target time (default 2am UTC)"""
|
||||
now = datetime.now(timezone.utc)
|
||||
target = now.replace(
|
||||
hour=self.config.USAGE_TRACKING_HOUR,
|
||||
minute=self.config.USAGE_TRACKING_MINUTE,
|
||||
second=0,
|
||||
microsecond=0
|
||||
)
|
||||
|
||||
if target <= now:
|
||||
target += timedelta(days=1)
|
||||
|
||||
return (target - now).total_seconds()
|
||||
|
||||
async def _get_tenant_usage(self, session, tenant_id: str) -> dict:
|
||||
"""Get current usage counts for a tenant"""
|
||||
usage = {}
|
||||
|
||||
try:
|
||||
# Import models here to avoid circular imports
|
||||
from app.models.tenants import TenantMember
|
||||
|
||||
# Users count
|
||||
result = await session.execute(
|
||||
select(func.count()).select_from(TenantMember).where(TenantMember.tenant_id == tenant_id)
|
||||
)
|
||||
usage['users'] = result.scalar() or 0
|
||||
|
||||
# Get counts from other services via their databases
|
||||
# For now, we'll track basic metrics. More metrics can be added by querying other service databases
|
||||
|
||||
# Training jobs today (from Redis quota tracking)
|
||||
today_key = f"quota:training_jobs:{tenant_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
|
||||
training_count = await self.redis.get(today_key)
|
||||
usage['training_jobs'] = int(training_count) if training_count else 0
|
||||
|
||||
# Forecasts today (from Redis quota tracking)
|
||||
forecast_key = f"quota:forecasts:{tenant_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
|
||||
forecast_count = await self.redis.get(forecast_key)
|
||||
usage['forecasts'] = int(forecast_count) if forecast_count else 0
|
||||
|
||||
# API calls this hour (from Redis quota tracking)
|
||||
hour_key = f"quota:api_calls:{tenant_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d-%H')}"
|
||||
api_count = await self.redis.get(hour_key)
|
||||
usage['api_calls'] = int(api_count) if api_count else 0
|
||||
|
||||
# Storage (placeholder - implement based on file storage system)
|
||||
usage['storage'] = 0.0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting usage for tenant", tenant_id=tenant_id, error=str(e), exc_info=True)
|
||||
return {}
|
||||
|
||||
return usage
|
||||
|
||||
async def _track_metrics(self, tenant_id: str, usage: dict):
|
||||
"""Track metrics to Redis"""
|
||||
from app.api.usage_forecast import track_usage_snapshot
|
||||
|
||||
for metric_name, value in usage.items():
|
||||
try:
|
||||
await track_usage_snapshot(tenant_id, metric_name, value)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to track metric",
|
||||
tenant_id=tenant_id,
|
||||
metric=metric_name,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
async def _run_cycle(self):
|
||||
"""Execute one tracking cycle"""
|
||||
start_time = datetime.now(timezone.utc)
|
||||
logger.info("Starting daily usage tracking cycle")
|
||||
|
||||
try:
|
||||
async with self.db_manager.get_session() as session:
|
||||
# Import models here to avoid circular imports
|
||||
from app.models.tenants import Tenant, Subscription
|
||||
from sqlalchemy import select
|
||||
|
||||
# Get all active tenants
|
||||
result = await session.execute(
|
||||
select(Tenant, Subscription)
|
||||
.join(Subscription, Tenant.id == Subscription.tenant_id)
|
||||
.where(Tenant.is_active == True)
|
||||
.where(Subscription.status.in_(['active', 'trialing', 'cancelled']))
|
||||
)
|
||||
|
||||
tenants_data = result.all()
|
||||
total_tenants = len(tenants_data)
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
logger.info(f"Found {total_tenants} active tenants to track")
|
||||
|
||||
# Process each tenant
|
||||
for tenant, subscription in tenants_data:
|
||||
try:
|
||||
usage = await self._get_tenant_usage(session, tenant.id)
|
||||
|
||||
if usage:
|
||||
await self._track_metrics(tenant.id, usage)
|
||||
success_count += 1
|
||||
else:
|
||||
logger.warning(
|
||||
"No usage data available for tenant",
|
||||
tenant_id=tenant.id
|
||||
)
|
||||
error_count += 1
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error tracking tenant usage",
|
||||
tenant_id=tenant.id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
error_count += 1
|
||||
|
||||
end_time = datetime.now(timezone.utc)
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
|
||||
logger.info(
|
||||
"Daily usage tracking completed",
|
||||
total_tenants=total_tenants,
|
||||
success=success_count,
|
||||
errors=error_count,
|
||||
duration_seconds=duration
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Usage tracking cycle failed", error=str(e), exc_info=True)
|
||||
|
||||
async def _run_scheduler(self):
|
||||
"""Main scheduler loop"""
|
||||
logger.info(
|
||||
"Usage tracking scheduler loop started",
|
||||
target_hour=self.config.USAGE_TRACKING_HOUR,
|
||||
target_minute=self.config.USAGE_TRACKING_MINUTE
|
||||
)
|
||||
|
||||
# Initial delay to target time
|
||||
delay = self.seconds_until_target_time()
|
||||
logger.info(f"Waiting {delay/3600:.2f} hours until next run at {self.config.USAGE_TRACKING_HOUR:02d}:{self.config.USAGE_TRACKING_MINUTE:02d} UTC")
|
||||
|
||||
try:
|
||||
await asyncio.sleep(delay)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Scheduler cancelled during initial delay")
|
||||
return
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
await self._run_cycle()
|
||||
except Exception as e:
|
||||
logger.error("Scheduler cycle error", error=str(e), exc_info=True)
|
||||
|
||||
# Wait 24 hours until next run
|
||||
try:
|
||||
await asyncio.sleep(86400)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Scheduler cancelled during sleep")
|
||||
break
|
||||
|
||||
def start(self):
|
||||
"""Start the scheduler"""
|
||||
if not self.config.USAGE_TRACKING_ENABLED:
|
||||
logger.info("Usage tracking scheduler disabled by configuration")
|
||||
return
|
||||
|
||||
if self._running:
|
||||
logger.warning("Usage tracking scheduler already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._task = asyncio.create_task(self._run_scheduler())
|
||||
logger.info("Usage tracking scheduler started successfully")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the scheduler gracefully"""
|
||||
if not self._running:
|
||||
logger.debug("Scheduler not running, nothing to stop")
|
||||
return
|
||||
|
||||
logger.info("Stopping usage tracking scheduler")
|
||||
self._running = False
|
||||
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Scheduler task cancelled successfully")
|
||||
|
||||
logger.info("Usage tracking scheduler stopped")
|
||||
|
||||
|
||||
# Global instance
|
||||
_scheduler: Optional[UsageTrackingScheduler] = None
|
||||
|
||||
|
||||
async def start_scheduler(db_manager, redis_client, config):
|
||||
"""Start the usage tracking scheduler"""
|
||||
global _scheduler
|
||||
|
||||
try:
|
||||
_scheduler = UsageTrackingScheduler(db_manager, redis_client, config)
|
||||
_scheduler.start()
|
||||
logger.info("Usage tracking scheduler module initialized")
|
||||
except Exception as e:
|
||||
logger.error("Failed to start usage tracking scheduler", error=str(e), exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
async def stop_scheduler():
|
||||
"""Stop the usage tracking scheduler"""
|
||||
global _scheduler
|
||||
|
||||
if _scheduler:
|
||||
await _scheduler.stop()
|
||||
_scheduler = None
|
||||
logger.info("Usage tracking scheduler module stopped")
|
||||
175
services/tenant/app/main.py
Normal file
175
services/tenant/app/main.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# services/tenant/app/main.py
|
||||
"""
|
||||
Tenant Service FastAPI application
|
||||
"""
|
||||
|
||||
from fastapi import FastAPI
|
||||
from sqlalchemy import text
|
||||
from app.core.config import settings
|
||||
from app.core.database import database_manager
|
||||
from app.api import tenants, tenant_members, tenant_operations, webhooks, plans, subscription, tenant_settings, whatsapp_admin, usage_forecast, enterprise_upgrade, tenant_locations, tenant_hierarchy, internal_demo, network_alerts, onboarding
|
||||
from shared.service_base import StandardFastAPIService
|
||||
from shared.monitoring.system_metrics import SystemMetricsCollector
|
||||
|
||||
|
||||
class TenantService(StandardFastAPIService):
|
||||
"""Tenant Service with standardized setup"""
|
||||
|
||||
expected_migration_version = "00001"
|
||||
|
||||
async def verify_migrations(self):
|
||||
"""Verify database schema matches the latest migrations."""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
# Check if alembic_version table exists
|
||||
result = await session.execute(text("""
|
||||
SELECT EXISTS (
|
||||
SELECT FROM information_schema.tables
|
||||
WHERE table_schema = 'public'
|
||||
AND table_name = 'alembic_version'
|
||||
)
|
||||
"""))
|
||||
table_exists = result.scalar()
|
||||
|
||||
if table_exists:
|
||||
# If table exists, check the version
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
# For now, just log the version instead of strict checking to avoid startup failures
|
||||
self.logger.info(f"Migration verification successful: {version}")
|
||||
else:
|
||||
# If table doesn't exist, migrations might not have run yet
|
||||
# This is OK - the migration job should create it
|
||||
self.logger.warning("alembic_version table does not exist yet - migrations may not have run")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Migration verification failed (this may be expected during initial setup): {e}")
|
||||
|
||||
def __init__(self):
|
||||
# Define expected database tables for health checks
|
||||
tenant_expected_tables = ['tenants', 'tenant_members', 'subscriptions']
|
||||
# Note: api_prefix is empty because RouteBuilder already includes /api/v1
|
||||
super().__init__(
|
||||
service_name="tenant-service",
|
||||
app_name="Tenant Management Service",
|
||||
description="Multi-tenant bakery management service",
|
||||
version="1.0.0",
|
||||
log_level=settings.LOG_LEVEL,
|
||||
api_prefix="",
|
||||
database_manager=database_manager,
|
||||
expected_tables=tenant_expected_tables
|
||||
)
|
||||
|
||||
async def on_startup(self, app: FastAPI):
|
||||
"""Custom startup logic for tenant service"""
|
||||
# Verify migrations first
|
||||
await self.verify_migrations()
|
||||
|
||||
# Import models to ensure they're registered with SQLAlchemy
|
||||
from app.models.tenants import Tenant, TenantMember, Subscription
|
||||
from app.models.tenant_settings import TenantSettings
|
||||
self.logger.info("Tenant models imported successfully")
|
||||
|
||||
# Initialize Redis
|
||||
from shared.redis_utils import initialize_redis, get_redis_client
|
||||
await initialize_redis(settings.REDIS_URL, db=settings.REDIS_DB, max_connections=20)
|
||||
redis_client = await get_redis_client()
|
||||
self.logger.info("Redis initialized successfully")
|
||||
|
||||
# Initialize system metrics collection
|
||||
system_metrics = SystemMetricsCollector("tenant")
|
||||
self.logger.info("System metrics collection started")
|
||||
|
||||
# Start usage tracking scheduler
|
||||
from app.jobs.usage_tracking_scheduler import start_scheduler
|
||||
await start_scheduler(self.database_manager, redis_client, settings)
|
||||
self.logger.info("Usage tracking scheduler started")
|
||||
|
||||
# Run startup seeders (pilot coupon, etc.)
|
||||
from app.jobs.startup_seeder import run_startup_seeders
|
||||
await run_startup_seeders(self.database_manager)
|
||||
self.logger.info("Startup seeders completed")
|
||||
|
||||
async def on_shutdown(self, app: FastAPI):
|
||||
"""Custom shutdown logic for tenant service"""
|
||||
# Stop usage tracking scheduler
|
||||
from app.jobs.usage_tracking_scheduler import stop_scheduler
|
||||
await stop_scheduler()
|
||||
self.logger.info("Usage tracking scheduler stopped")
|
||||
|
||||
# Close Redis connection
|
||||
from shared.redis_utils import close_redis
|
||||
await close_redis()
|
||||
self.logger.info("Redis connection closed")
|
||||
|
||||
# Database cleanup is handled by the base class
|
||||
|
||||
def get_service_features(self):
|
||||
"""Return tenant-specific features"""
|
||||
return [
|
||||
"multi_tenant_management",
|
||||
"subscription_management",
|
||||
"tenant_isolation",
|
||||
"webhook_notifications",
|
||||
"member_management"
|
||||
]
|
||||
|
||||
def setup_custom_endpoints(self):
|
||||
"""Setup custom endpoints for tenant service"""
|
||||
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
|
||||
# The /metrics endpoint is not needed as metrics are pushed automatically
|
||||
# @self.app.get("/metrics")
|
||||
# async def metrics():
|
||||
# """Prometheus metrics endpoint"""
|
||||
# if self.metrics_collector:
|
||||
# return self.metrics_collector.get_metrics()
|
||||
# return {"metrics": "not_available"}
|
||||
|
||||
|
||||
# Create service instance
|
||||
service = TenantService()
|
||||
|
||||
# Create FastAPI app with standardized setup
|
||||
app = service.create_app(
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc"
|
||||
)
|
||||
|
||||
# Setup standard endpoints
|
||||
service.setup_standard_endpoints()
|
||||
|
||||
# Setup custom endpoints
|
||||
service.setup_custom_endpoints()
|
||||
|
||||
# Include routers
|
||||
service.add_router(plans.router, tags=["subscription-plans"]) # Public endpoint
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(subscription.router, tags=["subscription"])
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(usage_forecast.router, prefix="/api/v1", tags=["usage-forecast"]) # Usage forecasting & predictive analytics
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
# Register settings router BEFORE tenants router to ensure proper route matching
|
||||
service.add_router(tenant_settings.router, prefix="/api/v1/tenants", tags=["tenant-settings"])
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(whatsapp_admin.router, prefix="/api/v1", tags=["whatsapp-admin"]) # Admin WhatsApp management
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(tenants.router, tags=["tenants"])
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(tenant_members.router, tags=["tenant-members"])
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(tenant_operations.router, tags=["tenant-operations"])
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(webhooks.router, tags=["webhooks"])
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(enterprise_upgrade.router, tags=["enterprise"]) # Enterprise tier upgrade endpoints
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(tenant_locations.router, tags=["tenant-locations"]) # Tenant locations endpoints
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(tenant_hierarchy.router, tags=["tenant-hierarchy"]) # Tenant hierarchy endpoints
|
||||
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
|
||||
service.add_router(network_alerts.router, tags=["network-alerts"]) # Network alerts aggregation endpoints
|
||||
service.add_router(onboarding.router, tags=["onboarding"]) # Onboarding status endpoints
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
31
services/tenant/app/models/__init__.py
Normal file
31
services/tenant/app/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Tenant Service Models Package
|
||||
|
||||
Import all models to ensure they are registered with SQLAlchemy Base.
|
||||
"""
|
||||
|
||||
# Import AuditLog model for this service
|
||||
from shared.security import create_audit_log_model
|
||||
from shared.database.base import Base
|
||||
|
||||
# Create audit log model for this service
|
||||
AuditLog = create_audit_log_model(Base)
|
||||
|
||||
# Import all models to register them with the Base metadata
|
||||
from .tenants import Tenant, TenantMember, Subscription
|
||||
from .tenant_location import TenantLocation
|
||||
from .coupon import CouponModel, CouponRedemptionModel
|
||||
from .events import Event, EventTemplate
|
||||
|
||||
# List all models for easier access
|
||||
__all__ = [
|
||||
"Tenant",
|
||||
"TenantMember",
|
||||
"Subscription",
|
||||
"TenantLocation",
|
||||
"AuditLog",
|
||||
"CouponModel",
|
||||
"CouponRedemptionModel",
|
||||
"Event",
|
||||
"EventTemplate",
|
||||
]
|
||||
64
services/tenant/app/models/coupon.py
Normal file
64
services/tenant/app/models/coupon.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""
|
||||
SQLAlchemy models for coupon system
|
||||
"""
|
||||
from datetime import datetime
|
||||
from sqlalchemy import Column, String, Integer, Boolean, DateTime, ForeignKey, JSON, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
import uuid
|
||||
|
||||
from shared.database import Base
|
||||
|
||||
|
||||
class CouponModel(Base):
|
||||
"""Coupon configuration table"""
|
||||
__tablename__ = "coupons"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
code = Column(String(50), unique=True, nullable=False, index=True)
|
||||
discount_type = Column(String(20), nullable=False) # trial_extension, percentage, fixed_amount
|
||||
discount_value = Column(Integer, nullable=False) # Days/percentage/cents depending on type
|
||||
max_redemptions = Column(Integer, nullable=True) # None = unlimited
|
||||
current_redemptions = Column(Integer, nullable=False, default=0)
|
||||
valid_from = Column(DateTime(timezone=True), nullable=False)
|
||||
valid_until = Column(DateTime(timezone=True), nullable=True) # None = no expiry
|
||||
active = Column(Boolean, nullable=False, default=True)
|
||||
created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow)
|
||||
extra_data = Column(JSON, nullable=True) # Renamed from metadata to avoid SQLAlchemy reserved name
|
||||
|
||||
# Relationships
|
||||
redemptions = relationship("CouponRedemptionModel", back_populates="coupon")
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index('idx_coupon_code_active', 'code', 'active'),
|
||||
Index('idx_coupon_valid_dates', 'valid_from', 'valid_until'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Coupon(code='{self.code}', type='{self.discount_type}', value={self.discount_value})>"
|
||||
|
||||
|
||||
class CouponRedemptionModel(Base):
|
||||
"""Coupon redemption history table"""
|
||||
__tablename__ = "coupon_redemptions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(String(255), nullable=False, index=True)
|
||||
coupon_code = Column(String(50), ForeignKey('coupons.code'), nullable=False)
|
||||
redeemed_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow)
|
||||
discount_applied = Column(JSON, nullable=False) # Details of discount applied
|
||||
extra_data = Column(JSON, nullable=True) # Renamed from metadata to avoid SQLAlchemy reserved name
|
||||
|
||||
# Relationships
|
||||
coupon = relationship("CouponModel", back_populates="redemptions")
|
||||
|
||||
# Constraints
|
||||
__table_args__ = (
|
||||
Index('idx_redemption_tenant', 'tenant_id'),
|
||||
Index('idx_redemption_coupon', 'coupon_code'),
|
||||
Index('idx_redemption_tenant_coupon', 'tenant_id', 'coupon_code'), # Prevent duplicate redemptions
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<CouponRedemption(tenant_id='{self.tenant_id}', code='{self.coupon_code}')>"
|
||||
136
services/tenant/app/models/events.py
Normal file
136
services/tenant/app/models/events.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
Event Calendar Models
|
||||
Database models for tracking local events, promotions, and special occasions
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, Float, Date
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from shared.database.base import Base
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
|
||||
class Event(Base):
|
||||
"""
|
||||
Table to track events that affect bakery demand.
|
||||
|
||||
Events include:
|
||||
- Local events (festivals, markets, concerts)
|
||||
- Promotions and sales
|
||||
- Weather events (heat waves, storms)
|
||||
- School holidays and breaks
|
||||
- Special occasions
|
||||
"""
|
||||
__tablename__ = "events"
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Event information
|
||||
event_name = Column(String(500), nullable=False)
|
||||
event_type = Column(String(100), nullable=False, index=True) # promotion, festival, holiday, weather, school_break, sport_event, etc.
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Date and time
|
||||
event_date = Column(Date, nullable=False, index=True)
|
||||
start_time = Column(DateTime(timezone=True), nullable=True)
|
||||
end_time = Column(DateTime(timezone=True), nullable=True)
|
||||
is_all_day = Column(Boolean, default=True)
|
||||
|
||||
# Impact estimation
|
||||
expected_impact = Column(String(50), nullable=True) # low, medium, high, very_high
|
||||
impact_multiplier = Column(Float, nullable=True) # Expected demand multiplier (e.g., 1.5 = 50% increase)
|
||||
affected_product_categories = Column(String(500), nullable=True) # Comma-separated categories
|
||||
|
||||
# Location
|
||||
location = Column(String(500), nullable=True)
|
||||
is_local = Column(Boolean, default=True) # True if event is near bakery
|
||||
|
||||
# Status
|
||||
is_confirmed = Column(Boolean, default=False)
|
||||
is_recurring = Column(Boolean, default=False)
|
||||
recurrence_pattern = Column(String(200), nullable=True) # e.g., "weekly:monday", "monthly:first_saturday"
|
||||
|
||||
# Actual impact (filled after event)
|
||||
actual_impact_multiplier = Column(Float, nullable=True)
|
||||
actual_sales_increase_percent = Column(Float, nullable=True)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
created_by = Column(String(255), nullable=True)
|
||||
notes = Column(Text, nullable=True)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"tenant_id": str(self.tenant_id),
|
||||
"event_name": self.event_name,
|
||||
"event_type": self.event_type,
|
||||
"description": self.description,
|
||||
"event_date": self.event_date.isoformat() if self.event_date else None,
|
||||
"start_time": self.start_time.isoformat() if self.start_time else None,
|
||||
"end_time": self.end_time.isoformat() if self.end_time else None,
|
||||
"is_all_day": self.is_all_day,
|
||||
"expected_impact": self.expected_impact,
|
||||
"impact_multiplier": self.impact_multiplier,
|
||||
"affected_product_categories": self.affected_product_categories,
|
||||
"location": self.location,
|
||||
"is_local": self.is_local,
|
||||
"is_confirmed": self.is_confirmed,
|
||||
"is_recurring": self.is_recurring,
|
||||
"recurrence_pattern": self.recurrence_pattern,
|
||||
"actual_impact_multiplier": self.actual_impact_multiplier,
|
||||
"actual_sales_increase_percent": self.actual_sales_increase_percent,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
|
||||
"created_by": self.created_by,
|
||||
"notes": self.notes
|
||||
}
|
||||
|
||||
|
||||
class EventTemplate(Base):
|
||||
"""
|
||||
Template for recurring events.
|
||||
Allows easy creation of events based on patterns.
|
||||
"""
|
||||
__tablename__ = "event_templates"
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Template information
|
||||
template_name = Column(String(500), nullable=False)
|
||||
event_type = Column(String(100), nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Default values
|
||||
default_impact = Column(String(50), nullable=True)
|
||||
default_impact_multiplier = Column(Float, nullable=True)
|
||||
default_affected_categories = Column(String(500), nullable=True)
|
||||
|
||||
# Recurrence
|
||||
recurrence_pattern = Column(String(200), nullable=False) # e.g., "weekly:saturday", "monthly:last_sunday"
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": str(self.id),
|
||||
"tenant_id": str(self.tenant_id),
|
||||
"template_name": self.template_name,
|
||||
"event_type": self.event_type,
|
||||
"description": self.description,
|
||||
"default_impact": self.default_impact,
|
||||
"default_impact_multiplier": self.default_impact_multiplier,
|
||||
"default_affected_categories": self.default_affected_categories,
|
||||
"recurrence_pattern": self.recurrence_pattern,
|
||||
"is_active": self.is_active,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
59
services/tenant/app/models/tenant_location.py
Normal file
59
services/tenant/app/models/tenant_location.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""
|
||||
Tenant Location Model
|
||||
Represents physical locations for enterprise tenants (central production, retail outlets)
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Float, ForeignKey, Text, Integer, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class TenantLocation(Base):
|
||||
"""TenantLocation model - represents physical locations for enterprise tenants"""
|
||||
__tablename__ = "tenant_locations"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Location information
|
||||
name = Column(String(200), nullable=False)
|
||||
location_type = Column(String(50), nullable=False) # central_production, retail_outlet
|
||||
address = Column(Text, nullable=False)
|
||||
city = Column(String(100), default="Madrid")
|
||||
postal_code = Column(String(10), nullable=False)
|
||||
latitude = Column(Float, nullable=True)
|
||||
longitude = Column(Float, nullable=True)
|
||||
|
||||
# Location-specific configuration
|
||||
delivery_windows = Column(JSON, nullable=True) # { "monday": "08:00-12:00,14:00-18:00", ... }
|
||||
capacity = Column(Integer, nullable=True) # For production capacity in kg/day or storage capacity
|
||||
max_delivery_radius_km = Column(Float, nullable=True, default=50.0)
|
||||
|
||||
# Operational hours
|
||||
operational_hours = Column(JSON, nullable=True) # { "monday": "06:00-20:00", ... }
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Contact information
|
||||
contact_person = Column(String(200), nullable=True)
|
||||
contact_phone = Column(String(20), nullable=True)
|
||||
contact_email = Column(String(255), nullable=True)
|
||||
|
||||
# Custom delivery scheduling configuration per location
|
||||
delivery_schedule_config = Column(JSON, nullable=True) # { "delivery_days": "Mon,Wed,Fri", "time_window": "07:00-10:00" }
|
||||
|
||||
# Metadata
|
||||
metadata_ = Column(JSON, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Relationships
|
||||
tenant = relationship("Tenant", back_populates="locations")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TenantLocation(id={self.id}, tenant_id={self.tenant_id}, name={self.name}, type={self.location_type})>"
|
||||
370
services/tenant/app/models/tenant_settings.py
Normal file
370
services/tenant/app/models/tenant_settings.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# services/tenant/app/models/tenant_settings.py
|
||||
"""
|
||||
Tenant Settings Model
|
||||
Centralized configuration storage for all tenant-specific operational settings
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class TenantSettings(Base):
|
||||
"""
|
||||
Centralized tenant settings model
|
||||
Stores all operational configurations for a tenant across all services
|
||||
"""
|
||||
__tablename__ = "tenant_settings"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False, unique=True, index=True)
|
||||
|
||||
# Procurement & Auto-Approval Settings (Orders Service)
|
||||
procurement_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"auto_approve_enabled": True,
|
||||
"auto_approve_threshold_eur": 500.0,
|
||||
"auto_approve_min_supplier_score": 0.80,
|
||||
"require_approval_new_suppliers": True,
|
||||
"require_approval_critical_items": True,
|
||||
"procurement_lead_time_days": 3,
|
||||
"demand_forecast_days": 14,
|
||||
"safety_stock_percentage": 20.0,
|
||||
"po_approval_reminder_hours": 24,
|
||||
"po_critical_escalation_hours": 12,
|
||||
"use_reorder_rules": True,
|
||||
"economic_rounding": True,
|
||||
"respect_storage_limits": True,
|
||||
"use_supplier_minimums": True,
|
||||
"optimize_price_tiers": True
|
||||
})
|
||||
|
||||
# Inventory Management Settings (Inventory Service)
|
||||
inventory_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"low_stock_threshold": 10,
|
||||
"reorder_point": 20,
|
||||
"reorder_quantity": 50,
|
||||
"expiring_soon_days": 7,
|
||||
"expiration_warning_days": 3,
|
||||
"quality_score_threshold": 8.0,
|
||||
"temperature_monitoring_enabled": True,
|
||||
"refrigeration_temp_min": 1.0,
|
||||
"refrigeration_temp_max": 4.0,
|
||||
"freezer_temp_min": -20.0,
|
||||
"freezer_temp_max": -15.0,
|
||||
"room_temp_min": 18.0,
|
||||
"room_temp_max": 25.0,
|
||||
"temp_deviation_alert_minutes": 15,
|
||||
"critical_temp_deviation_minutes": 5
|
||||
})
|
||||
|
||||
# Production Settings (Production Service)
|
||||
production_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"planning_horizon_days": 7,
|
||||
"minimum_batch_size": 1.0,
|
||||
"maximum_batch_size": 100.0,
|
||||
"production_buffer_percentage": 10.0,
|
||||
"working_hours_per_day": 12,
|
||||
"max_overtime_hours": 4,
|
||||
"capacity_utilization_target": 0.85,
|
||||
"capacity_warning_threshold": 0.95,
|
||||
"quality_check_enabled": True,
|
||||
"minimum_yield_percentage": 85.0,
|
||||
"quality_score_threshold": 8.0,
|
||||
"schedule_optimization_enabled": True,
|
||||
"prep_time_buffer_minutes": 30,
|
||||
"cleanup_time_buffer_minutes": 15,
|
||||
"labor_cost_per_hour_eur": 15.0,
|
||||
"overhead_cost_percentage": 20.0
|
||||
})
|
||||
|
||||
# Supplier Settings (Suppliers Service)
|
||||
supplier_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"default_payment_terms_days": 30,
|
||||
"default_delivery_days": 3,
|
||||
"excellent_delivery_rate": 95.0,
|
||||
"good_delivery_rate": 90.0,
|
||||
"excellent_quality_rate": 98.0,
|
||||
"good_quality_rate": 95.0,
|
||||
"critical_delivery_delay_hours": 24,
|
||||
"critical_quality_rejection_rate": 10.0,
|
||||
"high_cost_variance_percentage": 15.0,
|
||||
# Supplier Approval Workflow Settings
|
||||
"require_supplier_approval": True,
|
||||
"auto_approve_for_admin_owner": True,
|
||||
"approval_required_roles": ["member", "viewer"]
|
||||
})
|
||||
|
||||
# POS Integration Settings (POS Service)
|
||||
pos_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"sync_interval_minutes": 5,
|
||||
"auto_sync_products": True,
|
||||
"auto_sync_transactions": True
|
||||
})
|
||||
|
||||
# Order & Business Rules Settings (Orders Service)
|
||||
order_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"max_discount_percentage": 50.0,
|
||||
"default_delivery_window_hours": 48,
|
||||
"dynamic_pricing_enabled": False,
|
||||
"discount_enabled": True,
|
||||
"delivery_tracking_enabled": True
|
||||
})
|
||||
|
||||
# Replenishment Planning Settings (Orchestrator Service)
|
||||
replenishment_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"projection_horizon_days": 7,
|
||||
"service_level": 0.95,
|
||||
"buffer_days": 1,
|
||||
"enable_auto_replenishment": True,
|
||||
"min_order_quantity": 1.0,
|
||||
"max_order_quantity": 1000.0,
|
||||
"demand_forecast_days": 14
|
||||
})
|
||||
|
||||
# Safety Stock Settings (Orchestrator Service)
|
||||
safety_stock_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"service_level": 0.95,
|
||||
"method": "statistical",
|
||||
"min_safety_stock": 0.0,
|
||||
"max_safety_stock": 100.0,
|
||||
"reorder_point_calculation": "safety_stock_plus_lead_time_demand"
|
||||
})
|
||||
|
||||
# MOQ Aggregation Settings (Orchestrator Service)
|
||||
moq_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"consolidation_window_days": 7,
|
||||
"allow_early_ordering": True,
|
||||
"enable_batch_optimization": True,
|
||||
"min_batch_size": 1.0,
|
||||
"max_batch_size": 1000.0
|
||||
})
|
||||
|
||||
# Supplier Selection Settings (Orchestrator Service)
|
||||
supplier_selection_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
"price_weight": 0.40,
|
||||
"lead_time_weight": 0.20,
|
||||
"quality_weight": 0.20,
|
||||
"reliability_weight": 0.20,
|
||||
"diversification_threshold": 1000,
|
||||
"max_single_percentage": 0.70,
|
||||
"enable_supplier_score_optimization": True
|
||||
})
|
||||
|
||||
# ML Insights Settings (AI Insights Service)
|
||||
ml_insights_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
# Inventory ML (Safety Stock Optimization)
|
||||
"inventory_lookback_days": 90,
|
||||
"inventory_min_history_days": 30,
|
||||
|
||||
# Production ML (Yield Prediction)
|
||||
"production_lookback_days": 90,
|
||||
"production_min_history_runs": 30,
|
||||
|
||||
# Procurement ML (Supplier Analysis & Price Forecasting)
|
||||
"supplier_analysis_lookback_days": 180,
|
||||
"supplier_analysis_min_orders": 10,
|
||||
"price_forecast_lookback_days": 180,
|
||||
"price_forecast_horizon_days": 30,
|
||||
|
||||
# Forecasting ML (Dynamic Rules)
|
||||
"rules_generation_lookback_days": 90,
|
||||
"rules_generation_min_samples": 10,
|
||||
|
||||
# Global ML Settings
|
||||
"enable_ml_insights": True,
|
||||
"ml_insights_auto_trigger": False,
|
||||
"ml_confidence_threshold": 0.80
|
||||
})
|
||||
|
||||
# Notification Settings (Notification Service)
|
||||
notification_settings = Column(JSON, nullable=False, default=lambda: {
|
||||
# WhatsApp Configuration (Shared Account Model)
|
||||
"whatsapp_enabled": False,
|
||||
"whatsapp_phone_number_id": "", # Meta WhatsApp Phone Number ID (from shared master account)
|
||||
"whatsapp_display_phone_number": "", # Display format for UI (e.g., "+34 612 345 678")
|
||||
"whatsapp_default_language": "es",
|
||||
|
||||
# Email Configuration
|
||||
"email_enabled": True,
|
||||
"email_from_address": "",
|
||||
"email_from_name": "",
|
||||
"email_reply_to": "",
|
||||
|
||||
# Notification Preferences
|
||||
"enable_po_notifications": True,
|
||||
"enable_inventory_alerts": True,
|
||||
"enable_production_alerts": True,
|
||||
"enable_forecast_alerts": True,
|
||||
|
||||
# Notification Channels
|
||||
"po_notification_channels": ["email"], # ["email", "whatsapp"]
|
||||
"inventory_alert_channels": ["email"],
|
||||
"production_alert_channels": ["email"],
|
||||
"forecast_alert_channels": ["email"]
|
||||
})
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc), nullable=False)
|
||||
|
||||
# Relationships
|
||||
tenant = relationship("Tenant", backref="settings")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TenantSettings(tenant_id={self.tenant_id})>"
|
||||
|
||||
@staticmethod
|
||||
def get_default_settings() -> dict:
|
||||
"""
|
||||
Get default settings for all categories
|
||||
Returns a dictionary with default values for all setting categories
|
||||
"""
|
||||
return {
|
||||
"procurement_settings": {
|
||||
"auto_approve_enabled": True,
|
||||
"auto_approve_threshold_eur": 500.0,
|
||||
"auto_approve_min_supplier_score": 0.80,
|
||||
"require_approval_new_suppliers": True,
|
||||
"require_approval_critical_items": True,
|
||||
"procurement_lead_time_days": 3,
|
||||
"demand_forecast_days": 14,
|
||||
"safety_stock_percentage": 20.0,
|
||||
"po_approval_reminder_hours": 24,
|
||||
"po_critical_escalation_hours": 12,
|
||||
"use_reorder_rules": True,
|
||||
"economic_rounding": True,
|
||||
"respect_storage_limits": True,
|
||||
"use_supplier_minimums": True,
|
||||
"optimize_price_tiers": True
|
||||
},
|
||||
"inventory_settings": {
|
||||
"low_stock_threshold": 10,
|
||||
"reorder_point": 20,
|
||||
"reorder_quantity": 50,
|
||||
"expiring_soon_days": 7,
|
||||
"expiration_warning_days": 3,
|
||||
"quality_score_threshold": 8.0,
|
||||
"temperature_monitoring_enabled": True,
|
||||
"refrigeration_temp_min": 1.0,
|
||||
"refrigeration_temp_max": 4.0,
|
||||
"freezer_temp_min": -20.0,
|
||||
"freezer_temp_max": -15.0,
|
||||
"room_temp_min": 18.0,
|
||||
"room_temp_max": 25.0,
|
||||
"temp_deviation_alert_minutes": 15,
|
||||
"critical_temp_deviation_minutes": 5
|
||||
},
|
||||
"production_settings": {
|
||||
"planning_horizon_days": 7,
|
||||
"minimum_batch_size": 1.0,
|
||||
"maximum_batch_size": 100.0,
|
||||
"production_buffer_percentage": 10.0,
|
||||
"working_hours_per_day": 12,
|
||||
"max_overtime_hours": 4,
|
||||
"capacity_utilization_target": 0.85,
|
||||
"capacity_warning_threshold": 0.95,
|
||||
"quality_check_enabled": True,
|
||||
"minimum_yield_percentage": 85.0,
|
||||
"quality_score_threshold": 8.0,
|
||||
"schedule_optimization_enabled": True,
|
||||
"prep_time_buffer_minutes": 30,
|
||||
"cleanup_time_buffer_minutes": 15,
|
||||
"labor_cost_per_hour_eur": 15.0,
|
||||
"overhead_cost_percentage": 20.0
|
||||
},
|
||||
"supplier_settings": {
|
||||
"default_payment_terms_days": 30,
|
||||
"default_delivery_days": 3,
|
||||
"excellent_delivery_rate": 95.0,
|
||||
"good_delivery_rate": 90.0,
|
||||
"excellent_quality_rate": 98.0,
|
||||
"good_quality_rate": 95.0,
|
||||
"critical_delivery_delay_hours": 24,
|
||||
"critical_quality_rejection_rate": 10.0,
|
||||
"high_cost_variance_percentage": 15.0,
|
||||
"require_supplier_approval": True,
|
||||
"auto_approve_for_admin_owner": True,
|
||||
"approval_required_roles": ["member", "viewer"]
|
||||
},
|
||||
"pos_settings": {
|
||||
"sync_interval_minutes": 5,
|
||||
"auto_sync_products": True,
|
||||
"auto_sync_transactions": True
|
||||
},
|
||||
"order_settings": {
|
||||
"max_discount_percentage": 50.0,
|
||||
"default_delivery_window_hours": 48,
|
||||
"dynamic_pricing_enabled": False,
|
||||
"discount_enabled": True,
|
||||
"delivery_tracking_enabled": True
|
||||
},
|
||||
"replenishment_settings": {
|
||||
"projection_horizon_days": 7,
|
||||
"service_level": 0.95,
|
||||
"buffer_days": 1,
|
||||
"enable_auto_replenishment": True,
|
||||
"min_order_quantity": 1.0,
|
||||
"max_order_quantity": 1000.0,
|
||||
"demand_forecast_days": 14
|
||||
},
|
||||
"safety_stock_settings": {
|
||||
"service_level": 0.95,
|
||||
"method": "statistical",
|
||||
"min_safety_stock": 0.0,
|
||||
"max_safety_stock": 100.0,
|
||||
"reorder_point_calculation": "safety_stock_plus_lead_time_demand"
|
||||
},
|
||||
"moq_settings": {
|
||||
"consolidation_window_days": 7,
|
||||
"allow_early_ordering": True,
|
||||
"enable_batch_optimization": True,
|
||||
"min_batch_size": 1.0,
|
||||
"max_batch_size": 1000.0
|
||||
},
|
||||
"supplier_selection_settings": {
|
||||
"price_weight": 0.40,
|
||||
"lead_time_weight": 0.20,
|
||||
"quality_weight": 0.20,
|
||||
"reliability_weight": 0.20,
|
||||
"diversification_threshold": 1000,
|
||||
"max_single_percentage": 0.70,
|
||||
"enable_supplier_score_optimization": True
|
||||
},
|
||||
"ml_insights_settings": {
|
||||
"inventory_lookback_days": 90,
|
||||
"inventory_min_history_days": 30,
|
||||
"production_lookback_days": 90,
|
||||
"production_min_history_runs": 30,
|
||||
"supplier_analysis_lookback_days": 180,
|
||||
"supplier_analysis_min_orders": 10,
|
||||
"price_forecast_lookback_days": 180,
|
||||
"price_forecast_horizon_days": 30,
|
||||
"rules_generation_lookback_days": 90,
|
||||
"rules_generation_min_samples": 10,
|
||||
"enable_ml_insights": True,
|
||||
"ml_insights_auto_trigger": False,
|
||||
"ml_confidence_threshold": 0.80
|
||||
},
|
||||
"notification_settings": {
|
||||
"whatsapp_enabled": False,
|
||||
"whatsapp_phone_number_id": "",
|
||||
"whatsapp_display_phone_number": "",
|
||||
"whatsapp_default_language": "es",
|
||||
"email_enabled": True,
|
||||
"email_from_address": "",
|
||||
"email_from_name": "",
|
||||
"email_reply_to": "",
|
||||
"enable_po_notifications": True,
|
||||
"enable_inventory_alerts": True,
|
||||
"enable_production_alerts": True,
|
||||
"enable_forecast_alerts": True,
|
||||
"po_notification_channels": ["email"],
|
||||
"inventory_alert_channels": ["email"],
|
||||
"production_alert_channels": ["email"],
|
||||
"forecast_alert_channels": ["email"]
|
||||
}
|
||||
}
|
||||
221
services/tenant/app/models/tenants.py
Normal file
221
services/tenant/app/models/tenants.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# services/tenant/app/models/tenants.py - FIXED VERSION
|
||||
"""
|
||||
Tenant models for bakery management - FIXED
|
||||
Removed cross-service User relationship to eliminate circular dependencies
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Boolean, DateTime, Float, ForeignKey, Text, Integer, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import relationship
|
||||
from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
class Tenant(Base):
|
||||
"""Tenant/Bakery model"""
|
||||
__tablename__ = "tenants"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
name = Column(String(200), nullable=False)
|
||||
subdomain = Column(String(100), unique=True)
|
||||
business_type = Column(String(100), default="bakery")
|
||||
business_model = Column(String(100), default="individual_bakery") # individual_bakery, central_baker_satellite, retail_bakery, hybrid_bakery
|
||||
|
||||
# Location info
|
||||
address = Column(Text, nullable=False)
|
||||
city = Column(String(100), default="Madrid")
|
||||
postal_code = Column(String(10), nullable=False)
|
||||
latitude = Column(Float)
|
||||
longitude = Column(Float)
|
||||
|
||||
# Regional/Localization configuration
|
||||
timezone = Column(String(50), default="Europe/Madrid", nullable=False)
|
||||
currency = Column(String(3), default="EUR", nullable=False) # Currency code: EUR, USD, GBP
|
||||
language = Column(String(5), default="es", nullable=False) # Language code: es, en, eu
|
||||
|
||||
# Contact info
|
||||
phone = Column(String(20))
|
||||
email = Column(String(255))
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
||||
# Demo account flags
|
||||
is_demo = Column(Boolean, default=False, index=True)
|
||||
is_demo_template = Column(Boolean, default=False, index=True)
|
||||
base_demo_tenant_id = Column(UUID(as_uuid=True), nullable=True, index=True)
|
||||
demo_session_id = Column(String(100), nullable=True, index=True)
|
||||
demo_expires_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# ML status
|
||||
ml_model_trained = Column(Boolean, default=False)
|
||||
last_training_date = Column(DateTime(timezone=True))
|
||||
|
||||
# Additional metadata (JSON field for flexible data storage)
|
||||
metadata_ = Column(JSON, nullable=True)
|
||||
|
||||
# Ownership (user_id without FK - cross-service reference)
|
||||
owner_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
|
||||
# Enterprise tier hierarchy fields
|
||||
parent_tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="RESTRICT"), nullable=True, index=True)
|
||||
tenant_type = Column(String(50), default="standalone", nullable=False) # standalone, parent, child
|
||||
hierarchy_path = Column(String(500), nullable=True) # Materialized path for queries
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# 3D Secure (3DS) tracking
|
||||
threeds_authentication_required = Column(Boolean, default=False)
|
||||
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
|
||||
threeds_authentication_completed = Column(Boolean, default=False)
|
||||
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
last_threeds_setup_intent_id = Column(String(255), nullable=True)
|
||||
threeds_action_type = Column(String(100), nullable=True)
|
||||
|
||||
# Relationships - only within tenant service
|
||||
members = relationship("TenantMember", back_populates="tenant", cascade="all, delete-orphan")
|
||||
subscriptions = relationship("Subscription", back_populates="tenant", cascade="all, delete-orphan")
|
||||
locations = relationship("TenantLocation", back_populates="tenant", cascade="all, delete-orphan")
|
||||
child_tenants = relationship("Tenant", back_populates="parent_tenant", remote_side=[id])
|
||||
parent_tenant = relationship("Tenant", back_populates="child_tenants", remote_side=[parent_tenant_id])
|
||||
|
||||
# REMOVED: users relationship - no cross-service SQLAlchemy relationships
|
||||
|
||||
@property
|
||||
def subscription_tier(self):
|
||||
"""
|
||||
Get current subscription tier from active subscription
|
||||
|
||||
Note: This is a computed property that requires subscription relationship to be loaded.
|
||||
For performance-critical operations, use the subscription cache service directly.
|
||||
"""
|
||||
# Find active subscription
|
||||
for subscription in self.subscriptions:
|
||||
if subscription.status == 'active':
|
||||
return subscription.plan
|
||||
return "starter" # Default fallback
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Tenant(id={self.id}, name={self.name})>"
|
||||
|
||||
class TenantMember(Base):
|
||||
"""
|
||||
Tenant membership model for team access.
|
||||
|
||||
This model represents TENANT-SPECIFIC roles, which are distinct from global user roles.
|
||||
|
||||
TENANT ROLES (stored here):
|
||||
- owner: Full control of the tenant, can transfer ownership, manage all aspects
|
||||
- admin: Tenant administrator, can manage team members and most operations
|
||||
- member: Standard team member, regular operational access
|
||||
- viewer: Read-only observer, view-only access to tenant data
|
||||
|
||||
ROLE MAPPING TO GLOBAL ROLES:
|
||||
When users are created through tenant management (pilot phase), their tenant role
|
||||
is mapped to a global user role in the Auth service:
|
||||
- tenant 'admin' → global 'admin' (system-wide admin access)
|
||||
- tenant 'member' → global 'manager' (management-level access)
|
||||
- tenant 'viewer' → global 'user' (basic user access)
|
||||
- tenant 'owner' → No automatic global role (owner is tenant-specific)
|
||||
|
||||
This mapping is implemented in app/api/tenant_members.py lines 68-76.
|
||||
|
||||
Note: user_id is a cross-service reference (no FK) to avoid circular dependencies.
|
||||
User enrichment is handled at the service layer via Auth service calls.
|
||||
"""
|
||||
__tablename__ = "tenant_members"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False)
|
||||
user_id = Column(UUID(as_uuid=True), nullable=False, index=True) # No FK - cross-service reference
|
||||
|
||||
# Role and permissions specific to this tenant
|
||||
# Valid values: 'owner', 'admin', 'member', 'viewer', 'network_admin'
|
||||
role = Column(String(50), default="member")
|
||||
permissions = Column(Text) # JSON string of permissions
|
||||
|
||||
# Status
|
||||
is_active = Column(Boolean, default=True)
|
||||
invited_by = Column(UUID(as_uuid=True)) # No FK - cross-service reference
|
||||
invited_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
joined_at = Column(DateTime(timezone=True))
|
||||
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Relationships - only within tenant service
|
||||
tenant = relationship("Tenant", back_populates="members")
|
||||
|
||||
# REMOVED: user relationship - no cross-service SQLAlchemy relationships
|
||||
|
||||
def __repr__(self):
|
||||
return f"<TenantMember(tenant_id={self.tenant_id}, user_id={self.user_id}, role={self.role})>"
|
||||
|
||||
# Additional models for subscriptions, plans, etc.
|
||||
class Subscription(Base):
|
||||
"""Subscription model for tenant billing with tenant linking support"""
|
||||
__tablename__ = "subscriptions"
|
||||
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=True)
|
||||
|
||||
# User reference for tenant-independent subscriptions
|
||||
user_id = Column(UUID(as_uuid=True), nullable=True, index=True)
|
||||
|
||||
# Tenant linking status
|
||||
is_tenant_linked = Column(Boolean, default=False, nullable=False)
|
||||
tenant_linking_status = Column(String(50), nullable=True) # pending, completed, failed
|
||||
linked_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
plan = Column(String(50), default="starter") # starter, professional, enterprise
|
||||
status = Column(String(50), default="active") # active, pending_cancellation, inactive, suspended, pending_tenant_linking
|
||||
|
||||
# Billing
|
||||
monthly_price = Column(Float, default=0.0)
|
||||
billing_cycle = Column(String(20), default="monthly") # monthly, yearly
|
||||
next_billing_date = Column(DateTime(timezone=True))
|
||||
trial_ends_at = Column(DateTime(timezone=True))
|
||||
cancelled_at = Column(DateTime(timezone=True), nullable=True)
|
||||
cancellation_effective_date = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Payment provider references (generic names for provider-agnostic design)
|
||||
subscription_id = Column(String(255), nullable=True) # Payment provider subscription ID
|
||||
customer_id = Column(String(255), nullable=True) # Payment provider customer ID
|
||||
|
||||
# Limits
|
||||
max_users = Column(Integer, default=5)
|
||||
max_locations = Column(Integer, default=1)
|
||||
max_products = Column(Integer, default=50)
|
||||
|
||||
# Features - Store plan features as JSON
|
||||
features = Column(JSON)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# 3D Secure (3DS) tracking
|
||||
threeds_authentication_required = Column(Boolean, default=False)
|
||||
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
|
||||
threeds_authentication_completed = Column(Boolean, default=False)
|
||||
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
last_threeds_setup_intent_id = Column(String(255), nullable=True)
|
||||
threeds_action_type = Column(String(100), nullable=True)
|
||||
|
||||
# Relationships
|
||||
tenant = relationship("Tenant")
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Subscription(id={self.id}, tenant_id={self.tenant_id}, user_id={self.user_id}, plan={self.plan}, status={self.status})>"
|
||||
|
||||
def is_pending_tenant_linking(self) -> bool:
|
||||
"""Check if subscription is waiting to be linked to a tenant"""
|
||||
return self.tenant_linking_status == "pending" and not self.is_tenant_linked
|
||||
|
||||
def can_be_linked_to_tenant(self, user_id: str) -> bool:
|
||||
"""Check if subscription can be linked to a tenant by the given user"""
|
||||
return (self.is_pending_tenant_linking() and
|
||||
str(self.user_id) == user_id and
|
||||
self.tenant_id is None)
|
||||
16
services/tenant/app/repositories/__init__.py
Normal file
16
services/tenant/app/repositories/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""
|
||||
Tenant Service Repositories
|
||||
Repository implementations for tenant service
|
||||
"""
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from .tenant_repository import TenantRepository
|
||||
from .tenant_member_repository import TenantMemberRepository
|
||||
from .subscription_repository import SubscriptionRepository
|
||||
|
||||
__all__ = [
|
||||
"TenantBaseRepository",
|
||||
"TenantRepository",
|
||||
"TenantMemberRepository",
|
||||
"SubscriptionRepository"
|
||||
]
|
||||
234
services/tenant/app/repositories/base.py
Normal file
234
services/tenant/app/repositories/base.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""
|
||||
Base Repository for Tenant Service
|
||||
Service-specific repository base class with tenant management utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantBaseRepository(BaseRepository):
|
||||
"""Base repository for tenant service with common tenant operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Tenant data is relatively stable, medium cache time (10 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by tenant ID"""
|
||||
if hasattr(self.model, 'tenant_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by user ID (for cross-service references)"""
|
||||
if hasattr(self.model, 'user_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"user_id": user_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
elif hasattr(self.model, 'owner_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"owner_id": user_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return []
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_old_records(self, days_old: int = 365) -> int:
|
||||
"""Clean up old tenant records (very conservative - 1 year)"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Only delete inactive records that are very old
|
||||
conditions = [
|
||||
"created_at < :cutoff_date"
|
||||
]
|
||||
|
||||
if hasattr(self.model, 'is_active'):
|
||||
conditions.append("is_active = false")
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up old {self.model.__name__} records",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get statistics for a tenant"""
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Get basic counts
|
||||
total_records = await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
# Get active records if applicable
|
||||
active_records = total_records
|
||||
if hasattr(self.model, 'is_active'):
|
||||
active_records = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
})
|
||||
|
||||
# Get recent activity (records in last 7 days)
|
||||
seven_days_ago = datetime.utcnow() - timedelta(days=7)
|
||||
recent_query = text(f"""
|
||||
SELECT COUNT(*) as count
|
||||
FROM {table_name}
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND created_at >= :seven_days_ago
|
||||
""")
|
||||
|
||||
result = await self.session.execute(recent_query, {
|
||||
"tenant_id": tenant_id,
|
||||
"seven_days_ago": seven_days_ago
|
||||
})
|
||||
recent_records = result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total_records": total_records,
|
||||
"active_records": active_records,
|
||||
"inactive_records": total_records - active_records,
|
||||
"recent_records_7d": recent_records
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics",
|
||||
model=self.model.__name__,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_records": 0,
|
||||
"active_records": 0,
|
||||
"inactive_records": 0,
|
||||
"recent_records_7d": 0
|
||||
}
|
||||
|
||||
def _validate_tenant_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate tenant-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate tenant_id format if present
|
||||
if "tenant_id" in data and data["tenant_id"]:
|
||||
tenant_id = data["tenant_id"]
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate user_id format if present
|
||||
if "user_id" in data and data["user_id"]:
|
||||
user_id = data["user_id"]
|
||||
if not isinstance(user_id, str) or len(user_id) < 1:
|
||||
errors.append("Invalid user_id format")
|
||||
|
||||
# Validate owner_id format if present
|
||||
if "owner_id" in data and data["owner_id"]:
|
||||
owner_id = data["owner_id"]
|
||||
if not isinstance(owner_id, str) or len(owner_id) < 1:
|
||||
errors.append("Invalid owner_id format")
|
||||
|
||||
# Validate email format if present
|
||||
if "email" in data and data["email"]:
|
||||
email = data["email"]
|
||||
if "@" not in email or "." not in email.split("@")[-1]:
|
||||
errors.append("Invalid email format")
|
||||
|
||||
# Validate phone format if present (basic validation)
|
||||
if "phone" in data and data["phone"]:
|
||||
phone = data["phone"]
|
||||
if not isinstance(phone, str) or len(phone) < 9:
|
||||
errors.append("Invalid phone format")
|
||||
|
||||
# Validate coordinates if present
|
||||
if "latitude" in data and data["latitude"] is not None:
|
||||
try:
|
||||
lat = float(data["latitude"])
|
||||
if lat < -90 or lat > 90:
|
||||
errors.append("Invalid latitude - must be between -90 and 90")
|
||||
except (ValueError, TypeError):
|
||||
errors.append("Invalid latitude format")
|
||||
|
||||
if "longitude" in data and data["longitude"] is not None:
|
||||
try:
|
||||
lng = float(data["longitude"])
|
||||
if lng < -180 or lng > 180:
|
||||
errors.append("Invalid longitude - must be between -180 and 180")
|
||||
except (ValueError, TypeError):
|
||||
errors.append("Invalid longitude format")
|
||||
|
||||
# Validate JSON fields
|
||||
json_fields = ["permissions"]
|
||||
for field in json_fields:
|
||||
if field in data and data[field]:
|
||||
if isinstance(data[field], str):
|
||||
try:
|
||||
json.loads(data[field])
|
||||
except json.JSONDecodeError:
|
||||
errors.append(f"Invalid JSON format in {field}")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
326
services/tenant/app/repositories/coupon_repository.py
Normal file
326
services/tenant/app/repositories/coupon_repository.py
Normal file
@@ -0,0 +1,326 @@
|
||||
"""
|
||||
Repository for coupon data access and validation
|
||||
"""
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import and_, select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from app.models.coupon import CouponModel, CouponRedemptionModel
|
||||
from shared.subscription.coupons import (
|
||||
Coupon,
|
||||
CouponRedemption,
|
||||
CouponValidationResult,
|
||||
DiscountType,
|
||||
calculate_trial_end_date,
|
||||
format_discount_description
|
||||
)
|
||||
|
||||
|
||||
class CouponRepository:
|
||||
"""Data access layer for coupon operations"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_coupon_by_code(self, code: str) -> Optional[Coupon]:
|
||||
"""
|
||||
Retrieve coupon by code.
|
||||
Returns None if not found.
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(CouponModel).where(CouponModel.code == code.upper())
|
||||
)
|
||||
coupon_model = result.scalar_one_or_none()
|
||||
|
||||
if not coupon_model:
|
||||
return None
|
||||
|
||||
return self._model_to_dataclass(coupon_model)
|
||||
|
||||
async def validate_coupon(
|
||||
self,
|
||||
code: str,
|
||||
tenant_id: str
|
||||
) -> CouponValidationResult:
|
||||
"""
|
||||
Validate a coupon code for a specific tenant.
|
||||
Checks: existence, validity, redemption limits, and if tenant already used it.
|
||||
"""
|
||||
# Get coupon
|
||||
coupon = await self.get_coupon_by_code(code)
|
||||
if not coupon:
|
||||
return CouponValidationResult(
|
||||
valid=False,
|
||||
coupon=None,
|
||||
error_message="Código de cupón inválido",
|
||||
discount_preview=None
|
||||
)
|
||||
|
||||
# Check if coupon can be redeemed
|
||||
can_redeem, reason = coupon.can_be_redeemed()
|
||||
if not can_redeem:
|
||||
error_messages = {
|
||||
"Coupon is inactive": "Este cupón no está activo",
|
||||
"Coupon is not yet valid": "Este cupón aún no es válido",
|
||||
"Coupon has expired": "Este cupón ha expirado",
|
||||
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
|
||||
}
|
||||
return CouponValidationResult(
|
||||
valid=False,
|
||||
coupon=coupon,
|
||||
error_message=error_messages.get(reason, reason),
|
||||
discount_preview=None
|
||||
)
|
||||
|
||||
# Check if tenant already redeemed this coupon
|
||||
result = await self.db.execute(
|
||||
select(CouponRedemptionModel).where(
|
||||
and_(
|
||||
CouponRedemptionModel.tenant_id == tenant_id,
|
||||
CouponRedemptionModel.coupon_code == code.upper()
|
||||
)
|
||||
)
|
||||
)
|
||||
existing_redemption = result.scalar_one_or_none()
|
||||
|
||||
if existing_redemption:
|
||||
return CouponValidationResult(
|
||||
valid=False,
|
||||
coupon=coupon,
|
||||
error_message="Ya has utilizado este cupón",
|
||||
discount_preview=None
|
||||
)
|
||||
|
||||
# Generate discount preview
|
||||
discount_preview = self._generate_discount_preview(coupon)
|
||||
|
||||
return CouponValidationResult(
|
||||
valid=True,
|
||||
coupon=coupon,
|
||||
error_message=None,
|
||||
discount_preview=discount_preview
|
||||
)
|
||||
|
||||
async def redeem_coupon(
|
||||
self,
|
||||
code: str,
|
||||
tenant_id: Optional[str],
|
||||
base_trial_days: int = 0
|
||||
) -> tuple[bool, Optional[CouponRedemption], Optional[str]]:
|
||||
"""
|
||||
Redeem a coupon for a tenant.
|
||||
For tenant-independent registrations, tenant_id can be None initially.
|
||||
Returns (success, redemption, error_message)
|
||||
"""
|
||||
# For tenant-independent registrations, skip tenant validation
|
||||
if tenant_id:
|
||||
# Validate first
|
||||
validation = await self.validate_coupon(code, tenant_id)
|
||||
if not validation.valid:
|
||||
return False, None, validation.error_message
|
||||
coupon = validation.coupon
|
||||
else:
|
||||
# Just get the coupon and validate its general availability
|
||||
coupon = await self.get_coupon_by_code(code)
|
||||
if not coupon:
|
||||
return False, None, "Código de cupón inválido"
|
||||
|
||||
# Check if coupon can be redeemed
|
||||
can_redeem, reason = coupon.can_be_redeemed()
|
||||
if not can_redeem:
|
||||
error_messages = {
|
||||
"Coupon is inactive": "Este cupón no está activo",
|
||||
"Coupon is not yet valid": "Este cupón aún no es válido",
|
||||
"Coupon has expired": "Este cupón ha expirado",
|
||||
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
|
||||
}
|
||||
return False, None, error_messages.get(reason, reason)
|
||||
|
||||
# Calculate discount applied
|
||||
discount_applied = self._calculate_discount_applied(
|
||||
coupon,
|
||||
base_trial_days
|
||||
)
|
||||
|
||||
# Only create redemption record if tenant_id is provided
|
||||
# For tenant-independent subscriptions, skip redemption record creation
|
||||
if tenant_id:
|
||||
# Create redemption record
|
||||
redemption_model = CouponRedemptionModel(
|
||||
tenant_id=tenant_id,
|
||||
coupon_code=code.upper(),
|
||||
redeemed_at=datetime.now(timezone.utc),
|
||||
discount_applied=discount_applied,
|
||||
extra_data={
|
||||
"coupon_type": coupon.discount_type.value,
|
||||
"coupon_value": coupon.discount_value
|
||||
}
|
||||
)
|
||||
|
||||
self.db.add(redemption_model)
|
||||
|
||||
# Increment coupon redemption count
|
||||
result = await self.db.execute(
|
||||
select(CouponModel).where(CouponModel.code == code.upper())
|
||||
)
|
||||
coupon_model = result.scalar_one_or_none()
|
||||
if coupon_model:
|
||||
coupon_model.current_redemptions += 1
|
||||
|
||||
try:
|
||||
await self.db.commit()
|
||||
await self.db.refresh(redemption_model)
|
||||
|
||||
redemption = CouponRedemption(
|
||||
id=str(redemption_model.id),
|
||||
tenant_id=redemption_model.tenant_id,
|
||||
coupon_code=redemption_model.coupon_code,
|
||||
redeemed_at=redemption_model.redeemed_at,
|
||||
discount_applied=redemption_model.discount_applied,
|
||||
extra_data=redemption_model.extra_data
|
||||
)
|
||||
|
||||
return True, redemption, None
|
||||
|
||||
except Exception as e:
|
||||
await self.db.rollback()
|
||||
return False, None, f"Error al aplicar el cupón: {str(e)}"
|
||||
else:
|
||||
# For tenant-independent subscriptions, return discount without creating redemption
|
||||
# The redemption will be created when the tenant is linked
|
||||
redemption = CouponRedemption(
|
||||
id="pending", # Temporary ID
|
||||
tenant_id="pending", # Will be set during tenant linking
|
||||
coupon_code=code.upper(),
|
||||
redeemed_at=datetime.now(timezone.utc),
|
||||
discount_applied=discount_applied,
|
||||
extra_data={
|
||||
"coupon_type": coupon.discount_type.value,
|
||||
"coupon_value": coupon.discount_value
|
||||
}
|
||||
)
|
||||
return True, redemption, None
|
||||
|
||||
async def get_redemption_by_tenant_and_code(
|
||||
self,
|
||||
tenant_id: str,
|
||||
code: str
|
||||
) -> Optional[CouponRedemption]:
|
||||
"""Get existing redemption for tenant and coupon code"""
|
||||
result = await self.db.execute(
|
||||
select(CouponRedemptionModel).where(
|
||||
and_(
|
||||
CouponRedemptionModel.tenant_id == tenant_id,
|
||||
CouponRedemptionModel.coupon_code == code.upper()
|
||||
)
|
||||
)
|
||||
)
|
||||
redemption_model = result.scalar_one_or_none()
|
||||
|
||||
if not redemption_model:
|
||||
return None
|
||||
|
||||
return CouponRedemption(
|
||||
id=str(redemption_model.id),
|
||||
tenant_id=redemption_model.tenant_id,
|
||||
coupon_code=redemption_model.coupon_code,
|
||||
redeemed_at=redemption_model.redeemed_at,
|
||||
discount_applied=redemption_model.discount_applied,
|
||||
extra_data=redemption_model.extra_data
|
||||
)
|
||||
|
||||
async def get_coupon_usage_stats(self, code: str) -> Optional[dict]:
|
||||
"""Get usage statistics for a coupon"""
|
||||
result = await self.db.execute(
|
||||
select(CouponModel).where(CouponModel.code == code.upper())
|
||||
)
|
||||
coupon_model = result.scalar_one_or_none()
|
||||
|
||||
if not coupon_model:
|
||||
return None
|
||||
|
||||
count_result = await self.db.execute(
|
||||
select(CouponRedemptionModel).where(
|
||||
CouponRedemptionModel.coupon_code == code.upper()
|
||||
)
|
||||
)
|
||||
redemptions_count = len(count_result.scalars().all())
|
||||
|
||||
return {
|
||||
"code": coupon_model.code,
|
||||
"current_redemptions": coupon_model.current_redemptions,
|
||||
"max_redemptions": coupon_model.max_redemptions,
|
||||
"redemptions_remaining": (
|
||||
coupon_model.max_redemptions - coupon_model.current_redemptions
|
||||
if coupon_model.max_redemptions
|
||||
else None
|
||||
),
|
||||
"active": coupon_model.active,
|
||||
"valid_from": coupon_model.valid_from.isoformat(),
|
||||
"valid_until": coupon_model.valid_until.isoformat() if coupon_model.valid_until else None
|
||||
}
|
||||
|
||||
def _model_to_dataclass(self, model: CouponModel) -> Coupon:
|
||||
"""Convert SQLAlchemy model to dataclass"""
|
||||
return Coupon(
|
||||
id=str(model.id),
|
||||
code=model.code,
|
||||
discount_type=DiscountType(model.discount_type),
|
||||
discount_value=model.discount_value,
|
||||
max_redemptions=model.max_redemptions,
|
||||
current_redemptions=model.current_redemptions,
|
||||
valid_from=model.valid_from,
|
||||
valid_until=model.valid_until,
|
||||
active=model.active,
|
||||
created_at=model.created_at,
|
||||
extra_data=model.extra_data
|
||||
)
|
||||
|
||||
def _generate_discount_preview(self, coupon: Coupon) -> dict:
|
||||
"""Generate a preview of the discount to be applied"""
|
||||
description = format_discount_description(coupon)
|
||||
|
||||
preview = {
|
||||
"description": description,
|
||||
"discount_type": coupon.discount_type.value,
|
||||
"discount_value": coupon.discount_value
|
||||
}
|
||||
|
||||
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
|
||||
trial_end = calculate_trial_end_date(0, coupon.discount_value)
|
||||
preview["trial_end_date"] = trial_end.isoformat()
|
||||
preview["total_trial_days"] = 0 + coupon.discount_value
|
||||
|
||||
return preview
|
||||
|
||||
def _calculate_discount_applied(
|
||||
self,
|
||||
coupon: Coupon,
|
||||
base_trial_days: int
|
||||
) -> dict:
|
||||
"""Calculate the actual discount that will be applied"""
|
||||
discount = {
|
||||
"type": coupon.discount_type.value,
|
||||
"value": coupon.discount_value,
|
||||
"description": format_discount_description(coupon)
|
||||
}
|
||||
|
||||
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
|
||||
total_trial_days = base_trial_days + coupon.discount_value
|
||||
trial_end = calculate_trial_end_date(base_trial_days, coupon.discount_value)
|
||||
|
||||
discount["base_trial_days"] = base_trial_days
|
||||
discount["extension_days"] = coupon.discount_value
|
||||
discount["total_trial_days"] = total_trial_days
|
||||
discount["trial_end_date"] = trial_end.isoformat()
|
||||
|
||||
elif coupon.discount_type == DiscountType.PERCENTAGE:
|
||||
discount["percentage_off"] = coupon.discount_value
|
||||
|
||||
elif coupon.discount_type == DiscountType.FIXED_AMOUNT:
|
||||
discount["amount_off_cents"] = coupon.discount_value
|
||||
discount["amount_off_euros"] = coupon.discount_value / 100
|
||||
|
||||
return discount
|
||||
283
services/tenant/app/repositories/event_repository.py
Normal file
283
services/tenant/app/repositories/event_repository.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Event Repository
|
||||
Data access layer for events
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import date, datetime
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from app.models.events import Event, EventTemplate
|
||||
from shared.database.repository import BaseRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class EventRepository(BaseRepository[Event]):
|
||||
"""Repository for event management"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(Event, session)
|
||||
|
||||
async def get_events_by_date_range(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
start_date: date,
|
||||
end_date: date,
|
||||
event_types: List[str] = None,
|
||||
confirmed_only: bool = False
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get events within a date range.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
start_date: Start date (inclusive)
|
||||
end_date: End date (inclusive)
|
||||
event_types: Optional filter by event types
|
||||
confirmed_only: Only return confirmed events
|
||||
|
||||
Returns:
|
||||
List of Event objects
|
||||
"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_date >= start_date,
|
||||
Event.event_date <= end_date
|
||||
)
|
||||
)
|
||||
|
||||
if event_types:
|
||||
query = query.where(Event.event_type.in_(event_types))
|
||||
|
||||
if confirmed_only:
|
||||
query = query.where(Event.is_confirmed == True)
|
||||
|
||||
query = query.order_by(Event.event_date)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
logger.debug("Retrieved events by date range",
|
||||
tenant_id=str(tenant_id),
|
||||
start_date=start_date.isoformat(),
|
||||
end_date=end_date.isoformat(),
|
||||
count=len(events))
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get events by date range",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_events_for_date(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
event_date: date
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get all events for a specific date.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
event_date: Date to get events for
|
||||
|
||||
Returns:
|
||||
List of Event objects
|
||||
"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_date == event_date
|
||||
)
|
||||
).order_by(Event.start_time)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get events for date",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_upcoming_events(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
days_ahead: int = 30,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""
|
||||
Get upcoming events.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
days_ahead: Number of days to look ahead
|
||||
limit: Maximum number of events to return
|
||||
|
||||
Returns:
|
||||
List of upcoming Event objects
|
||||
"""
|
||||
try:
|
||||
from datetime import date, timedelta
|
||||
|
||||
today = date.today()
|
||||
future_date = today + timedelta(days=days_ahead)
|
||||
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_date >= today,
|
||||
Event.event_date <= future_date
|
||||
)
|
||||
).order_by(Event.event_date).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get upcoming events",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def create_event(self, event_data: Dict[str, Any]) -> Event:
|
||||
"""Create a new event"""
|
||||
try:
|
||||
event = Event(**event_data)
|
||||
self.session.add(event)
|
||||
await self.session.flush()
|
||||
|
||||
logger.info("Created event",
|
||||
event_id=str(event.id),
|
||||
event_name=event.event_name,
|
||||
event_date=event.event_date.isoformat())
|
||||
|
||||
return event
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create event", error=str(e))
|
||||
raise
|
||||
|
||||
async def update_event_actual_impact(
|
||||
self,
|
||||
event_id: UUID,
|
||||
actual_impact_multiplier: float,
|
||||
actual_sales_increase_percent: float
|
||||
) -> Optional[Event]:
|
||||
"""
|
||||
Update event with actual impact after it occurs.
|
||||
|
||||
Args:
|
||||
event_id: Event UUID
|
||||
actual_impact_multiplier: Actual demand multiplier observed
|
||||
actual_sales_increase_percent: Actual sales increase percentage
|
||||
|
||||
Returns:
|
||||
Updated Event or None
|
||||
"""
|
||||
try:
|
||||
event = await self.get(event_id)
|
||||
if not event:
|
||||
return None
|
||||
|
||||
event.actual_impact_multiplier = actual_impact_multiplier
|
||||
event.actual_sales_increase_percent = actual_sales_increase_percent
|
||||
|
||||
await self.session.flush()
|
||||
|
||||
logger.info("Updated event actual impact",
|
||||
event_id=str(event_id),
|
||||
actual_multiplier=actual_impact_multiplier)
|
||||
|
||||
return event
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update event actual impact",
|
||||
event_id=str(event_id),
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_events_by_type(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
event_type: str,
|
||||
limit: int = 100
|
||||
) -> List[Event]:
|
||||
"""Get events by type"""
|
||||
try:
|
||||
query = select(Event).where(
|
||||
and_(
|
||||
Event.tenant_id == tenant_id,
|
||||
Event.event_type == event_type
|
||||
)
|
||||
).order_by(Event.event_date.desc()).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
events = result.scalars().all()
|
||||
|
||||
return list(events)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get events by type",
|
||||
tenant_id=str(tenant_id),
|
||||
event_type=event_type,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
|
||||
class EventTemplateRepository(BaseRepository[EventTemplate]):
|
||||
"""Repository for event template management"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(EventTemplate, session)
|
||||
|
||||
async def get_active_templates(self, tenant_id: UUID) -> List[EventTemplate]:
|
||||
"""Get all active event templates for a tenant"""
|
||||
try:
|
||||
query = select(EventTemplate).where(
|
||||
and_(
|
||||
EventTemplate.tenant_id == tenant_id,
|
||||
EventTemplate.is_active == True
|
||||
)
|
||||
).order_by(EventTemplate.template_name)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
templates = result.scalars().all()
|
||||
|
||||
return list(templates)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active templates",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def create_template(self, template_data: Dict[str, Any]) -> EventTemplate:
|
||||
"""Create a new event template"""
|
||||
try:
|
||||
template = EventTemplate(**template_data)
|
||||
self.session.add(template)
|
||||
await self.session.flush()
|
||||
|
||||
logger.info("Created event template",
|
||||
template_id=str(template.id),
|
||||
template_name=template.template_name)
|
||||
|
||||
return template
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create event template", error=str(e))
|
||||
raise
|
||||
812
services/tenant/app/repositories/subscription_repository.py
Normal file
812
services/tenant/app/repositories/subscription_repository.py
Normal file
@@ -0,0 +1,812 @@
|
||||
"""
|
||||
Subscription Repository
|
||||
Repository for subscription operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import Subscription
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
from shared.subscription.plans import SubscriptionPlanMetadata, QuotaLimits, PlanPricing
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionRepository(TenantBaseRepository):
|
||||
"""Repository for subscription operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Subscriptions are relatively stable, medium cache time (10 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
|
||||
"""Create a new subscription with validation"""
|
||||
try:
|
||||
# Validate subscription data
|
||||
validation_result = self._validate_tenant_data(
|
||||
subscription_data,
|
||||
["tenant_id", "plan"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
|
||||
|
||||
# Check for existing active subscription
|
||||
existing_subscription = await self.get_active_subscription(
|
||||
subscription_data["tenant_id"]
|
||||
)
|
||||
|
||||
if existing_subscription:
|
||||
raise DuplicateRecordError(f"Tenant already has an active subscription")
|
||||
|
||||
# Set default values based on plan from centralized configuration
|
||||
plan = subscription_data["plan"]
|
||||
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
|
||||
|
||||
# Set defaults from centralized plan configuration
|
||||
if "monthly_price" not in subscription_data:
|
||||
billing_cycle = subscription_data.get("billing_cycle", "monthly")
|
||||
subscription_data["monthly_price"] = float(
|
||||
PlanPricing.get_price(plan, billing_cycle)
|
||||
)
|
||||
|
||||
if "max_users" not in subscription_data:
|
||||
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
|
||||
|
||||
if "max_locations" not in subscription_data:
|
||||
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
|
||||
|
||||
if "max_products" not in subscription_data:
|
||||
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
|
||||
|
||||
if "features" not in subscription_data:
|
||||
subscription_data["features"] = {
|
||||
feature: True for feature in plan_info.get("features", [])
|
||||
}
|
||||
|
||||
# Set default subscription values
|
||||
if "status" not in subscription_data:
|
||||
subscription_data["status"] = "active"
|
||||
if "billing_cycle" not in subscription_data:
|
||||
subscription_data["billing_cycle"] = "monthly"
|
||||
if "next_billing_date" not in subscription_data:
|
||||
# Set next billing date based on cycle
|
||||
if subscription_data["billing_cycle"] == "yearly":
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
|
||||
else:
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
# Check if subscription with this subscription_id already exists to prevent duplicates
|
||||
if subscription_data.get('subscription_id'):
|
||||
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
logger.info("Existing subscription updated",
|
||||
subscription_id=subscription_data['subscription_id'],
|
||||
tenant_id=subscription_data.get('tenant_id'),
|
||||
plan=subscription_data.get('plan'))
|
||||
|
||||
return updated_subscription
|
||||
|
||||
# Create subscription
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
logger.info("Subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
tenant_id=subscription.tenant_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
return subscription
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create subscription",
|
||||
tenant_id=subscription_data.get("tenant_id"),
|
||||
plan=subscription_data.get("plan"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create subscription: {str(e)}")
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str) -> Optional[Subscription]:
|
||||
"""Get subscription by tenant ID"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions[0] if subscriptions else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription by tenant ID",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
||||
|
||||
async def get_by_provider_id(self, subscription_id: str) -> Optional[Subscription]:
|
||||
"""Get subscription by payment provider subscription ID"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"subscription_id": subscription_id
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions[0] if subscriptions else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription by provider ID",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
||||
|
||||
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
|
||||
"""Get active subscription for tenant"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"status": "active"
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions[0] if subscriptions else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active subscription",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscription: {str(e)}")
|
||||
|
||||
async def get_tenant_subscriptions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
include_inactive: bool = False
|
||||
) -> List[Subscription]:
|
||||
"""Get all subscriptions for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if not include_inactive:
|
||||
filters["status"] = "active"
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant subscriptions",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
|
||||
|
||||
async def update_subscription_plan(
|
||||
self,
|
||||
subscription_id: str,
|
||||
new_plan: str,
|
||||
billing_cycle: str = "monthly"
|
||||
) -> Optional[Subscription]:
|
||||
"""Update subscription plan and pricing using centralized configuration"""
|
||||
try:
|
||||
valid_plans = ["starter", "professional", "enterprise"]
|
||||
if new_plan not in valid_plans:
|
||||
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
|
||||
|
||||
# Get current subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
# Get new plan configuration from centralized source
|
||||
plan_info = SubscriptionPlanMetadata.get_plan_info(new_plan)
|
||||
|
||||
# Update subscription with new plan details
|
||||
update_data = {
|
||||
"plan": new_plan,
|
||||
"monthly_price": float(PlanPricing.get_price(new_plan, billing_cycle)),
|
||||
"billing_cycle": billing_cycle,
|
||||
"max_users": QuotaLimits.get_limit('MAX_USERS', new_plan) or -1,
|
||||
"max_locations": QuotaLimits.get_limit('MAX_LOCATIONS', new_plan) or -1,
|
||||
"max_products": QuotaLimits.get_limit('MAX_PRODUCTS', new_plan) or -1,
|
||||
"features": {feature: True for feature in plan_info.get("features", [])},
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription plan updated",
|
||||
subscription_id=subscription_id,
|
||||
new_plan=new_plan,
|
||||
new_price=update_data["monthly_price"])
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription plan",
|
||||
subscription_id=subscription_id,
|
||||
new_plan=new_plan,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update plan: {str(e)}")
|
||||
|
||||
async def update_subscription_status(
|
||||
self,
|
||||
subscription_id: str,
|
||||
status: str,
|
||||
additional_data: Dict[str, Any] = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Update subscription status with optional additional data"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
update_data = {
|
||||
"status": status,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
# Merge additional data if provided
|
||||
if additional_data:
|
||||
update_data.update(additional_data)
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
if subscription.tenant_id:
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription status updated",
|
||||
subscription_id=subscription_id,
|
||||
new_status=status,
|
||||
additional_data=additional_data)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription status",
|
||||
subscription_id=subscription_id,
|
||||
status=status,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
|
||||
|
||||
async def cancel_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
reason: str = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Cancel a subscription"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
update_data = {
|
||||
"status": "cancelled",
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription cancelled",
|
||||
subscription_id=subscription_id,
|
||||
reason=reason)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
|
||||
|
||||
async def suspend_subscription(
|
||||
self,
|
||||
subscription_id: str,
|
||||
reason: str = None
|
||||
) -> Optional[Subscription]:
|
||||
"""Suspend a subscription"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
update_data = {
|
||||
"status": "suspended",
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription suspended",
|
||||
subscription_id=subscription_id,
|
||||
reason=reason)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to suspend subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
|
||||
|
||||
async def reactivate_subscription(
|
||||
self,
|
||||
subscription_id: str
|
||||
) -> Optional[Subscription]:
|
||||
"""Reactivate a cancelled or suspended subscription"""
|
||||
try:
|
||||
# Get subscription to find tenant_id for cache invalidation
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
# Reset billing date when reactivating
|
||||
next_billing_date = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
update_data = {
|
||||
"status": "active",
|
||||
"next_billing_date": next_billing_date,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache
|
||||
await self._invalidate_cache(str(subscription.tenant_id))
|
||||
|
||||
logger.info("Subscription reactivated",
|
||||
subscription_id=subscription_id,
|
||||
next_billing_date=next_billing_date)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to reactivate subscription",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
|
||||
|
||||
async def get_subscriptions_due_for_billing(
|
||||
self,
|
||||
days_ahead: int = 3
|
||||
) -> List[Subscription]:
|
||||
"""Get subscriptions that need billing in the next N days"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
AND next_billing_date <= :cutoff_date
|
||||
ORDER BY next_billing_date ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
subscriptions = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
subscription = self.model(**record_dict)
|
||||
subscriptions.append(subscription)
|
||||
|
||||
return subscriptions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscriptions due for billing",
|
||||
days_ahead=days_ahead,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def update_billing_date(
|
||||
self,
|
||||
subscription_id: str,
|
||||
next_billing_date: datetime
|
||||
) -> Optional[Subscription]:
|
||||
"""Update next billing date for subscription"""
|
||||
try:
|
||||
updated_subscription = await self.update(subscription_id, {
|
||||
"next_billing_date": next_billing_date,
|
||||
"updated_at": datetime.utcnow()
|
||||
})
|
||||
|
||||
logger.info("Subscription billing date updated",
|
||||
subscription_id=subscription_id,
|
||||
next_billing_date=next_billing_date)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update billing date",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update billing date: {str(e)}")
|
||||
|
||||
async def get_subscription_statistics(self) -> Dict[str, Any]:
|
||||
"""Get subscription statistics"""
|
||||
try:
|
||||
# Get counts by plan
|
||||
plan_query = text("""
|
||||
SELECT plan, COUNT(*) as count
|
||||
FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
GROUP BY plan
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(plan_query)
|
||||
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
|
||||
|
||||
# Get counts by status
|
||||
status_query = text("""
|
||||
SELECT status, COUNT(*) as count
|
||||
FROM subscriptions
|
||||
GROUP BY status
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(status_query)
|
||||
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
|
||||
|
||||
# Get revenue statistics
|
||||
revenue_query = text("""
|
||||
SELECT
|
||||
SUM(monthly_price) as total_monthly_revenue,
|
||||
AVG(monthly_price) as avg_monthly_price,
|
||||
COUNT(*) as total_active_subscriptions
|
||||
FROM subscriptions
|
||||
WHERE status = 'active'
|
||||
""")
|
||||
|
||||
revenue_result = await self.session.execute(revenue_query)
|
||||
revenue_row = revenue_result.fetchone()
|
||||
|
||||
# Get upcoming billing count
|
||||
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
|
||||
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
|
||||
|
||||
return {
|
||||
"subscriptions_by_plan": subscriptions_by_plan,
|
||||
"subscriptions_by_status": subscriptions_by_status,
|
||||
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
|
||||
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
|
||||
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
|
||||
"upcoming_billing_30d": upcoming_billing
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription statistics", error=str(e))
|
||||
return {
|
||||
"subscriptions_by_plan": {},
|
||||
"subscriptions_by_status": {},
|
||||
"total_monthly_revenue": 0.0,
|
||||
"avg_monthly_price": 0.0,
|
||||
"total_active_subscriptions": 0,
|
||||
"upcoming_billing_30d": 0
|
||||
}
|
||||
|
||||
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
|
||||
"""Clean up very old cancelled subscriptions (2 years)"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM subscriptions
|
||||
WHERE status IN ('cancelled', 'suspended')
|
||||
AND updated_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old subscriptions",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def _invalidate_cache(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Invalidate subscription cache for a tenant
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
"""
|
||||
try:
|
||||
from app.services.subscription_cache import get_subscription_cache_service
|
||||
|
||||
cache_service = get_subscription_cache_service()
|
||||
await cache_service.invalidate_subscription_cache(tenant_id)
|
||||
|
||||
logger.debug("Invalidated subscription cache from repository",
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to invalidate cache (non-critical)",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
# ========================================================================
|
||||
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
|
||||
# ========================================================================
|
||||
|
||||
async def create_tenant_independent_subscription(
|
||||
self,
|
||||
subscription_data: Dict[str, Any]
|
||||
) -> Subscription:
|
||||
"""Create a subscription not linked to any tenant (for registration flow)"""
|
||||
try:
|
||||
# Validate required data for tenant-independent subscription
|
||||
# user_id may not exist during registration, so validate other required fields
|
||||
required_fields = ["plan", "subscription_id", "customer_id"]
|
||||
validation_result = self._validate_tenant_data(subscription_data, required_fields)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
|
||||
|
||||
# Ensure tenant_id is not provided (this is tenant-independent)
|
||||
if "tenant_id" in subscription_data and subscription_data["tenant_id"]:
|
||||
raise ValidationError("tenant_id should not be provided for tenant-independent subscriptions")
|
||||
|
||||
# Set tenant-independent specific fields
|
||||
subscription_data["tenant_id"] = None
|
||||
subscription_data["is_tenant_linked"] = False
|
||||
subscription_data["tenant_linking_status"] = "pending"
|
||||
subscription_data["linked_at"] = None
|
||||
|
||||
# Set default values based on plan from centralized configuration
|
||||
plan = subscription_data["plan"]
|
||||
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
|
||||
|
||||
# Set defaults from centralized plan configuration
|
||||
if "monthly_price" not in subscription_data:
|
||||
billing_cycle = subscription_data.get("billing_cycle", "monthly")
|
||||
subscription_data["monthly_price"] = float(
|
||||
PlanPricing.get_price(plan, billing_cycle)
|
||||
)
|
||||
|
||||
if "max_users" not in subscription_data:
|
||||
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
|
||||
|
||||
if "max_locations" not in subscription_data:
|
||||
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
|
||||
|
||||
if "max_products" not in subscription_data:
|
||||
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
|
||||
|
||||
if "features" not in subscription_data:
|
||||
subscription_data["features"] = {
|
||||
feature: True for feature in plan_info.get("features", [])
|
||||
}
|
||||
|
||||
# Set default subscription values
|
||||
if "status" not in subscription_data:
|
||||
subscription_data["status"] = "pending_tenant_linking"
|
||||
if "billing_cycle" not in subscription_data:
|
||||
subscription_data["billing_cycle"] = "monthly"
|
||||
if "next_billing_date" not in subscription_data:
|
||||
# Set next billing date based on cycle
|
||||
if subscription_data["billing_cycle"] == "yearly":
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
|
||||
else:
|
||||
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
|
||||
|
||||
# Check if subscription with this subscription_id already exists
|
||||
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
logger.info("Existing tenant-independent subscription updated",
|
||||
subscription_id=subscription_data['subscription_id'],
|
||||
user_id=subscription_data.get('user_id'),
|
||||
plan=subscription_data.get('plan'))
|
||||
|
||||
return updated_subscription
|
||||
else:
|
||||
# Create new subscription, but handle potential duplicate errors
|
||||
try:
|
||||
subscription = await self.create(subscription_data)
|
||||
|
||||
logger.info("Tenant-independent subscription created successfully",
|
||||
subscription_id=subscription.id,
|
||||
user_id=subscription.user_id,
|
||||
plan=subscription.plan,
|
||||
monthly_price=subscription.monthly_price)
|
||||
|
||||
return subscription
|
||||
except DuplicateRecordError:
|
||||
# Another process may have created the subscription between our check and create
|
||||
# Try to get the existing subscription and return it
|
||||
final_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
|
||||
if final_subscription:
|
||||
logger.info("Race condition detected: subscription already created by another process",
|
||||
subscription_id=subscription_data['subscription_id'])
|
||||
return final_subscription
|
||||
else:
|
||||
# This shouldn't happen, but re-raise the error if we can't find it
|
||||
raise
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tenant-independent subscription",
|
||||
user_id=subscription_data.get("user_id"),
|
||||
plan=subscription_data.get("plan"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create tenant-independent subscription: {str(e)}")
|
||||
|
||||
async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]:
|
||||
"""Get all subscriptions waiting to be linked to tenants"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"tenant_linking_status": "pending",
|
||||
"is_tenant_linked": False
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending tenant linking subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
|
||||
|
||||
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
|
||||
"""Get pending tenant linking subscriptions for a specific user"""
|
||||
try:
|
||||
subscriptions = await self.get_multi(
|
||||
filters={
|
||||
"user_id": user_id,
|
||||
"tenant_linking_status": "pending",
|
||||
"is_tenant_linked": False
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return subscriptions
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending subscriptions by user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
|
||||
|
||||
async def link_subscription_to_tenant(
|
||||
self,
|
||||
subscription_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str
|
||||
) -> Subscription:
|
||||
"""Link a pending subscription to a tenant"""
|
||||
try:
|
||||
# Get the subscription first
|
||||
subscription = await self.get_by_id(subscription_id)
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription {subscription_id} not found")
|
||||
|
||||
# Validate subscription can be linked
|
||||
if not subscription.can_be_linked_to_tenant(user_id):
|
||||
raise ValidationError(
|
||||
f"Subscription {subscription_id} cannot be linked to tenant by user {user_id}. "
|
||||
f"Current status: {subscription.tenant_linking_status}, "
|
||||
f"User: {subscription.user_id}, "
|
||||
f"Already linked: {subscription.is_tenant_linked}"
|
||||
)
|
||||
|
||||
# Update subscription with tenant information
|
||||
update_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"is_tenant_linked": True,
|
||||
"tenant_linking_status": "completed",
|
||||
"linked_at": datetime.utcnow(),
|
||||
"status": "active", # Activate subscription when linked to tenant
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
updated_subscription = await self.update(subscription_id, update_data)
|
||||
|
||||
# Invalidate cache for the tenant
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info("Subscription linked to tenant successfully",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to link subscription to tenant",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}")
|
||||
|
||||
async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int:
|
||||
"""Clean up subscriptions that were never linked to tenants"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM subscriptions
|
||||
WHERE tenant_linking_status = 'pending'
|
||||
AND is_tenant_linked = FALSE
|
||||
AND created_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up orphaned subscriptions",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup orphaned subscriptions",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_by_customer_id(self, customer_id: str) -> List[Subscription]:
|
||||
"""
|
||||
Get subscriptions by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
List of Subscription objects
|
||||
"""
|
||||
try:
|
||||
query = select(Subscription).where(Subscription.customer_id == customer_id)
|
||||
result = await self.session.execute(query)
|
||||
subscriptions = result.scalars().all()
|
||||
|
||||
logger.debug("Found subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
count=len(subscriptions))
|
||||
return subscriptions
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting subscriptions by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get subscriptions by customer_id: {str(e)}")
|
||||
218
services/tenant/app/repositories/tenant_location_repository.py
Normal file
218
services/tenant/app/repositories/tenant_location_repository.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Tenant Location Repository
|
||||
Handles database operations for tenant location data
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update, delete
|
||||
from sqlalchemy.orm import selectinload
|
||||
import structlog
|
||||
|
||||
from app.models.tenant_location import TenantLocation
|
||||
from app.models.tenants import Tenant
|
||||
from shared.database.exceptions import DatabaseError
|
||||
from .base import BaseRepository
|
||||
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantLocationRepository(BaseRepository):
|
||||
"""Repository for tenant location operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
super().__init__(TenantLocation, session)
|
||||
|
||||
async def create_location(self, location_data: Dict[str, Any]) -> TenantLocation:
|
||||
"""
|
||||
Create a new tenant location
|
||||
|
||||
Args:
|
||||
location_data: Dictionary containing location information
|
||||
|
||||
Returns:
|
||||
Created TenantLocation object
|
||||
"""
|
||||
try:
|
||||
# Create new location instance
|
||||
location = TenantLocation(**location_data)
|
||||
self.session.add(location)
|
||||
await self.session.commit()
|
||||
await self.session.refresh(location)
|
||||
logger.info(f"Created new tenant location: {location.id} for tenant {location.tenant_id}")
|
||||
return location
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Failed to create tenant location: {str(e)}")
|
||||
raise DatabaseError(f"Failed to create tenant location: {str(e)}")
|
||||
|
||||
async def get_location_by_id(self, location_id: str) -> Optional[TenantLocation]:
|
||||
"""
|
||||
Get a location by its ID
|
||||
|
||||
Args:
|
||||
location_id: UUID of the location
|
||||
|
||||
Returns:
|
||||
TenantLocation object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(TenantLocation.id == location_id)
|
||||
result = await self.session.execute(stmt)
|
||||
location = result.scalar_one_or_none()
|
||||
return location
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get location by ID: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get location by ID: {str(e)}")
|
||||
|
||||
async def get_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
|
||||
"""
|
||||
Get all locations for a specific tenant
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
List of TenantLocation objects
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(TenantLocation.tenant_id == tenant_id)
|
||||
result = await self.session.execute(stmt)
|
||||
locations = result.scalars().all()
|
||||
return locations
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get locations by tenant: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get locations by tenant: {str(e)}")
|
||||
|
||||
async def get_location_by_type(self, tenant_id: str, location_type: str) -> Optional[TenantLocation]:
|
||||
"""
|
||||
Get a location by tenant and type
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
location_type: Type of location (e.g., 'central_production', 'retail_outlet')
|
||||
|
||||
Returns:
|
||||
TenantLocation object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(
|
||||
TenantLocation.tenant_id == tenant_id,
|
||||
TenantLocation.location_type == location_type
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
location = result.scalar_one_or_none()
|
||||
return location
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get location by type: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get location by type: {str(e)}")
|
||||
|
||||
async def update_location(self, location_id: str, location_data: Dict[str, Any]) -> Optional[TenantLocation]:
|
||||
"""
|
||||
Update a tenant location
|
||||
|
||||
Args:
|
||||
location_id: UUID of the location to update
|
||||
location_data: Dictionary containing updated location information
|
||||
|
||||
Returns:
|
||||
Updated TenantLocation object if successful, None if location not found
|
||||
"""
|
||||
try:
|
||||
stmt = (
|
||||
update(TenantLocation)
|
||||
.where(TenantLocation.id == location_id)
|
||||
.values(**location_data)
|
||||
)
|
||||
await self.session.execute(stmt)
|
||||
|
||||
# Now fetch the updated location
|
||||
location_stmt = select(TenantLocation).where(TenantLocation.id == location_id)
|
||||
result = await self.session.execute(location_stmt)
|
||||
location = result.scalar_one_or_none()
|
||||
|
||||
if location:
|
||||
await self.session.commit()
|
||||
logger.info(f"Updated tenant location: {location_id}")
|
||||
return location
|
||||
else:
|
||||
await self.session.rollback()
|
||||
logger.warning(f"Location not found for update: {location_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Failed to update location: {str(e)}")
|
||||
raise DatabaseError(f"Failed to update location: {str(e)}")
|
||||
|
||||
async def delete_location(self, location_id: str) -> bool:
|
||||
"""
|
||||
Delete a tenant location
|
||||
|
||||
Args:
|
||||
location_id: UUID of the location to delete
|
||||
|
||||
Returns:
|
||||
True if deleted successfully, False if location not found
|
||||
"""
|
||||
try:
|
||||
stmt = delete(TenantLocation).where(TenantLocation.id == location_id)
|
||||
result = await self.session.execute(stmt)
|
||||
|
||||
if result.rowcount > 0:
|
||||
await self.session.commit()
|
||||
logger.info(f"Deleted tenant location: {location_id}")
|
||||
return True
|
||||
else:
|
||||
await self.session.rollback()
|
||||
logger.warning(f"Location not found for deletion: {location_id}")
|
||||
return False
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Failed to delete location: {str(e)}")
|
||||
raise DatabaseError(f"Failed to delete location: {str(e)}")
|
||||
|
||||
async def get_active_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
|
||||
"""
|
||||
Get all active locations for a specific tenant
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
List of active TenantLocation objects
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(
|
||||
TenantLocation.tenant_id == tenant_id,
|
||||
TenantLocation.is_active == True
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
locations = result.scalars().all()
|
||||
return locations
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get active locations by tenant: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get active locations by tenant: {str(e)}")
|
||||
|
||||
async def get_locations_by_tenant_with_type(self, tenant_id: str, location_types: List[str]) -> List[TenantLocation]:
|
||||
"""
|
||||
Get locations for a specific tenant filtered by location types
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
location_types: List of location types to filter by
|
||||
|
||||
Returns:
|
||||
List of TenantLocation objects matching the criteria
|
||||
"""
|
||||
try:
|
||||
stmt = select(TenantLocation).where(
|
||||
TenantLocation.tenant_id == tenant_id,
|
||||
TenantLocation.location_type.in_(location_types)
|
||||
)
|
||||
result = await self.session.execute(stmt)
|
||||
locations = result.scalars().all()
|
||||
return locations
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get locations by tenant and type: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get locations by tenant and type: {str(e)}")
|
||||
588
services/tenant/app/repositories/tenant_member_repository.py
Normal file
588
services/tenant/app/repositories/tenant_member_repository.py
Normal file
@@ -0,0 +1,588 @@
|
||||
"""
|
||||
Tenant Member Repository
|
||||
Repository for tenant membership operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import json
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import TenantMember
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
from shared.config.base import is_internal_service
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantMemberRepository(TenantBaseRepository):
|
||||
"""Repository for tenant member operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Member data changes more frequently, shorter cache time (5 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_membership(self, membership_data: Dict[str, Any]) -> TenantMember:
|
||||
"""Create a new tenant membership with validation"""
|
||||
try:
|
||||
# Validate membership data
|
||||
validation_result = self._validate_tenant_data(
|
||||
membership_data,
|
||||
["tenant_id", "user_id", "role"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid membership data: {validation_result['errors']}")
|
||||
|
||||
# Check for existing membership
|
||||
existing_membership = await self.get_membership(
|
||||
membership_data["tenant_id"],
|
||||
membership_data["user_id"]
|
||||
)
|
||||
|
||||
if existing_membership and existing_membership.is_active:
|
||||
raise DuplicateRecordError(f"User is already an active member of this tenant")
|
||||
|
||||
# Set default values
|
||||
if "is_active" not in membership_data:
|
||||
membership_data["is_active"] = True
|
||||
if "joined_at" not in membership_data:
|
||||
membership_data["joined_at"] = datetime.utcnow()
|
||||
|
||||
# Set permissions based on role
|
||||
if "permissions" not in membership_data:
|
||||
membership_data["permissions"] = self._get_default_permissions(
|
||||
membership_data["role"]
|
||||
)
|
||||
|
||||
# If reactivating existing membership
|
||||
if existing_membership and not existing_membership.is_active:
|
||||
# Update existing membership
|
||||
update_data = {
|
||||
key: value for key, value in membership_data.items()
|
||||
if key not in ["tenant_id", "user_id"]
|
||||
}
|
||||
membership = await self.update(existing_membership.id, update_data)
|
||||
else:
|
||||
# Create new membership
|
||||
membership = await self.create(membership_data)
|
||||
|
||||
logger.info("Tenant membership created",
|
||||
membership_id=membership.id,
|
||||
tenant_id=membership.tenant_id,
|
||||
user_id=membership.user_id,
|
||||
role=membership.role)
|
||||
|
||||
return membership
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create membership",
|
||||
tenant_id=membership_data.get("tenant_id"),
|
||||
user_id=membership_data.get("user_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create membership: {str(e)}")
|
||||
|
||||
async def get_membership(self, tenant_id: str, user_id: str) -> Optional[TenantMember]:
|
||||
"""Get specific membership by tenant and user"""
|
||||
try:
|
||||
# Validate that user_id is a proper UUID format for actual users
|
||||
# Service names like 'inventory-service' should be handled differently
|
||||
import uuid
|
||||
try:
|
||||
uuid.UUID(user_id)
|
||||
is_valid_uuid = True
|
||||
except ValueError:
|
||||
is_valid_uuid = False
|
||||
|
||||
# For internal service access, return None to indicate no user membership
|
||||
# Service access should be handled at the API layer
|
||||
if not is_valid_uuid:
|
||||
if is_internal_service(user_id):
|
||||
# This is a known internal service request, return None
|
||||
# Service access is granted at the API endpoint level
|
||||
logger.debug("Internal service detected in membership lookup",
|
||||
service=user_id,
|
||||
tenant_id=tenant_id)
|
||||
return None
|
||||
elif user_id == "unknown-service":
|
||||
# Special handling for 'unknown-service' which commonly occurs in demo sessions
|
||||
# This happens when service identification fails during demo operations
|
||||
logger.warning("Demo session service identification issue",
|
||||
service=user_id,
|
||||
tenant_id=tenant_id,
|
||||
message="Service not properly identified - likely demo session context")
|
||||
return None
|
||||
else:
|
||||
# This is an unknown service
|
||||
# Return None to prevent database errors, but log a warning
|
||||
logger.warning("Unknown service detected in membership lookup",
|
||||
service=user_id,
|
||||
tenant_id=tenant_id,
|
||||
message="Service not in internal services registry")
|
||||
return None
|
||||
|
||||
memberships = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id
|
||||
},
|
||||
limit=1,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return memberships[0] if memberships else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get membership",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get membership: {str(e)}")
|
||||
|
||||
async def get_tenant_members(
|
||||
self,
|
||||
tenant_id: str,
|
||||
active_only: bool = True,
|
||||
role: str = None,
|
||||
include_user_info: bool = False
|
||||
) -> List[TenantMember]:
|
||||
"""Get all members of a tenant with optional user info enrichment"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
|
||||
if active_only:
|
||||
filters["is_active"] = True
|
||||
|
||||
if role:
|
||||
filters["role"] = role
|
||||
|
||||
members = await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="joined_at",
|
||||
order_desc=False
|
||||
)
|
||||
|
||||
# If include_user_info is True, enrich with user data from auth service
|
||||
if include_user_info and members:
|
||||
members = await self._enrich_members_with_user_info(members)
|
||||
|
||||
return members
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant members",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get members: {str(e)}")
|
||||
|
||||
async def _enrich_members_with_user_info(self, members: List[TenantMember]) -> List[TenantMember]:
|
||||
"""Enrich member objects with user information from auth service using batch endpoint"""
|
||||
try:
|
||||
import httpx
|
||||
import os
|
||||
|
||||
if not members:
|
||||
return members
|
||||
|
||||
# Get unique user IDs
|
||||
user_ids = list(set([str(member.user_id) for member in members]))
|
||||
|
||||
if not user_ids:
|
||||
return members
|
||||
|
||||
# Fetch user data from auth service using batch endpoint
|
||||
# Using internal service communication
|
||||
auth_service_url = os.getenv('AUTH_SERVICE_URL', 'http://auth-service:8000')
|
||||
|
||||
user_data_map = {}
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
# Use batch endpoint for efficiency
|
||||
response = await client.post(
|
||||
f"{auth_service_url}/api/v1/auth/users/batch",
|
||||
json={"user_ids": user_ids},
|
||||
timeout=10.0,
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
batch_result = response.json()
|
||||
user_data_map = batch_result.get("users", {})
|
||||
logger.info(
|
||||
"Batch user fetch successful",
|
||||
requested_count=len(user_ids),
|
||||
found_count=batch_result.get("found_count", 0)
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Batch user fetch failed, falling back to individual calls",
|
||||
status_code=response.status_code
|
||||
)
|
||||
# Fallback to individual calls if batch fails
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
user_data_map[user_id] = user_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Batch user fetch failed, falling back to individual calls", error=str(e))
|
||||
# Fallback to individual calls
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
user_data_map[user_id] = user_data
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
|
||||
continue
|
||||
|
||||
# Enrich members with user data
|
||||
for member in members:
|
||||
user_id_str = str(member.user_id)
|
||||
if user_id_str in user_data_map and user_data_map[user_id_str] is not None:
|
||||
user_data = user_data_map[user_id_str]
|
||||
# Add user fields as attributes to the member object
|
||||
member.user_email = user_data.get("email")
|
||||
member.user_full_name = user_data.get("full_name")
|
||||
member.user = user_data # Store full user object for compatibility
|
||||
else:
|
||||
# Set defaults for missing users
|
||||
member.user_email = None
|
||||
member.user_full_name = "Unknown User"
|
||||
member.user = None
|
||||
|
||||
return members
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to enrich members with user info", error=str(e))
|
||||
# Return members without enrichment if it fails
|
||||
return members
|
||||
|
||||
async def get_user_memberships(
|
||||
self,
|
||||
user_id: str,
|
||||
active_only: bool = True
|
||||
) -> List[TenantMember]:
|
||||
"""Get all tenants a user is a member of"""
|
||||
try:
|
||||
filters = {"user_id": user_id}
|
||||
|
||||
if active_only:
|
||||
filters["is_active"] = True
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="joined_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user memberships",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get memberships: {str(e)}")
|
||||
|
||||
async def verify_user_access(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Verify if user has access to tenant and return access details"""
|
||||
try:
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
|
||||
if not membership or not membership.is_active:
|
||||
return {
|
||||
"has_access": False,
|
||||
"role": "none",
|
||||
"permissions": []
|
||||
}
|
||||
|
||||
# Parse permissions
|
||||
permissions = []
|
||||
if membership.permissions:
|
||||
try:
|
||||
permissions = json.loads(membership.permissions)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("Invalid permissions JSON for membership",
|
||||
membership_id=membership.id)
|
||||
permissions = self._get_default_permissions(membership.role)
|
||||
|
||||
return {
|
||||
"has_access": True,
|
||||
"role": membership.role,
|
||||
"permissions": permissions,
|
||||
"membership_id": str(membership.id),
|
||||
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify user access",
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"has_access": False,
|
||||
"role": "none",
|
||||
"permissions": []
|
||||
}
|
||||
|
||||
async def update_member_role(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
new_role: str,
|
||||
updated_by: str = None
|
||||
) -> Optional[TenantMember]:
|
||||
"""Update member role and permissions"""
|
||||
try:
|
||||
valid_roles = ["owner", "admin", "member", "viewer"]
|
||||
if new_role not in valid_roles:
|
||||
raise ValidationError(f"Invalid role. Must be one of: {valid_roles}")
|
||||
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
if not membership:
|
||||
raise ValidationError("Membership not found")
|
||||
|
||||
# Get new permissions based on role
|
||||
new_permissions = self._get_default_permissions(new_role)
|
||||
|
||||
updated_membership = await self.update(membership.id, {
|
||||
"role": new_role,
|
||||
"permissions": json.dumps(new_permissions)
|
||||
})
|
||||
|
||||
logger.info("Member role updated",
|
||||
membership_id=membership.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
old_role=membership.role,
|
||||
new_role=new_role,
|
||||
updated_by=updated_by)
|
||||
|
||||
return updated_membership
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update member role",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
new_role=new_role,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update role: {str(e)}")
|
||||
|
||||
async def deactivate_membership(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
deactivated_by: str = None
|
||||
) -> Optional[TenantMember]:
|
||||
"""Deactivate a membership (remove user from tenant)"""
|
||||
try:
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
if not membership:
|
||||
raise ValidationError("Membership not found")
|
||||
|
||||
# Don't allow deactivating the owner
|
||||
if membership.role == "owner":
|
||||
raise ValidationError("Cannot deactivate the owner membership")
|
||||
|
||||
updated_membership = await self.update(membership.id, {
|
||||
"is_active": False
|
||||
})
|
||||
|
||||
logger.info("Membership deactivated",
|
||||
membership_id=membership.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
deactivated_by=deactivated_by)
|
||||
|
||||
return updated_membership
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to deactivate membership",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to deactivate membership: {str(e)}")
|
||||
|
||||
async def reactivate_membership(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
reactivated_by: str = None
|
||||
) -> Optional[TenantMember]:
|
||||
"""Reactivate a deactivated membership"""
|
||||
try:
|
||||
membership = await self.get_membership(tenant_id, user_id)
|
||||
if not membership:
|
||||
raise ValidationError("Membership not found")
|
||||
|
||||
updated_membership = await self.update(membership.id, {
|
||||
"is_active": True,
|
||||
"joined_at": datetime.utcnow() # Update join date
|
||||
})
|
||||
|
||||
logger.info("Membership reactivated",
|
||||
membership_id=membership.id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
reactivated_by=reactivated_by)
|
||||
|
||||
return updated_membership
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to reactivate membership",
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reactivate membership: {str(e)}")
|
||||
|
||||
async def get_membership_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get membership statistics for a tenant"""
|
||||
try:
|
||||
# Get counts by role
|
||||
role_query = text("""
|
||||
SELECT role, COUNT(*) as count
|
||||
FROM tenant_members
|
||||
WHERE tenant_id = :tenant_id AND is_active = true
|
||||
GROUP BY role
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(role_query, {"tenant_id": tenant_id})
|
||||
members_by_role = {row.role: row.count for row in result.fetchall()}
|
||||
|
||||
# Get basic counts
|
||||
total_members = await self.count(filters={"tenant_id": tenant_id})
|
||||
active_members = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
})
|
||||
|
||||
# Get recent activity (members joined in last 30 days)
|
||||
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
|
||||
recent_joins = len(await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
},
|
||||
limit=1000 # High limit to get accurate count
|
||||
))
|
||||
|
||||
# Filter for recent joins (manual filtering since we can't use date range in filters easily)
|
||||
recent_members = 0
|
||||
all_active_members = await self.get_tenant_members(tenant_id, active_only=True)
|
||||
for member in all_active_members:
|
||||
if member.joined_at and member.joined_at >= thirty_days_ago:
|
||||
recent_members += 1
|
||||
|
||||
return {
|
||||
"total_members": total_members,
|
||||
"active_members": active_members,
|
||||
"inactive_members": total_members - active_members,
|
||||
"members_by_role": members_by_role,
|
||||
"recent_joins_30d": recent_members
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get membership statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_members": 0,
|
||||
"active_members": 0,
|
||||
"inactive_members": 0,
|
||||
"members_by_role": {},
|
||||
"recent_joins_30d": 0
|
||||
}
|
||||
|
||||
def _get_default_permissions(self, role: str) -> str:
|
||||
"""Get default permissions JSON string for a role"""
|
||||
permission_map = {
|
||||
"owner": ["read", "write", "admin", "delete"],
|
||||
"admin": ["read", "write", "admin"],
|
||||
"member": ["read", "write"],
|
||||
"viewer": ["read"]
|
||||
}
|
||||
|
||||
permissions = permission_map.get(role, ["read"])
|
||||
return json.dumps(permissions)
|
||||
|
||||
async def bulk_update_permissions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
role_permissions: Dict[str, List[str]]
|
||||
) -> int:
|
||||
"""Bulk update permissions for all members of specific roles"""
|
||||
try:
|
||||
updated_count = 0
|
||||
|
||||
for role, permissions in role_permissions.items():
|
||||
members = await self.get_tenant_members(
|
||||
tenant_id, active_only=True, role=role
|
||||
)
|
||||
|
||||
for member in members:
|
||||
await self.update(member.id, {
|
||||
"permissions": json.dumps(permissions)
|
||||
})
|
||||
updated_count += 1
|
||||
|
||||
logger.info("Bulk updated member permissions",
|
||||
tenant_id=tenant_id,
|
||||
updated_count=updated_count,
|
||||
roles=list(role_permissions.keys()))
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to bulk update permissions",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Bulk permission update failed: {str(e)}")
|
||||
|
||||
async def cleanup_inactive_memberships(self, days_old: int = 180) -> int:
|
||||
"""Clean up old inactive memberships"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM tenant_members
|
||||
WHERE is_active = false
|
||||
AND created_at < :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up inactive memberships",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup inactive memberships",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
680
services/tenant/app/repositories/tenant_repository.py
Normal file
680
services/tenant/app/repositories/tenant_repository.py
Normal file
@@ -0,0 +1,680 @@
|
||||
"""
|
||||
Tenant Repository
|
||||
Repository for tenant operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, text, and_
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
import uuid
|
||||
|
||||
from .base import TenantBaseRepository
|
||||
from app.models.tenants import Tenant, Subscription
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantRepository(TenantBaseRepository):
|
||||
"""Repository for tenant operations"""
|
||||
|
||||
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Tenants are relatively stable, longer cache time (10 minutes)
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
async def create_tenant(self, tenant_data: Dict[str, Any]) -> Tenant:
|
||||
"""Create a new tenant with validation"""
|
||||
try:
|
||||
# Validate tenant data
|
||||
validation_result = self._validate_tenant_data(
|
||||
tenant_data,
|
||||
["name", "address", "postal_code", "owner_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid tenant data: {validation_result['errors']}")
|
||||
|
||||
# Generate subdomain if not provided
|
||||
if "subdomain" not in tenant_data or not tenant_data["subdomain"]:
|
||||
subdomain = await self._generate_unique_subdomain(tenant_data["name"])
|
||||
tenant_data["subdomain"] = subdomain
|
||||
else:
|
||||
# Check if provided subdomain is unique
|
||||
existing_tenant = await self.get_by_subdomain(tenant_data["subdomain"])
|
||||
if existing_tenant:
|
||||
raise DuplicateRecordError(f"Subdomain {tenant_data['subdomain']} already exists")
|
||||
|
||||
# Set default values
|
||||
if "business_type" not in tenant_data:
|
||||
tenant_data["business_type"] = "bakery"
|
||||
if "city" not in tenant_data:
|
||||
tenant_data["city"] = "Madrid"
|
||||
if "is_active" not in tenant_data:
|
||||
tenant_data["is_active"] = True
|
||||
if "ml_model_trained" not in tenant_data:
|
||||
tenant_data["ml_model_trained"] = False
|
||||
|
||||
# Create tenant
|
||||
tenant = await self.create(tenant_data)
|
||||
|
||||
logger.info("Tenant created successfully",
|
||||
tenant_id=tenant.id,
|
||||
name=tenant.name,
|
||||
subdomain=tenant.subdomain,
|
||||
owner_id=tenant.owner_id)
|
||||
|
||||
return tenant
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create tenant",
|
||||
name=tenant_data.get("name"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create tenant: {str(e)}")
|
||||
|
||||
async def get_by_subdomain(self, subdomain: str) -> Optional[Tenant]:
|
||||
"""Get tenant by subdomain"""
|
||||
try:
|
||||
return await self.get_by_field("subdomain", subdomain)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant by subdomain",
|
||||
subdomain=subdomain,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant: {str(e)}")
|
||||
|
||||
async def get_tenants_by_owner(self, owner_id: str) -> List[Tenant]:
|
||||
"""Get all tenants owned by a user"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"owner_id": owner_id, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by owner",
|
||||
owner_id=owner_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenants: {str(e)}")
|
||||
|
||||
async def get_active_tenants(self, skip: int = 0, limit: int = 100) -> List[Tenant]:
|
||||
"""Get all active tenants"""
|
||||
return await self.get_active_records(skip=skip, limit=limit)
|
||||
|
||||
async def search_tenants(
|
||||
self,
|
||||
search_term: str,
|
||||
business_type: str = None,
|
||||
city: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 50
|
||||
) -> List[Tenant]:
|
||||
"""Search tenants by name, address, or other criteria"""
|
||||
try:
|
||||
# Build search conditions
|
||||
conditions = ["is_active = true"]
|
||||
params = {"skip": skip, "limit": limit}
|
||||
|
||||
# Add text search
|
||||
conditions.append("(LOWER(name) LIKE LOWER(:search_term) OR LOWER(address) LIKE LOWER(:search_term))")
|
||||
params["search_term"] = f"%{search_term}%"
|
||||
|
||||
# Add business type filter
|
||||
if business_type:
|
||||
conditions.append("business_type = :business_type")
|
||||
params["business_type"] = business_type
|
||||
|
||||
# Add city filter
|
||||
if city:
|
||||
conditions.append("LOWER(city) = LOWER(:city)")
|
||||
params["city"] = city
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM tenants
|
||||
WHERE {' AND '.join(conditions)}
|
||||
ORDER BY name ASC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
return tenants
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to search tenants",
|
||||
search_term=search_term,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def update_tenant_model_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
ml_model_trained: bool,
|
||||
last_training_date: datetime = None
|
||||
) -> Optional[Tenant]:
|
||||
"""Update tenant model training status"""
|
||||
try:
|
||||
update_data = {
|
||||
"ml_model_trained": ml_model_trained,
|
||||
"updated_at": datetime.utcnow()
|
||||
}
|
||||
|
||||
if last_training_date:
|
||||
update_data["last_training_date"] = last_training_date
|
||||
elif ml_model_trained:
|
||||
update_data["last_training_date"] = datetime.utcnow()
|
||||
|
||||
updated_tenant = await self.update(tenant_id, update_data)
|
||||
|
||||
logger.info("Tenant model status updated",
|
||||
tenant_id=tenant_id,
|
||||
ml_model_trained=ml_model_trained,
|
||||
last_training_date=last_training_date)
|
||||
|
||||
return updated_tenant
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update tenant model status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update model status: {str(e)}")
|
||||
|
||||
|
||||
async def get_tenants_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
radius_km: float = 10.0,
|
||||
limit: int = 50
|
||||
) -> List[Tenant]:
|
||||
"""Get tenants within a geographic radius"""
|
||||
try:
|
||||
# Using Haversine formula for distance calculation
|
||||
query_text = """
|
||||
SELECT *,
|
||||
(6371 * acos(
|
||||
cos(radians(:latitude)) *
|
||||
cos(radians(latitude)) *
|
||||
cos(radians(longitude) - radians(:longitude)) +
|
||||
sin(radians(:latitude)) *
|
||||
sin(radians(latitude))
|
||||
)) AS distance_km
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
AND latitude IS NOT NULL
|
||||
AND longitude IS NOT NULL
|
||||
HAVING distance_km <= :radius_km
|
||||
ORDER BY distance_km ASC
|
||||
LIMIT :limit
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"radius_km": radius_km,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
# Create tenant object (excluding the calculated distance_km field)
|
||||
record_dict = dict(row._mapping)
|
||||
record_dict.pop("distance_km", None) # Remove calculated field
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
return tenants
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by location",
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
radius_km=radius_km,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_tenant_statistics(self) -> Dict[str, Any]:
|
||||
"""Get global tenant statistics"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_tenants = await self.count()
|
||||
active_tenants = await self.count(filters={"is_active": True})
|
||||
|
||||
# Get tenants by business type
|
||||
business_type_query = text("""
|
||||
SELECT business_type, COUNT(*) as count
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
GROUP BY business_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(business_type_query)
|
||||
business_type_stats = {row.business_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get tenants by subscription tier - now from subscriptions table
|
||||
tier_query = text("""
|
||||
SELECT s.plan as subscription_tier, COUNT(*) as count
|
||||
FROM tenants t
|
||||
LEFT JOIN subscriptions s ON t.id = s.tenant_id AND s.status = 'active'
|
||||
WHERE t.is_active = true
|
||||
GROUP BY s.plan
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
tier_result = await self.session.execute(tier_query)
|
||||
tier_stats = {}
|
||||
for row in tier_result.fetchall():
|
||||
tier = row.subscription_tier if row.subscription_tier else "no_subscription"
|
||||
tier_stats[tier] = row.count
|
||||
|
||||
# Get model training statistics
|
||||
model_query = text("""
|
||||
SELECT
|
||||
COUNT(CASE WHEN ml_model_trained = true THEN 1 END) as trained_count,
|
||||
COUNT(CASE WHEN ml_model_trained = false THEN 1 END) as untrained_count,
|
||||
AVG(EXTRACT(EPOCH FROM (NOW() - last_training_date))/86400) as avg_days_since_training
|
||||
FROM tenants
|
||||
WHERE is_active = true
|
||||
""")
|
||||
|
||||
model_result = await self.session.execute(model_query)
|
||||
model_row = model_result.fetchone()
|
||||
|
||||
# Get recent registrations (last 30 days)
|
||||
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
|
||||
recent_registrations = await self.count(filters={
|
||||
"created_at": f">= '{thirty_days_ago.isoformat()}'"
|
||||
})
|
||||
|
||||
return {
|
||||
"total_tenants": total_tenants,
|
||||
"active_tenants": active_tenants,
|
||||
"inactive_tenants": total_tenants - active_tenants,
|
||||
"tenants_by_business_type": business_type_stats,
|
||||
"tenants_by_subscription": tier_stats,
|
||||
"model_training": {
|
||||
"trained_tenants": int(model_row.trained_count or 0),
|
||||
"untrained_tenants": int(model_row.untrained_count or 0),
|
||||
"avg_days_since_training": float(model_row.avg_days_since_training or 0)
|
||||
} if model_row else {
|
||||
"trained_tenants": 0,
|
||||
"untrained_tenants": 0,
|
||||
"avg_days_since_training": 0.0
|
||||
},
|
||||
"recent_registrations_30d": recent_registrations
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant statistics", error=str(e))
|
||||
return {
|
||||
"total_tenants": 0,
|
||||
"active_tenants": 0,
|
||||
"inactive_tenants": 0,
|
||||
"tenants_by_business_type": {},
|
||||
"tenants_by_subscription": {},
|
||||
"model_training": {
|
||||
"trained_tenants": 0,
|
||||
"untrained_tenants": 0,
|
||||
"avg_days_since_training": 0.0
|
||||
},
|
||||
"recent_registrations_30d": 0
|
||||
}
|
||||
|
||||
async def _generate_unique_subdomain(self, name: str) -> str:
|
||||
"""Generate a unique subdomain from tenant name"""
|
||||
try:
|
||||
# Clean the name to create a subdomain
|
||||
subdomain = name.lower().replace(' ', '-')
|
||||
# Remove accents
|
||||
subdomain = subdomain.replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u')
|
||||
subdomain = subdomain.replace('ñ', 'n')
|
||||
# Keep only alphanumeric and hyphens
|
||||
subdomain = ''.join(c for c in subdomain if c.isalnum() or c == '-')
|
||||
# Remove multiple consecutive hyphens
|
||||
while '--' in subdomain:
|
||||
subdomain = subdomain.replace('--', '-')
|
||||
# Remove leading/trailing hyphens
|
||||
subdomain = subdomain.strip('-')
|
||||
|
||||
# Ensure minimum length
|
||||
if len(subdomain) < 3:
|
||||
subdomain = f"tenant-{subdomain}"
|
||||
|
||||
# Check if subdomain exists
|
||||
existing_tenant = await self.get_by_subdomain(subdomain)
|
||||
if not existing_tenant:
|
||||
return subdomain
|
||||
|
||||
# If it exists, add a unique suffix
|
||||
counter = 1
|
||||
while True:
|
||||
candidate = f"{subdomain}-{counter}"
|
||||
existing_tenant = await self.get_by_subdomain(candidate)
|
||||
if not existing_tenant:
|
||||
return candidate
|
||||
counter += 1
|
||||
|
||||
# Prevent infinite loop
|
||||
if counter > 9999:
|
||||
return f"{subdomain}-{uuid.uuid4().hex[:6]}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to generate unique subdomain",
|
||||
name=name,
|
||||
error=str(e))
|
||||
# Fallback to UUID-based subdomain
|
||||
return f"tenant-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def deactivate_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Deactivate a tenant"""
|
||||
return await self.deactivate_record(tenant_id)
|
||||
|
||||
async def activate_tenant(self, tenant_id: str) -> Optional[Tenant]:
|
||||
"""Activate a tenant"""
|
||||
return await self.activate_record(tenant_id)
|
||||
|
||||
async def get_child_tenants(self, parent_tenant_id: str) -> List[Tenant]:
|
||||
"""Get all child tenants for a parent tenant"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"parent_tenant_id": parent_tenant_id, "is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get child tenants",
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get child tenants: {str(e)}")
|
||||
|
||||
async def get(self, record_id: Any) -> Optional[Tenant]:
|
||||
"""Get tenant by ID - alias for get_by_id for compatibility"""
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def get_child_tenant_count(self, parent_tenant_id: str) -> int:
|
||||
"""Get count of child tenants for a parent tenant"""
|
||||
try:
|
||||
child_tenants = await self.get_child_tenants(parent_tenant_id)
|
||||
return len(child_tenants)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get child tenant count",
|
||||
parent_tenant_id=parent_tenant_id,
|
||||
error=str(e))
|
||||
return 0
|
||||
|
||||
async def get_user_tenants_with_hierarchy(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all tenants a user has access to, organized in hierarchy.
|
||||
Returns parent tenants with their children nested.
|
||||
"""
|
||||
try:
|
||||
# Get all tenants where user is owner or member
|
||||
query_text = """
|
||||
SELECT DISTINCT t.*
|
||||
FROM tenants t
|
||||
LEFT JOIN tenant_members tm ON t.id = tm.tenant_id
|
||||
WHERE (t.owner_id = :user_id OR tm.user_id = :user_id)
|
||||
AND t.is_active = true
|
||||
ORDER BY t.tenant_type DESC, t.created_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"user_id": user_id})
|
||||
|
||||
tenants = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
tenant = self.model(**record_dict)
|
||||
tenants.append(tenant)
|
||||
|
||||
# Organize into hierarchy
|
||||
tenant_hierarchy = []
|
||||
parent_map = {}
|
||||
|
||||
# First pass: collect all parent/standalone tenants
|
||||
for tenant in tenants:
|
||||
if tenant.tenant_type in ['parent', 'standalone']:
|
||||
tenant_dict = {
|
||||
'id': str(tenant.id),
|
||||
'name': tenant.name,
|
||||
'subdomain': tenant.subdomain,
|
||||
'tenant_type': tenant.tenant_type,
|
||||
'business_type': tenant.business_type,
|
||||
'business_model': tenant.business_model,
|
||||
'city': tenant.city,
|
||||
'is_active': tenant.is_active,
|
||||
'children': [] if tenant.tenant_type == 'parent' else None
|
||||
}
|
||||
tenant_hierarchy.append(tenant_dict)
|
||||
parent_map[str(tenant.id)] = tenant_dict
|
||||
|
||||
# Second pass: attach children to their parents
|
||||
for tenant in tenants:
|
||||
if tenant.tenant_type == 'child' and tenant.parent_tenant_id:
|
||||
parent_id = str(tenant.parent_tenant_id)
|
||||
if parent_id in parent_map:
|
||||
child_dict = {
|
||||
'id': str(tenant.id),
|
||||
'name': tenant.name,
|
||||
'subdomain': tenant.subdomain,
|
||||
'tenant_type': 'child',
|
||||
'parent_tenant_id': parent_id,
|
||||
'city': tenant.city,
|
||||
'is_active': tenant.is_active
|
||||
}
|
||||
parent_map[parent_id]['children'].append(child_dict)
|
||||
|
||||
return tenant_hierarchy
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get user tenants with hierarchy",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_tenants_by_session_id(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get tenants associated with a specific demo session using the demo_session_id field.
|
||||
"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"demo_session_id": session_id,
|
||||
"is_active": True
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenants by session ID",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenants by session ID: {str(e)}")
|
||||
|
||||
async def get_professional_demo_tenants(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get professional demo tenants filtered by session.
|
||||
|
||||
Args:
|
||||
session_id: Required demo session ID to filter tenants
|
||||
|
||||
Returns:
|
||||
List of professional demo tenants for this specific session
|
||||
"""
|
||||
try:
|
||||
filters = {
|
||||
"business_model": "professional_bakery",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
}
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get professional demo tenants",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get professional demo tenants: {str(e)}")
|
||||
|
||||
async def get_enterprise_demo_tenants(self, session_id: str) -> List[Tenant]:
|
||||
"""
|
||||
Get enterprise demo tenants (parent and children) filtered by session.
|
||||
|
||||
Args:
|
||||
session_id: Required demo session ID to filter tenants
|
||||
|
||||
Returns:
|
||||
List of enterprise demo tenants (1 parent + 3 children) for this specific session
|
||||
"""
|
||||
try:
|
||||
# Get enterprise demo parent tenants for this session
|
||||
parent_tenants = await self.get_multi(
|
||||
filters={
|
||||
"tenant_type": "parent",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
# Get child tenants for the enterprise demo session
|
||||
child_tenants = await self.get_multi(
|
||||
filters={
|
||||
"tenant_type": "child",
|
||||
"is_demo": True,
|
||||
"is_active": True,
|
||||
"demo_session_id": session_id # Always filter by session
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
# Combine parent and child tenants
|
||||
return parent_tenants + child_tenants
|
||||
except Exception as e:
|
||||
logger.error("Failed to get enterprise demo tenants",
|
||||
session_id=session_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}")
|
||||
|
||||
async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get tenant by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
# Find tenant by joining with subscriptions table
|
||||
# Tenant doesn't have customer_id directly, so we need to find via subscription
|
||||
query = select(Tenant).join(
|
||||
Subscription, Subscription.tenant_id == Tenant.id
|
||||
).where(Subscription.customer_id == customer_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No tenant found for customer_id",
|
||||
customer_id=customer_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting tenant by customer_id",
|
||||
customer_id=customer_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}")
|
||||
|
||||
async def get_user_primary_tenant(self, user_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get the primary tenant for a user (the tenant they own)
|
||||
|
||||
Args:
|
||||
user_id: User ID to find primary tenant for
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting primary tenant for user", user_id=user_id)
|
||||
|
||||
# Query for tenant where user is the owner
|
||||
query = select(Tenant).where(Tenant.owner_id == user_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found primary tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No primary tenant found for user", user_id=user_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting primary tenant for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get primary tenant for user: {str(e)}")
|
||||
|
||||
async def get_any_user_tenant(self, user_id: str) -> Optional[Tenant]:
|
||||
"""
|
||||
Get any tenant that the user has access to (via tenant_members)
|
||||
|
||||
Args:
|
||||
user_id: User ID to find accessible tenants for
|
||||
|
||||
Returns:
|
||||
Tenant object if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting any accessible tenant for user", user_id=user_id)
|
||||
|
||||
# Query for tenant members where user has access
|
||||
from app.models.tenants import TenantMember
|
||||
|
||||
query = select(Tenant).join(
|
||||
TenantMember, Tenant.id == TenantMember.tenant_id
|
||||
).where(TenantMember.user_id == user_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if tenant:
|
||||
logger.debug("Found accessible tenant for user",
|
||||
user_id=user_id,
|
||||
tenant_id=str(tenant.id))
|
||||
return tenant
|
||||
else:
|
||||
logger.debug("No accessible tenants found for user", user_id=user_id)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting accessible tenant for user",
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get accessible tenant for user: {str(e)}")
|
||||
@@ -0,0 +1,82 @@
|
||||
# services/tenant/app/repositories/tenant_settings_repository.py
|
||||
"""
|
||||
Tenant Settings Repository
|
||||
Data access layer for tenant settings
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from ..models.tenant_settings import TenantSettings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantSettingsRepository:
|
||||
"""Repository for TenantSettings data access"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: UUID) -> Optional[TenantSettings]:
|
||||
"""
|
||||
Get tenant settings by tenant ID
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
TenantSettings or None if not found
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create(self, settings: TenantSettings) -> TenantSettings:
|
||||
"""
|
||||
Create new tenant settings
|
||||
|
||||
Args:
|
||||
settings: TenantSettings instance to create
|
||||
|
||||
Returns:
|
||||
Created TenantSettings instance
|
||||
"""
|
||||
self.db.add(settings)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(settings)
|
||||
return settings
|
||||
|
||||
async def update(self, settings: TenantSettings) -> TenantSettings:
|
||||
"""
|
||||
Update tenant settings
|
||||
|
||||
Args:
|
||||
settings: TenantSettings instance with updates
|
||||
|
||||
Returns:
|
||||
Updated TenantSettings instance
|
||||
"""
|
||||
await self.db.commit()
|
||||
await self.db.refresh(settings)
|
||||
return settings
|
||||
|
||||
async def delete(self, tenant_id: UUID) -> None:
|
||||
"""
|
||||
Delete tenant settings
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
|
||||
)
|
||||
settings = result.scalar_one_or_none()
|
||||
|
||||
if settings:
|
||||
await self.db.delete(settings)
|
||||
await self.db.commit()
|
||||
0
services/tenant/app/schemas/__init__.py
Normal file
0
services/tenant/app/schemas/__init__.py
Normal file
89
services/tenant/app/schemas/tenant_locations.py
Normal file
89
services/tenant/app/schemas/tenant_locations.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""
|
||||
Tenant Location Schemas
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
class TenantLocationBase(BaseModel):
|
||||
"""Base schema for tenant location"""
|
||||
name: str = Field(..., min_length=1, max_length=200)
|
||||
location_type: str = Field(..., pattern=r'^(central_production|retail_outlet|warehouse|store|branch)$')
|
||||
address: str = Field(..., min_length=10, max_length=500)
|
||||
city: str = Field(default="Madrid", max_length=100)
|
||||
postal_code: str = Field(..., min_length=3, max_length=10)
|
||||
latitude: Optional[float] = Field(None, ge=-90, le=90)
|
||||
longitude: Optional[float] = Field(None, ge=-180, le=180)
|
||||
contact_person: Optional[str] = Field(None, max_length=200)
|
||||
contact_phone: Optional[str] = Field(None, max_length=20)
|
||||
contact_email: Optional[str] = Field(None, max_length=255)
|
||||
is_active: bool = True
|
||||
delivery_windows: Optional[Dict[str, Any]] = None
|
||||
operational_hours: Optional[Dict[str, Any]] = None
|
||||
capacity: Optional[int] = Field(None, ge=0)
|
||||
max_delivery_radius_km: Optional[float] = Field(None, ge=0)
|
||||
delivery_schedule_config: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(None)
|
||||
|
||||
|
||||
class TenantLocationCreate(TenantLocationBase):
|
||||
"""Schema for creating a tenant location"""
|
||||
tenant_id: str # This will be validated as UUID in the API layer
|
||||
|
||||
|
||||
class TenantLocationUpdate(BaseModel):
|
||||
"""Schema for updating a tenant location"""
|
||||
name: Optional[str] = Field(None, min_length=1, max_length=200)
|
||||
location_type: Optional[str] = Field(None, pattern=r'^(central_production|retail_outlet|warehouse|store|branch)$')
|
||||
address: Optional[str] = Field(None, min_length=10, max_length=500)
|
||||
city: Optional[str] = Field(None, max_length=100)
|
||||
postal_code: Optional[str] = Field(None, min_length=3, max_length=10)
|
||||
latitude: Optional[float] = Field(None, ge=-90, le=90)
|
||||
longitude: Optional[float] = Field(None, ge=-180, le=180)
|
||||
contact_person: Optional[str] = Field(None, max_length=200)
|
||||
contact_phone: Optional[str] = Field(None, max_length=20)
|
||||
contact_email: Optional[str] = Field(None, max_length=255)
|
||||
is_active: Optional[bool] = None
|
||||
delivery_windows: Optional[Dict[str, Any]] = None
|
||||
operational_hours: Optional[Dict[str, Any]] = None
|
||||
capacity: Optional[int] = Field(None, ge=0)
|
||||
max_delivery_radius_km: Optional[float] = Field(None, ge=0)
|
||||
delivery_schedule_config: Optional[Dict[str, Any]] = None
|
||||
metadata: Optional[Dict[str, Any]] = Field(None)
|
||||
|
||||
|
||||
class TenantLocationResponse(TenantLocationBase):
|
||||
"""Schema for tenant location response"""
|
||||
id: str
|
||||
tenant_id: str
|
||||
created_at: datetime
|
||||
updated_at: Optional[datetime]
|
||||
|
||||
@field_validator('id', 'tenant_id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class TenantLocationsResponse(BaseModel):
|
||||
"""Schema for multiple tenant locations response"""
|
||||
locations: List[TenantLocationResponse]
|
||||
total: int
|
||||
|
||||
|
||||
class TenantLocationTypeFilter(BaseModel):
|
||||
"""Schema for filtering locations by type"""
|
||||
location_types: List[str] = Field(
|
||||
default=["central_production", "retail_outlet", "warehouse", "store", "branch"],
|
||||
description="List of location types to include"
|
||||
)
|
||||
316
services/tenant/app/schemas/tenant_settings.py
Normal file
316
services/tenant/app/schemas/tenant_settings.py
Normal file
@@ -0,0 +1,316 @@
|
||||
# services/tenant/app/schemas/tenant_settings.py
|
||||
"""
|
||||
Tenant Settings Schemas
|
||||
Pydantic models for API request/response validation
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
|
||||
# ================================================================
|
||||
# SETTING CATEGORY SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class ProcurementSettings(BaseModel):
|
||||
"""Procurement and auto-approval settings"""
|
||||
auto_approve_enabled: bool = True
|
||||
auto_approve_threshold_eur: float = Field(500.0, ge=0, le=10000)
|
||||
auto_approve_min_supplier_score: float = Field(0.80, ge=0.0, le=1.0)
|
||||
require_approval_new_suppliers: bool = True
|
||||
require_approval_critical_items: bool = True
|
||||
procurement_lead_time_days: int = Field(3, ge=1, le=30)
|
||||
demand_forecast_days: int = Field(14, ge=1, le=90)
|
||||
safety_stock_percentage: float = Field(20.0, ge=0.0, le=100.0)
|
||||
po_approval_reminder_hours: int = Field(24, ge=1, le=168)
|
||||
po_critical_escalation_hours: int = Field(12, ge=1, le=72)
|
||||
use_reorder_rules: bool = Field(True, description="Use ingredient reorder point and reorder quantity in procurement calculations")
|
||||
economic_rounding: bool = Field(True, description="Round order quantities to economic multiples (reorder_quantity or supplier minimum_order_quantity)")
|
||||
respect_storage_limits: bool = Field(True, description="Enforce max_stock_level constraints on orders")
|
||||
use_supplier_minimums: bool = Field(True, description="Respect supplier minimum_order_quantity and minimum_order_amount")
|
||||
optimize_price_tiers: bool = Field(True, description="Optimize order quantities to capture volume discount price tiers")
|
||||
|
||||
|
||||
class InventorySettings(BaseModel):
|
||||
"""Inventory management settings"""
|
||||
low_stock_threshold: int = Field(10, ge=1, le=1000)
|
||||
reorder_point: int = Field(20, ge=1, le=1000)
|
||||
reorder_quantity: int = Field(50, ge=1, le=1000)
|
||||
expiring_soon_days: int = Field(7, ge=1, le=30)
|
||||
expiration_warning_days: int = Field(3, ge=1, le=14)
|
||||
quality_score_threshold: float = Field(8.0, ge=0.0, le=10.0)
|
||||
temperature_monitoring_enabled: bool = True
|
||||
refrigeration_temp_min: float = Field(1.0, ge=-5.0, le=10.0)
|
||||
refrigeration_temp_max: float = Field(4.0, ge=-5.0, le=10.0)
|
||||
freezer_temp_min: float = Field(-20.0, ge=-30.0, le=0.0)
|
||||
freezer_temp_max: float = Field(-15.0, ge=-30.0, le=0.0)
|
||||
room_temp_min: float = Field(18.0, ge=10.0, le=35.0)
|
||||
room_temp_max: float = Field(25.0, ge=10.0, le=35.0)
|
||||
temp_deviation_alert_minutes: int = Field(15, ge=1, le=60)
|
||||
critical_temp_deviation_minutes: int = Field(5, ge=1, le=30)
|
||||
|
||||
@validator('refrigeration_temp_max')
|
||||
def validate_refrigeration_range(cls, v, values):
|
||||
if 'refrigeration_temp_min' in values and v <= values['refrigeration_temp_min']:
|
||||
raise ValueError('refrigeration_temp_max must be greater than refrigeration_temp_min')
|
||||
return v
|
||||
|
||||
@validator('freezer_temp_max')
|
||||
def validate_freezer_range(cls, v, values):
|
||||
if 'freezer_temp_min' in values and v <= values['freezer_temp_min']:
|
||||
raise ValueError('freezer_temp_max must be greater than freezer_temp_min')
|
||||
return v
|
||||
|
||||
@validator('room_temp_max')
|
||||
def validate_room_range(cls, v, values):
|
||||
if 'room_temp_min' in values and v <= values['room_temp_min']:
|
||||
raise ValueError('room_temp_max must be greater than room_temp_min')
|
||||
return v
|
||||
|
||||
|
||||
class ProductionSettings(BaseModel):
|
||||
"""Production settings"""
|
||||
planning_horizon_days: int = Field(7, ge=1, le=30)
|
||||
minimum_batch_size: float = Field(1.0, ge=0.1, le=100.0)
|
||||
maximum_batch_size: float = Field(100.0, ge=1.0, le=1000.0)
|
||||
production_buffer_percentage: float = Field(10.0, ge=0.0, le=50.0)
|
||||
working_hours_per_day: int = Field(12, ge=1, le=24)
|
||||
max_overtime_hours: int = Field(4, ge=0, le=12)
|
||||
capacity_utilization_target: float = Field(0.85, ge=0.5, le=1.0)
|
||||
capacity_warning_threshold: float = Field(0.95, ge=0.7, le=1.0)
|
||||
quality_check_enabled: bool = True
|
||||
minimum_yield_percentage: float = Field(85.0, ge=50.0, le=100.0)
|
||||
quality_score_threshold: float = Field(8.0, ge=0.0, le=10.0)
|
||||
schedule_optimization_enabled: bool = True
|
||||
prep_time_buffer_minutes: int = Field(30, ge=0, le=120)
|
||||
cleanup_time_buffer_minutes: int = Field(15, ge=0, le=120)
|
||||
labor_cost_per_hour_eur: float = Field(15.0, ge=5.0, le=100.0)
|
||||
overhead_cost_percentage: float = Field(20.0, ge=0.0, le=50.0)
|
||||
|
||||
@validator('maximum_batch_size')
|
||||
def validate_batch_size_range(cls, v, values):
|
||||
if 'minimum_batch_size' in values and v <= values['minimum_batch_size']:
|
||||
raise ValueError('maximum_batch_size must be greater than minimum_batch_size')
|
||||
return v
|
||||
|
||||
@validator('capacity_warning_threshold')
|
||||
def validate_capacity_threshold(cls, v, values):
|
||||
if 'capacity_utilization_target' in values and v <= values['capacity_utilization_target']:
|
||||
raise ValueError('capacity_warning_threshold must be greater than capacity_utilization_target')
|
||||
return v
|
||||
|
||||
|
||||
class SupplierSettings(BaseModel):
|
||||
"""Supplier management settings"""
|
||||
default_payment_terms_days: int = Field(30, ge=1, le=90)
|
||||
default_delivery_days: int = Field(3, ge=1, le=30)
|
||||
excellent_delivery_rate: float = Field(95.0, ge=90.0, le=100.0)
|
||||
good_delivery_rate: float = Field(90.0, ge=80.0, le=99.0)
|
||||
excellent_quality_rate: float = Field(98.0, ge=90.0, le=100.0)
|
||||
good_quality_rate: float = Field(95.0, ge=80.0, le=99.0)
|
||||
critical_delivery_delay_hours: int = Field(24, ge=1, le=168)
|
||||
critical_quality_rejection_rate: float = Field(10.0, ge=0.0, le=50.0)
|
||||
high_cost_variance_percentage: float = Field(15.0, ge=0.0, le=100.0)
|
||||
|
||||
@validator('good_delivery_rate')
|
||||
def validate_delivery_rates(cls, v, values):
|
||||
if 'excellent_delivery_rate' in values and v >= values['excellent_delivery_rate']:
|
||||
raise ValueError('good_delivery_rate must be less than excellent_delivery_rate')
|
||||
return v
|
||||
|
||||
@validator('good_quality_rate')
|
||||
def validate_quality_rates(cls, v, values):
|
||||
if 'excellent_quality_rate' in values and v >= values['excellent_quality_rate']:
|
||||
raise ValueError('good_quality_rate must be less than excellent_quality_rate')
|
||||
return v
|
||||
|
||||
|
||||
class POSSettings(BaseModel):
|
||||
"""POS integration settings"""
|
||||
sync_interval_minutes: int = Field(5, ge=1, le=60)
|
||||
auto_sync_products: bool = True
|
||||
auto_sync_transactions: bool = True
|
||||
|
||||
|
||||
class OrderSettings(BaseModel):
|
||||
"""Order and business rules settings"""
|
||||
max_discount_percentage: float = Field(50.0, ge=0.0, le=100.0)
|
||||
default_delivery_window_hours: int = Field(48, ge=1, le=168)
|
||||
dynamic_pricing_enabled: bool = False
|
||||
discount_enabled: bool = True
|
||||
delivery_tracking_enabled: bool = True
|
||||
|
||||
|
||||
class ReplenishmentSettings(BaseModel):
|
||||
"""Replenishment planning settings"""
|
||||
projection_horizon_days: int = Field(7, ge=1, le=30)
|
||||
service_level: float = Field(0.95, ge=0.0, le=1.0)
|
||||
buffer_days: int = Field(1, ge=0, le=14)
|
||||
enable_auto_replenishment: bool = True
|
||||
min_order_quantity: float = Field(1.0, ge=0.1, le=1000.0)
|
||||
max_order_quantity: float = Field(1000.0, ge=1.0, le=10000.0)
|
||||
demand_forecast_days: int = Field(14, ge=1, le=90)
|
||||
|
||||
|
||||
class SafetyStockSettings(BaseModel):
|
||||
"""Safety stock settings"""
|
||||
service_level: float = Field(0.95, ge=0.0, le=1.0)
|
||||
method: str = Field("statistical", description="Method for safety stock calculation")
|
||||
min_safety_stock: float = Field(0.0, ge=0.0, le=1000.0)
|
||||
max_safety_stock: float = Field(100.0, ge=0.0, le=1000.0)
|
||||
reorder_point_calculation: str = Field("safety_stock_plus_lead_time_demand", description="Method for reorder point calculation")
|
||||
|
||||
|
||||
class MOQSettings(BaseModel):
|
||||
"""MOQ aggregation settings"""
|
||||
consolidation_window_days: int = Field(7, ge=1, le=30)
|
||||
allow_early_ordering: bool = True
|
||||
enable_batch_optimization: bool = True
|
||||
min_batch_size: float = Field(1.0, ge=0.1, le=1000.0)
|
||||
max_batch_size: float = Field(1000.0, ge=1.0, le=10000.0)
|
||||
|
||||
|
||||
class SupplierSelectionSettings(BaseModel):
|
||||
"""Supplier selection settings"""
|
||||
price_weight: float = Field(0.40, ge=0.0, le=1.0)
|
||||
lead_time_weight: float = Field(0.20, ge=0.0, le=1.0)
|
||||
quality_weight: float = Field(0.20, ge=0.0, le=1.0)
|
||||
reliability_weight: float = Field(0.20, ge=0.0, le=1.0)
|
||||
diversification_threshold: int = Field(1000, ge=0, le=1000)
|
||||
max_single_percentage: float = Field(0.70, ge=0.0, le=1.0)
|
||||
enable_supplier_score_optimization: bool = True
|
||||
|
||||
@validator('price_weight', 'lead_time_weight', 'quality_weight', 'reliability_weight')
|
||||
def validate_weights_sum(cls, v, values):
|
||||
weights = [values.get('price_weight', 0.40), values.get('lead_time_weight', 0.20),
|
||||
values.get('quality_weight', 0.20), values.get('reliability_weight', 0.20)]
|
||||
total = sum(weights)
|
||||
if total > 1.0:
|
||||
raise ValueError('Weights must sum to 1.0 or less')
|
||||
return v
|
||||
|
||||
|
||||
class MLInsightsSettings(BaseModel):
|
||||
"""ML Insights configuration settings"""
|
||||
# Inventory ML (Safety Stock Optimization)
|
||||
inventory_lookback_days: int = Field(90, ge=30, le=365, description="Days of demand history for safety stock analysis")
|
||||
inventory_min_history_days: int = Field(30, ge=7, le=180, description="Minimum days of history required")
|
||||
|
||||
# Production ML (Yield Prediction)
|
||||
production_lookback_days: int = Field(90, ge=30, le=365, description="Days of production history for yield analysis")
|
||||
production_min_history_runs: int = Field(30, ge=10, le=100, description="Minimum production runs required")
|
||||
|
||||
# Procurement ML (Supplier Analysis & Price Forecasting)
|
||||
supplier_analysis_lookback_days: int = Field(180, ge=30, le=730, description="Days of order history for supplier analysis")
|
||||
supplier_analysis_min_orders: int = Field(10, ge=5, le=100, description="Minimum orders required for analysis")
|
||||
price_forecast_lookback_days: int = Field(180, ge=90, le=730, description="Days of price history for forecasting")
|
||||
price_forecast_horizon_days: int = Field(30, ge=7, le=90, description="Days to forecast ahead")
|
||||
|
||||
# Forecasting ML (Dynamic Rules)
|
||||
rules_generation_lookback_days: int = Field(90, ge=30, le=365, description="Days of sales history for rule learning")
|
||||
rules_generation_min_samples: int = Field(10, ge=5, le=100, description="Minimum samples required for rule generation")
|
||||
|
||||
# Global ML Settings
|
||||
enable_ml_insights: bool = Field(True, description="Enable/disable ML insights generation")
|
||||
ml_insights_auto_trigger: bool = Field(False, description="Automatically trigger ML insights in daily workflow")
|
||||
ml_confidence_threshold: float = Field(0.80, ge=0.0, le=1.0, description="Minimum confidence threshold for ML recommendations")
|
||||
|
||||
|
||||
class NotificationSettings(BaseModel):
|
||||
"""Notification and communication settings"""
|
||||
# WhatsApp Configuration (Shared Account Model)
|
||||
whatsapp_enabled: bool = Field(False, description="Enable WhatsApp notifications for this tenant")
|
||||
whatsapp_phone_number_id: str = Field("", description="Meta WhatsApp Phone Number ID (from shared master account)")
|
||||
whatsapp_display_phone_number: str = Field("", description="Display format for UI (e.g., '+34 612 345 678')")
|
||||
whatsapp_default_language: str = Field("es", description="Default language for WhatsApp templates")
|
||||
|
||||
# Email Configuration
|
||||
email_enabled: bool = Field(True, description="Enable email notifications for this tenant")
|
||||
email_from_address: str = Field("", description="Custom from email address (optional)")
|
||||
email_from_name: str = Field("", description="Custom from name (optional)")
|
||||
email_reply_to: str = Field("", description="Reply-to email address (optional)")
|
||||
|
||||
# Notification Preferences
|
||||
enable_po_notifications: bool = Field(True, description="Enable purchase order notifications")
|
||||
enable_inventory_alerts: bool = Field(True, description="Enable inventory alerts")
|
||||
enable_production_alerts: bool = Field(True, description="Enable production alerts")
|
||||
enable_forecast_alerts: bool = Field(True, description="Enable forecast alerts")
|
||||
|
||||
# Notification Channels
|
||||
po_notification_channels: list[str] = Field(["email"], description="Channels for PO notifications (email, whatsapp)")
|
||||
inventory_alert_channels: list[str] = Field(["email"], description="Channels for inventory alerts")
|
||||
production_alert_channels: list[str] = Field(["email"], description="Channels for production alerts")
|
||||
forecast_alert_channels: list[str] = Field(["email"], description="Channels for forecast alerts")
|
||||
|
||||
@validator('po_notification_channels', 'inventory_alert_channels', 'production_alert_channels', 'forecast_alert_channels')
|
||||
def validate_channels(cls, v):
|
||||
"""Validate that channels are valid"""
|
||||
valid_channels = ["email", "whatsapp", "sms", "push"]
|
||||
for channel in v:
|
||||
if channel not in valid_channels:
|
||||
raise ValueError(f"Invalid channel: {channel}. Must be one of {valid_channels}")
|
||||
return v
|
||||
|
||||
@validator('whatsapp_phone_number_id')
|
||||
def validate_phone_number_id(cls, v, values):
|
||||
"""Validate phone number ID is provided if WhatsApp is enabled"""
|
||||
if values.get('whatsapp_enabled') and not v:
|
||||
raise ValueError("whatsapp_phone_number_id is required when WhatsApp is enabled")
|
||||
return v
|
||||
|
||||
|
||||
# ================================================================
|
||||
# REQUEST/RESPONSE SCHEMAS
|
||||
# ================================================================
|
||||
|
||||
class TenantSettingsResponse(BaseModel):
|
||||
"""Response schema for tenant settings"""
|
||||
id: UUID
|
||||
tenant_id: UUID
|
||||
procurement_settings: ProcurementSettings
|
||||
inventory_settings: InventorySettings
|
||||
production_settings: ProductionSettings
|
||||
supplier_settings: SupplierSettings
|
||||
pos_settings: POSSettings
|
||||
order_settings: OrderSettings
|
||||
replenishment_settings: ReplenishmentSettings
|
||||
safety_stock_settings: SafetyStockSettings
|
||||
moq_settings: MOQSettings
|
||||
supplier_selection_settings: SupplierSelectionSettings
|
||||
ml_insights_settings: MLInsightsSettings
|
||||
notification_settings: NotificationSettings
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TenantSettingsUpdate(BaseModel):
|
||||
"""Schema for updating tenant settings"""
|
||||
procurement_settings: Optional[ProcurementSettings] = None
|
||||
inventory_settings: Optional[InventorySettings] = None
|
||||
production_settings: Optional[ProductionSettings] = None
|
||||
supplier_settings: Optional[SupplierSettings] = None
|
||||
pos_settings: Optional[POSSettings] = None
|
||||
order_settings: Optional[OrderSettings] = None
|
||||
replenishment_settings: Optional[ReplenishmentSettings] = None
|
||||
safety_stock_settings: Optional[SafetyStockSettings] = None
|
||||
moq_settings: Optional[MOQSettings] = None
|
||||
supplier_selection_settings: Optional[SupplierSelectionSettings] = None
|
||||
ml_insights_settings: Optional[MLInsightsSettings] = None
|
||||
notification_settings: Optional[NotificationSettings] = None
|
||||
|
||||
|
||||
class CategoryUpdateRequest(BaseModel):
|
||||
"""Schema for updating a single category"""
|
||||
settings: dict
|
||||
|
||||
|
||||
class CategoryResetResponse(BaseModel):
|
||||
"""Response schema for category reset"""
|
||||
category: str
|
||||
settings: dict
|
||||
message: str
|
||||
386
services/tenant/app/schemas/tenants.py
Normal file
386
services/tenant/app/schemas/tenants.py
Normal file
@@ -0,0 +1,386 @@
|
||||
# services/tenant/app/schemas/tenants.py
|
||||
"""
|
||||
Tenant schemas - FIXED VERSION
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, ValidationInfo
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
import re
|
||||
|
||||
class BakeryRegistration(BaseModel):
|
||||
"""Bakery registration schema"""
|
||||
name: str = Field(..., min_length=2, max_length=200)
|
||||
address: str = Field(..., min_length=10, max_length=500)
|
||||
city: str = Field(default="Madrid", max_length=100)
|
||||
postal_code: str = Field(..., pattern=r"^\d{5}$")
|
||||
phone: str = Field(..., min_length=9, max_length=20)
|
||||
business_type: str = Field(default="bakery")
|
||||
business_model: Optional[str] = Field(default="individual_bakery")
|
||||
coupon_code: Optional[str] = Field(None, max_length=50, description="Promotional coupon code")
|
||||
# Subscription linking fields (for new multi-phase registration architecture)
|
||||
subscription_id: Optional[str] = Field(None, description="Existing subscription ID to link to this tenant")
|
||||
link_existing_subscription: Optional[bool] = Field(False, description="Flag to link an existing subscription during tenant creation")
|
||||
|
||||
@field_validator('phone')
|
||||
@classmethod
|
||||
def validate_spanish_phone(cls, v):
|
||||
"""Validate Spanish phone number"""
|
||||
# Remove spaces and common separators
|
||||
phone = re.sub(r'[\s\-\(\)]', '', v)
|
||||
|
||||
# Spanish mobile: +34 6/7/8/9 + 8 digits
|
||||
# Spanish landline: +34 9 + 8 digits
|
||||
patterns = [
|
||||
r'^(\+34|0034|34)?[6789]\d{8}$', # Mobile
|
||||
r'^(\+34|0034|34)?9\d{8}$', # Landline
|
||||
]
|
||||
|
||||
if not any(re.match(pattern, phone) for pattern in patterns):
|
||||
raise ValueError('Invalid Spanish phone number')
|
||||
return v
|
||||
|
||||
@field_validator('business_type')
|
||||
@classmethod
|
||||
def validate_business_type(cls, v):
|
||||
valid_types = ['bakery', 'coffee_shop', 'pastry_shop', 'restaurant']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'Business type must be one of: {valid_types}')
|
||||
return v
|
||||
|
||||
@field_validator('business_model')
|
||||
@classmethod
|
||||
def validate_business_model(cls, v):
|
||||
if v is None:
|
||||
return v
|
||||
valid_models = ['individual_bakery', 'central_baker_satellite', 'retail_bakery', 'hybrid_bakery']
|
||||
if v not in valid_models:
|
||||
raise ValueError(f'Business model must be one of: {valid_models}')
|
||||
return v
|
||||
|
||||
class TenantResponse(BaseModel):
|
||||
"""Tenant response schema - Updated to use subscription relationship"""
|
||||
id: str # ✅ Keep as str for Pydantic validation
|
||||
name: str
|
||||
subdomain: Optional[str]
|
||||
business_type: str
|
||||
business_model: Optional[str]
|
||||
tenant_type: Optional[str] = "standalone" # standalone, parent, or child
|
||||
parent_tenant_id: Optional[str] = None # For child tenants
|
||||
address: str
|
||||
city: str
|
||||
postal_code: str
|
||||
# Regional/Localization settings
|
||||
timezone: Optional[str] = "Europe/Madrid"
|
||||
currency: Optional[str] = "EUR" # Currency code: EUR, USD, GBP
|
||||
language: Optional[str] = "es" # Language code: es, en, eu
|
||||
phone: Optional[str]
|
||||
is_active: bool
|
||||
subscription_plan: Optional[str] = None # Populated from subscription relationship or service
|
||||
ml_model_trained: bool
|
||||
last_training_date: Optional[datetime]
|
||||
owner_id: str # ✅ Keep as str for Pydantic validation
|
||||
created_at: datetime
|
||||
|
||||
# ✅ FIX: Add custom validator to convert UUID to string
|
||||
@field_validator('id', 'owner_id', 'parent_tenant_id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class TenantAccessResponse(BaseModel):
|
||||
"""Tenant access verification response"""
|
||||
has_access: bool
|
||||
role: str
|
||||
permissions: List[str]
|
||||
|
||||
class TenantMemberResponse(BaseModel):
|
||||
"""Tenant member response - FIXED VERSION with enriched user data"""
|
||||
id: str
|
||||
user_id: str
|
||||
role: str
|
||||
is_active: bool
|
||||
joined_at: Optional[datetime]
|
||||
# Enriched user fields (populated via service layer)
|
||||
user_email: Optional[str] = None
|
||||
user_full_name: Optional[str] = None
|
||||
user: Optional[Dict[str, Any]] = None # Full user object for compatibility
|
||||
|
||||
# ✅ FIX: Add custom validator to convert UUID to string
|
||||
@field_validator('id', 'user_id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class TenantUpdate(BaseModel):
|
||||
"""Tenant update schema"""
|
||||
name: Optional[str] = Field(None, min_length=2, max_length=200)
|
||||
address: Optional[str] = Field(None, min_length=10, max_length=500)
|
||||
phone: Optional[str] = None
|
||||
business_type: Optional[str] = None
|
||||
business_model: Optional[str] = None
|
||||
# Regional/Localization settings
|
||||
timezone: Optional[str] = None
|
||||
currency: Optional[str] = Field(None, pattern=r'^(EUR|USD|GBP)$') # Currency code
|
||||
language: Optional[str] = Field(None, pattern=r'^(es|en|eu)$') # Language code
|
||||
|
||||
class TenantListResponse(BaseModel):
|
||||
"""Response schema for listing tenants"""
|
||||
tenants: List[TenantResponse]
|
||||
total: int
|
||||
page: int
|
||||
per_page: int
|
||||
has_next: bool
|
||||
has_prev: bool
|
||||
|
||||
class TenantMemberInvitation(BaseModel):
|
||||
"""Schema for inviting a member to a tenant"""
|
||||
email: str = Field(..., pattern=r'^[^@]+@[^@]+\.[^@]+$')
|
||||
role: str = Field(..., pattern=r'^(admin|member|viewer)$')
|
||||
message: Optional[str] = Field(None, max_length=500)
|
||||
|
||||
class TenantMemberUpdate(BaseModel):
|
||||
"""Schema for updating tenant member"""
|
||||
role: Optional[str] = Field(None, pattern=r'^(owner|admin|member|viewer)$')
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
class AddMemberWithUserCreate(BaseModel):
|
||||
"""Schema for adding member with optional user creation (pilot phase)"""
|
||||
# For existing users
|
||||
user_id: Optional[str] = Field(None, description="ID of existing user to add")
|
||||
|
||||
# For new user creation
|
||||
create_user: bool = Field(False, description="Whether to create a new user")
|
||||
email: Optional[str] = Field(None, description="Email for new user (if create_user=True)")
|
||||
full_name: Optional[str] = Field(None, min_length=2, max_length=100, description="Full name for new user")
|
||||
password: Optional[str] = Field(None, min_length=8, max_length=128, description="Password for new user")
|
||||
phone: Optional[str] = Field(None, description="Phone number for new user")
|
||||
language: Optional[str] = Field("es", pattern="^(es|en|eu)$", description="Preferred language")
|
||||
timezone: Optional[str] = Field("Europe/Madrid", description="User timezone")
|
||||
|
||||
# Common fields
|
||||
role: str = Field(..., pattern=r'^(admin|member|viewer)$', description="Role in the tenant")
|
||||
|
||||
@field_validator('email', 'full_name', 'password')
|
||||
@classmethod
|
||||
def validate_user_creation_fields(cls, v, info: ValidationInfo):
|
||||
"""Validate that required fields are present when creating a user"""
|
||||
if info.data.get('create_user') and info.field_name in ['email', 'full_name', 'password']:
|
||||
if not v:
|
||||
raise ValueError(f"{info.field_name} is required when create_user is True")
|
||||
return v
|
||||
|
||||
@field_validator('user_id')
|
||||
@classmethod
|
||||
def validate_user_id_or_create(cls, v, info: ValidationInfo):
|
||||
"""Ensure either user_id or create_user is provided"""
|
||||
if not v and not info.data.get('create_user'):
|
||||
raise ValueError("Either user_id or create_user must be provided")
|
||||
if v and info.data.get('create_user'):
|
||||
raise ValueError("Cannot specify both user_id and create_user")
|
||||
return v
|
||||
|
||||
class TenantSubscriptionUpdate(BaseModel):
|
||||
"""Schema for updating tenant subscription"""
|
||||
plan: str = Field(..., pattern=r'^(basic|professional|enterprise)$')
|
||||
billing_cycle: str = Field(default="monthly", pattern=r'^(monthly|yearly)$')
|
||||
|
||||
class TenantStatsResponse(BaseModel):
|
||||
"""Tenant statistics response"""
|
||||
tenant_id: str
|
||||
total_members: int
|
||||
active_members: int
|
||||
total_predictions: int
|
||||
models_trained: int
|
||||
last_training_date: Optional[datetime]
|
||||
subscription_plan: str
|
||||
subscription_status: str
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# ENTERPRISE CHILD TENANT SCHEMAS
|
||||
# ============================================================================
|
||||
|
||||
class ChildTenantCreate(BaseModel):
|
||||
"""Schema for creating a child tenant in enterprise hierarchy - Updated to match tenant model"""
|
||||
name: str = Field(..., min_length=2, max_length=200, description="Child tenant name (e.g., 'Madrid - Salamanca')")
|
||||
city: str = Field(..., min_length=2, max_length=100, description="City where the outlet is located")
|
||||
zone: Optional[str] = Field(None, max_length=100, description="Zone or neighborhood")
|
||||
address: str = Field(..., min_length=10, max_length=500, description="Full address of the outlet")
|
||||
postal_code: str = Field(..., pattern=r"^\d{5}$", description="5-digit postal code")
|
||||
location_code: str = Field(..., min_length=1, max_length=10, description="Short location code (e.g., MAD, BCN)")
|
||||
|
||||
# Coordinates (can be geocoded from address if not provided)
|
||||
latitude: Optional[float] = Field(None, ge=-90, le=90, description="Latitude coordinate")
|
||||
longitude: Optional[float] = Field(None, ge=-180, le=180, description="Longitude coordinate")
|
||||
|
||||
# Contact info (inherits from parent if not provided)
|
||||
phone: Optional[str] = Field(None, min_length=9, max_length=20, description="Contact phone")
|
||||
email: Optional[str] = Field(None, description="Contact email")
|
||||
|
||||
# Business info
|
||||
business_type: Optional[str] = Field(None, max_length=100, description="Type of business")
|
||||
business_model: Optional[str] = Field(None, max_length=100, description="Business model")
|
||||
|
||||
# Timezone configuration
|
||||
timezone: Optional[str] = Field(None, max_length=50, description="Timezone for scheduling")
|
||||
|
||||
# Additional metadata
|
||||
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the child tenant")
|
||||
|
||||
@field_validator('location_code')
|
||||
@classmethod
|
||||
def validate_location_code(cls, v):
|
||||
"""Ensure location code is uppercase and alphanumeric"""
|
||||
if not v.replace('-', '').replace('_', '').isalnum():
|
||||
raise ValueError('Location code must be alphanumeric (with optional hyphens/underscores)')
|
||||
return v.upper()
|
||||
|
||||
@field_validator('phone')
|
||||
@classmethod
|
||||
def validate_phone(cls, v):
|
||||
"""Validate Spanish phone number if provided"""
|
||||
if v is None:
|
||||
return v
|
||||
phone = re.sub(r'[\s\-\(\)]', '', v)
|
||||
patterns = [
|
||||
r'^(\+34|0034|34)?[6789]\d{8}$', # Mobile
|
||||
r'^(\+34|0034|34)?9\d{8}$', # Landline
|
||||
]
|
||||
if not any(re.match(pattern, phone) for pattern in patterns):
|
||||
raise ValueError('Invalid Spanish phone number')
|
||||
return v
|
||||
|
||||
@field_validator('business_type')
|
||||
@classmethod
|
||||
def validate_business_type(cls, v):
|
||||
"""Validate business type if provided"""
|
||||
if v is None:
|
||||
return v
|
||||
valid_types = ['bakery', 'coffee_shop', 'pastry_shop', 'restaurant']
|
||||
if v not in valid_types:
|
||||
raise ValueError(f'Business type must be one of: {valid_types}')
|
||||
return v
|
||||
|
||||
@field_validator('business_model')
|
||||
@classmethod
|
||||
def validate_business_model(cls, v):
|
||||
"""Validate business model if provided"""
|
||||
if v is None:
|
||||
return v
|
||||
valid_models = ['individual_bakery', 'central_baker_satellite', 'retail_bakery', 'hybrid_bakery']
|
||||
if v not in valid_models:
|
||||
raise ValueError(f'Business model must be one of: {valid_models}')
|
||||
return v
|
||||
|
||||
@field_validator('timezone')
|
||||
@classmethod
|
||||
def validate_timezone(cls, v):
|
||||
"""Validate timezone if provided"""
|
||||
if v is None:
|
||||
return v
|
||||
# Basic timezone validation - should match common timezone formats
|
||||
if not re.match(r'^[A-Za-z_+/]+$', v):
|
||||
raise ValueError('Invalid timezone format')
|
||||
return v
|
||||
|
||||
|
||||
class BulkChildTenantsCreate(BaseModel):
|
||||
"""Schema for bulk creating child tenants during onboarding"""
|
||||
child_tenants: List[ChildTenantCreate] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=50,
|
||||
description="List of child tenants to create (1-50)"
|
||||
)
|
||||
|
||||
# Optional: Auto-configure distribution routes
|
||||
auto_configure_distribution: bool = Field(
|
||||
True,
|
||||
description="Whether to automatically set up distribution routes between parent and children"
|
||||
)
|
||||
|
||||
@field_validator('child_tenants')
|
||||
@classmethod
|
||||
def validate_unique_location_codes(cls, v):
|
||||
"""Ensure all location codes are unique within the batch"""
|
||||
location_codes = [ct.location_code for ct in v]
|
||||
if len(location_codes) != len(set(location_codes)):
|
||||
raise ValueError('Location codes must be unique within the batch')
|
||||
return v
|
||||
|
||||
|
||||
class ChildTenantResponse(TenantResponse):
|
||||
"""Response schema for child tenant - extends TenantResponse"""
|
||||
location_code: Optional[str] = None
|
||||
zone: Optional[str] = None
|
||||
hierarchy_path: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class BulkChildTenantsResponse(BaseModel):
|
||||
"""Response schema for bulk child tenant creation"""
|
||||
parent_tenant_id: str
|
||||
created_count: int
|
||||
failed_count: int
|
||||
created_tenants: List[ChildTenantResponse]
|
||||
failed_tenants: List[Dict[str, Any]] = Field(
|
||||
default_factory=list,
|
||||
description="List of failed tenants with error details"
|
||||
)
|
||||
distribution_configured: bool = False
|
||||
|
||||
@field_validator('parent_tenant_id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class TenantHierarchyResponse(BaseModel):
|
||||
"""Response schema for tenant hierarchy information"""
|
||||
tenant_id: str
|
||||
tenant_type: str = Field(..., description="Type: standalone, parent, or child")
|
||||
parent_tenant_id: Optional[str] = Field(None, description="Parent tenant ID if this is a child")
|
||||
hierarchy_path: Optional[str] = Field(None, description="Materialized path for hierarchy queries")
|
||||
child_count: int = Field(0, description="Number of child tenants (for parent tenants)")
|
||||
hierarchy_level: int = Field(0, description="Level in hierarchy: 0=parent, 1=child, 2=grandchild, etc.")
|
||||
|
||||
@field_validator('tenant_id', 'parent_tenant_id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
"""Convert UUID objects to strings for JSON serialization"""
|
||||
if v is None:
|
||||
return v
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
return v
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class TenantSearchRequest(BaseModel):
|
||||
"""Tenant search request schema"""
|
||||
query: Optional[str] = None
|
||||
business_type: Optional[str] = None
|
||||
city: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
limit: int = Field(default=50, ge=1, le=100)
|
||||
offset: int = Field(default=0, ge=0)
|
||||
19
services/tenant/app/services/__init__.py
Normal file
19
services/tenant/app/services/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
Tenant Service Layer
|
||||
Business logic services for tenant operations
|
||||
"""
|
||||
|
||||
from .tenant_service import TenantService, EnhancedTenantService
|
||||
from .subscription_service import SubscriptionService
|
||||
from .payment_service import PaymentService
|
||||
from .coupon_service import CouponService
|
||||
from .subscription_orchestration_service import SubscriptionOrchestrationService
|
||||
|
||||
__all__ = [
|
||||
"TenantService",
|
||||
"EnhancedTenantService",
|
||||
"SubscriptionService",
|
||||
"PaymentService",
|
||||
"CouponService",
|
||||
"SubscriptionOrchestrationService"
|
||||
]
|
||||
108
services/tenant/app/services/coupon_service.py
Normal file
108
services/tenant/app/services/coupon_service.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""
|
||||
Coupon Service - Coupon Operations
|
||||
This service handles ONLY coupon validation and redemption
|
||||
NO payment provider interactions, NO subscription logic
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.repositories.coupon_repository import CouponRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class CouponService:
|
||||
"""Service for handling coupon validation and redemption"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
self.coupon_repo = CouponRepository(db_session)
|
||||
|
||||
async def validate_coupon_code(
|
||||
self,
|
||||
coupon_code: str,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate a coupon code for a tenant
|
||||
|
||||
Args:
|
||||
coupon_code: Coupon code to validate
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Dictionary with validation results
|
||||
"""
|
||||
try:
|
||||
validation = await self.coupon_repo.validate_coupon(coupon_code, tenant_id)
|
||||
|
||||
return {
|
||||
"valid": validation.valid,
|
||||
"error_message": validation.error_message,
|
||||
"discount_preview": validation.discount_preview,
|
||||
"coupon": {
|
||||
"code": validation.coupon.code,
|
||||
"discount_type": validation.coupon.discount_type.value,
|
||||
"discount_value": validation.coupon.discount_value
|
||||
} if validation.coupon else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate coupon", error=str(e), coupon_code=coupon_code)
|
||||
return {
|
||||
"valid": False,
|
||||
"error_message": "Error al validar el cupón",
|
||||
"discount_preview": None,
|
||||
"coupon": None
|
||||
}
|
||||
|
||||
async def redeem_coupon(
|
||||
self,
|
||||
coupon_code: str,
|
||||
tenant_id: str,
|
||||
base_trial_days: int = 0
|
||||
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
|
||||
"""
|
||||
Redeem a coupon for a tenant
|
||||
|
||||
Args:
|
||||
coupon_code: Coupon code to redeem
|
||||
tenant_id: Tenant ID
|
||||
base_trial_days: Base trial days without coupon
|
||||
|
||||
Returns:
|
||||
Tuple of (success, discount_applied, error_message)
|
||||
"""
|
||||
try:
|
||||
success, redemption, error = await self.coupon_repo.redeem_coupon(
|
||||
coupon_code,
|
||||
tenant_id,
|
||||
base_trial_days
|
||||
)
|
||||
|
||||
if success and redemption:
|
||||
return True, redemption.discount_applied, None
|
||||
else:
|
||||
return False, None, error
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to redeem coupon", error=str(e), coupon_code=coupon_code)
|
||||
return False, None, f"Error al aplicar el cupón: {str(e)}"
|
||||
|
||||
async def get_coupon_by_code(self, coupon_code: str) -> Optional[Any]:
|
||||
"""
|
||||
Get coupon details by code
|
||||
|
||||
Args:
|
||||
coupon_code: Coupon code to retrieve
|
||||
|
||||
Returns:
|
||||
Coupon object or None
|
||||
"""
|
||||
try:
|
||||
return await self.coupon_repo.get_coupon_by_code(coupon_code)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get coupon by code", error=str(e), coupon_code=coupon_code)
|
||||
return None
|
||||
365
services/tenant/app/services/network_alerts_service.py
Normal file
365
services/tenant/app/services/network_alerts_service.py
Normal file
@@ -0,0 +1,365 @@
|
||||
# services/tenant/app/services/network_alerts_service.py
|
||||
"""
|
||||
Network Alerts Service
|
||||
Business logic for aggregating and managing alerts across enterprise networks
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import uuid
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class NetworkAlertsService:
|
||||
"""
|
||||
Service for aggregating and managing alerts across enterprise networks
|
||||
"""
|
||||
|
||||
def __init__(self, tenant_client, alerts_client):
|
||||
self.tenant_client = tenant_client
|
||||
self.alerts_client = alerts_client
|
||||
|
||||
async def get_child_tenants(self, parent_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all child tenants for a parent tenant
|
||||
"""
|
||||
try:
|
||||
# Get child tenants from tenant service
|
||||
children = await self.tenant_client.get_child_tenants(parent_id)
|
||||
|
||||
# Enrich with tenant details
|
||||
enriched_children = []
|
||||
for child in children:
|
||||
child_details = await self.tenant_client.get_tenant(child['id'])
|
||||
enriched_children.append({
|
||||
'id': child['id'],
|
||||
'name': child_details.get('name', f"Outlet {child['id']}"),
|
||||
'subdomain': child_details.get('subdomain'),
|
||||
'city': child_details.get('city')
|
||||
})
|
||||
|
||||
return enriched_children
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get child tenants, parent_id={parent_id}, error={str(e)}")
|
||||
raise Exception(f"Failed to get child tenants: {str(e)}")
|
||||
|
||||
async def get_alerts_for_tenant(self, tenant_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get alerts for a specific tenant
|
||||
"""
|
||||
try:
|
||||
# In a real implementation, this would call the alert service
|
||||
# For demo purposes, we'll simulate some alert data
|
||||
|
||||
# Simulate different types of alerts based on tenant type
|
||||
simulated_alerts = []
|
||||
|
||||
# Generate some sample alerts
|
||||
alert_types = ['inventory', 'production', 'delivery', 'equipment', 'quality']
|
||||
severities = ['critical', 'high', 'medium', 'low']
|
||||
|
||||
for i in range(3): # Generate 3 sample alerts per tenant
|
||||
alert = {
|
||||
'alert_id': str(uuid.uuid4()),
|
||||
'tenant_id': tenant_id,
|
||||
'alert_type': alert_types[i % len(alert_types)],
|
||||
'severity': severities[i % len(severities)],
|
||||
'title': f"{alert_types[i % len(alert_types)].title()} Alert Detected",
|
||||
'message': f"Sample {alert_types[i % len(alert_types)]} alert for tenant {tenant_id}",
|
||||
'timestamp': (datetime.now() - timedelta(hours=i)).isoformat(),
|
||||
'status': 'active' if i < 2 else 'resolved',
|
||||
'source_system': f"{alert_types[i % len(alert_types)]}-service",
|
||||
'related_entity_id': f"entity-{i+1}",
|
||||
'related_entity_type': alert_types[i % len(alert_types)]
|
||||
}
|
||||
simulated_alerts.append(alert)
|
||||
|
||||
return simulated_alerts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get alerts for tenant, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise Exception(f"Failed to get alerts: {str(e)}")
|
||||
|
||||
async def get_network_alerts(self, parent_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get all alerts across the network
|
||||
"""
|
||||
try:
|
||||
# Get all child tenants
|
||||
child_tenants = await self.get_child_tenants(parent_id)
|
||||
|
||||
# Aggregate alerts from all child tenants
|
||||
all_alerts = []
|
||||
|
||||
for child in child_tenants:
|
||||
child_id = child['id']
|
||||
child_alerts = await self.get_alerts_for_tenant(child_id)
|
||||
all_alerts.extend(child_alerts)
|
||||
|
||||
return all_alerts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get network alerts", parent_id=parent_id, error=str(e))
|
||||
raise Exception(f"Failed to get network alerts: {str(e)}")
|
||||
|
||||
async def detect_alert_correlations(
|
||||
self,
|
||||
alerts: List[Dict[str, Any]],
|
||||
min_correlation_strength: float = 0.7
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Detect correlations between alerts
|
||||
"""
|
||||
try:
|
||||
# Simple correlation detection (in real implementation, this would be more sophisticated)
|
||||
correlations = []
|
||||
|
||||
# Group alerts by type and time proximity
|
||||
alert_groups = {}
|
||||
|
||||
for alert in alerts:
|
||||
alert_type = alert['alert_type']
|
||||
timestamp = alert['timestamp']
|
||||
|
||||
# Use timestamp as key for grouping (simplified)
|
||||
if alert_type not in alert_groups:
|
||||
alert_groups[alert_type] = []
|
||||
|
||||
alert_groups[alert_type].append(alert)
|
||||
|
||||
# Create correlation groups
|
||||
for alert_type, group in alert_groups.items():
|
||||
if len(group) >= 2: # Only create correlations for groups with 2+ alerts
|
||||
primary_alert = group[0]
|
||||
related_alerts = group[1:]
|
||||
|
||||
correlation = {
|
||||
'correlation_id': str(uuid.uuid4()),
|
||||
'primary_alert': primary_alert,
|
||||
'related_alerts': related_alerts,
|
||||
'correlation_type': 'temporal',
|
||||
'correlation_strength': 0.85,
|
||||
'impact_analysis': f"Multiple {alert_type} alerts detected within short timeframe"
|
||||
}
|
||||
|
||||
if correlation['correlation_strength'] >= min_correlation_strength:
|
||||
correlations.append(correlation)
|
||||
|
||||
return correlations
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to detect alert correlations", error=str(e))
|
||||
raise Exception(f"Failed to detect correlations: {str(e)}")
|
||||
|
||||
async def acknowledge_alert(self, parent_id: str, alert_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Acknowledge an alert
|
||||
"""
|
||||
try:
|
||||
# In a real implementation, this would update the alert status
|
||||
# For demo purposes, we'll simulate the operation
|
||||
|
||||
logger.info("Alert acknowledged", parent_id=parent_id, alert_id=alert_id)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'alert_id': alert_id,
|
||||
'status': 'acknowledged'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to acknowledge alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
|
||||
raise Exception(f"Failed to acknowledge alert: {str(e)}")
|
||||
|
||||
async def resolve_alert(self, parent_id: str, alert_id: str, resolution_notes: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Resolve an alert
|
||||
"""
|
||||
try:
|
||||
# In a real implementation, this would update the alert status
|
||||
# For demo purposes, we'll simulate the operation
|
||||
|
||||
logger.info("Alert resolved", parent_id=parent_id, alert_id=alert_id, notes=resolution_notes)
|
||||
|
||||
return {
|
||||
'success': True,
|
||||
'alert_id': alert_id,
|
||||
'status': 'resolved',
|
||||
'resolution_notes': resolution_notes
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to resolve alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
|
||||
raise Exception(f"Failed to resolve alert: {str(e)}")
|
||||
|
||||
async def get_alert_trends(self, parent_id: str, days: int = 30) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get alert trends over time
|
||||
"""
|
||||
try:
|
||||
# Simulate trend data
|
||||
trends = []
|
||||
end_date = datetime.now()
|
||||
|
||||
# Generate daily trend data
|
||||
for i in range(days):
|
||||
date = end_date - timedelta(days=i)
|
||||
|
||||
# Simulate varying alert counts with weekly pattern
|
||||
base_count = 5
|
||||
weekly_variation = int((i % 7) * 1.5) # Higher on weekdays
|
||||
daily_noise = (i % 3 - 1) # Daily noise
|
||||
|
||||
alert_count = max(1, base_count + weekly_variation + daily_noise)
|
||||
|
||||
trends.append({
|
||||
'date': date.strftime('%Y-%m-%d'),
|
||||
'total_alerts': alert_count,
|
||||
'critical_alerts': max(0, int(alert_count * 0.1)),
|
||||
'high_alerts': max(0, int(alert_count * 0.2)),
|
||||
'medium_alerts': max(0, int(alert_count * 0.4)),
|
||||
'low_alerts': max(0, int(alert_count * 0.3))
|
||||
})
|
||||
|
||||
# Sort by date (oldest first)
|
||||
trends.sort(key=lambda x: x['date'])
|
||||
|
||||
return trends
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get alert trends", parent_id=parent_id, error=str(e))
|
||||
raise Exception(f"Failed to get alert trends: {str(e)}")
|
||||
|
||||
async def get_prioritized_alerts(self, parent_id: str, limit: int = 10) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get prioritized alerts based on impact and urgency
|
||||
"""
|
||||
try:
|
||||
# Get all network alerts
|
||||
all_alerts = await self.get_network_alerts(parent_id)
|
||||
|
||||
if not all_alerts:
|
||||
return []
|
||||
|
||||
# Simple prioritization (in real implementation, this would use ML)
|
||||
# Priority based on severity and recency
|
||||
severity_scores = {'critical': 4, 'high': 3, 'medium': 2, 'low': 1}
|
||||
|
||||
for alert in all_alerts:
|
||||
severity_score = severity_scores.get(alert['severity'], 1)
|
||||
# Add recency score (newer alerts get higher priority)
|
||||
timestamp = datetime.fromisoformat(alert['timestamp'])
|
||||
recency_score = min(3, (datetime.now() - timestamp).days + 1)
|
||||
|
||||
alert['priority_score'] = severity_score * recency_score
|
||||
|
||||
# Sort by priority score (highest first)
|
||||
all_alerts.sort(key=lambda x: x['priority_score'], reverse=True)
|
||||
|
||||
# Return top N alerts
|
||||
prioritized = all_alerts[:limit]
|
||||
|
||||
# Remove priority score from response
|
||||
for alert in prioritized:
|
||||
alert.pop('priority_score', None)
|
||||
|
||||
return prioritized
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get prioritized alerts", parent_id=parent_id, error=str(e))
|
||||
raise Exception(f"Failed to get prioritized alerts: {str(e)}")
|
||||
|
||||
|
||||
# Helper class for alert analysis
|
||||
class AlertAnalyzer:
|
||||
"""
|
||||
Helper class for analyzing alert patterns
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def calculate_alert_severity_score(alert: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate severity score for an alert
|
||||
"""
|
||||
severity_scores = {'critical': 1.0, 'high': 0.75, 'medium': 0.5, 'low': 0.25}
|
||||
return severity_scores.get(alert['severity'], 0.25)
|
||||
|
||||
@staticmethod
|
||||
def detect_alert_patterns(alerts: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Detect patterns in alert data
|
||||
"""
|
||||
if not alerts:
|
||||
return {'patterns': [], 'anomalies': []}
|
||||
|
||||
patterns = []
|
||||
anomalies = []
|
||||
|
||||
# Simple pattern detection
|
||||
alert_types = [a['alert_type'] for a in alerts]
|
||||
type_counts = {}
|
||||
|
||||
for alert_type in alert_types:
|
||||
type_counts[alert_type] = type_counts.get(alert_type, 0) + 1
|
||||
|
||||
# Detect if one type dominates
|
||||
total_alerts = len(alerts)
|
||||
for alert_type, count in type_counts.items():
|
||||
if count / total_alerts > 0.6: # More than 60% of one type
|
||||
patterns.append({
|
||||
'type': 'dominant_alert_type',
|
||||
'pattern': f'{alert_type} alerts dominate ({count}/{total_alerts})',
|
||||
'confidence': 0.85
|
||||
})
|
||||
|
||||
return {'patterns': patterns, 'anomalies': anomalies}
|
||||
|
||||
|
||||
# Helper class for alert correlation
|
||||
class AlertCorrelator:
|
||||
"""
|
||||
Helper class for correlating alerts
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def calculate_correlation_strength(alert1: Dict[str, Any], alert2: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate correlation strength between two alerts
|
||||
"""
|
||||
# Simple correlation based on type and time proximity
|
||||
same_type = 1.0 if alert1['alert_type'] == alert2['alert_type'] else 0.3
|
||||
|
||||
time1 = datetime.fromisoformat(alert1['timestamp'])
|
||||
time2 = datetime.fromisoformat(alert2['timestamp'])
|
||||
time_diff_hours = abs((time2 - time1).total_seconds() / 3600)
|
||||
|
||||
# Time proximity score (higher for closer times)
|
||||
time_proximity = max(0, 1.0 - min(1.0, time_diff_hours / 24)) # Decays over 24 hours
|
||||
|
||||
return same_type * time_proximity
|
||||
|
||||
|
||||
# Helper class for alert prioritization
|
||||
class AlertPrioritizer:
|
||||
"""
|
||||
Helper class for prioritizing alerts
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def calculate_priority_score(alert: Dict[str, Any]) -> float:
|
||||
"""
|
||||
Calculate priority score for an alert
|
||||
"""
|
||||
# Base score from severity
|
||||
severity_scores = {'critical': 100, 'high': 75, 'medium': 50, 'low': 25}
|
||||
base_score = severity_scores.get(alert['severity'], 25)
|
||||
|
||||
# Add recency bonus (newer alerts get higher priority)
|
||||
timestamp = datetime.fromisoformat(alert['timestamp'])
|
||||
hours_old = (datetime.now() - timestamp).total_seconds() / 3600
|
||||
recency_bonus = max(0, 50 - hours_old) # Decays over 50 hours
|
||||
|
||||
return base_score + recency_bonus
|
||||
1316
services/tenant/app/services/payment_service.py
Normal file
1316
services/tenant/app/services/payment_service.py
Normal file
File diff suppressed because it is too large
Load Diff
358
services/tenant/app/services/registration_state_service.py
Normal file
358
services/tenant/app/services/registration_state_service.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""
|
||||
Registration State Management Service
|
||||
Tracks registration progress and handles state transitions
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from uuid import uuid4
|
||||
from shared.exceptions.registration_exceptions import (
|
||||
RegistrationStateError,
|
||||
InvalidStateTransitionError
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class RegistrationState(Enum):
|
||||
"""Registration process states"""
|
||||
INITIATED = "initiated"
|
||||
PAYMENT_VERIFICATION_PENDING = "payment_verification_pending"
|
||||
PAYMENT_VERIFIED = "payment_verified"
|
||||
SUBSCRIPTION_CREATED = "subscription_created"
|
||||
USER_CREATED = "user_created"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class RegistrationStateService:
|
||||
"""
|
||||
Registration State Management Service
|
||||
Tracks and manages registration process state
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize state service"""
|
||||
# In production, this would use a database
|
||||
self.registration_states = {}
|
||||
|
||||
async def create_registration_state(
|
||||
self,
|
||||
email: str,
|
||||
user_data: Dict[str, Any]
|
||||
) -> str:
|
||||
"""
|
||||
Create new registration state
|
||||
|
||||
Args:
|
||||
email: User email
|
||||
user_data: Registration data
|
||||
|
||||
Returns:
|
||||
Registration state ID
|
||||
"""
|
||||
try:
|
||||
state_id = str(uuid4())
|
||||
|
||||
registration_state = {
|
||||
'state_id': state_id,
|
||||
'email': email,
|
||||
'current_state': RegistrationState.INITIATED.value,
|
||||
'created_at': datetime.now().isoformat(),
|
||||
'updated_at': datetime.now().isoformat(),
|
||||
'user_data': user_data,
|
||||
'setup_intent_id': None,
|
||||
'customer_id': None,
|
||||
'subscription_id': None,
|
||||
'error': None
|
||||
}
|
||||
|
||||
self.registration_states[state_id] = registration_state
|
||||
|
||||
logger.info("Registration state created",
|
||||
state_id=state_id,
|
||||
email=email,
|
||||
current_state=RegistrationState.INITIATED.value)
|
||||
|
||||
return state_id
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create registration state",
|
||||
error=str(e),
|
||||
email=email,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State creation failed: {str(e)}") from e
|
||||
|
||||
async def transition_state(
|
||||
self,
|
||||
state_id: str,
|
||||
new_state: RegistrationState,
|
||||
context: Optional[Dict[str, Any]] = None
|
||||
) -> None:
|
||||
"""
|
||||
Transition registration to new state with validation
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
new_state: New state to transition to
|
||||
context: Additional context data
|
||||
|
||||
Raises:
|
||||
InvalidStateTransitionError: If transition is invalid
|
||||
RegistrationStateError: If transition fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
current_state = self.registration_states[state_id]['current_state']
|
||||
|
||||
# Validate state transition
|
||||
if not self._is_valid_transition(current_state, new_state.value):
|
||||
raise InvalidStateTransitionError(
|
||||
f"Invalid transition from {current_state} to {new_state.value}"
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.registration_states[state_id]['current_state'] = new_state.value
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
# Update context data
|
||||
if context:
|
||||
self.registration_states[state_id].update(context)
|
||||
|
||||
logger.info("Registration state transitioned",
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=new_state.value)
|
||||
|
||||
except InvalidStateTransitionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("State transition failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=new_state.value,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State transition failed: {str(e)}") from e
|
||||
|
||||
async def update_state_context(
|
||||
self,
|
||||
state_id: str,
|
||||
context: Dict[str, Any]
|
||||
) -> None:
|
||||
"""
|
||||
Update state context data
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
context: Context data to update
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If update fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
self.registration_states[state_id].update(context)
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.debug("Registration state context updated",
|
||||
state_id=state_id,
|
||||
context_keys=list(context.keys()))
|
||||
|
||||
except Exception as e:
|
||||
logger.error("State context update failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State context update failed: {str(e)}") from e
|
||||
|
||||
async def get_registration_state(
|
||||
self,
|
||||
state_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get registration state by ID
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
|
||||
Returns:
|
||||
Registration state data
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If state not found
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
return self.registration_states[state_id]
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get registration state",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State retrieval failed: {str(e)}") from e
|
||||
|
||||
async def rollback_state(
|
||||
self,
|
||||
state_id: str,
|
||||
target_state: RegistrationState
|
||||
) -> None:
|
||||
"""
|
||||
Rollback registration to previous state
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
target_state: State to rollback to
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If rollback fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
current_state = self.registration_states[state_id]['current_state']
|
||||
|
||||
# Only allow rollback to earlier states
|
||||
if not self._can_rollback(current_state, target_state.value):
|
||||
raise InvalidStateTransitionError(
|
||||
f"Cannot rollback from {current_state} to {target_state.value}"
|
||||
)
|
||||
|
||||
# Update state
|
||||
self.registration_states[state_id]['current_state'] = target_state.value
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
self.registration_states[state_id]['error'] = "Registration rolled back"
|
||||
|
||||
logger.warning("Registration state rolled back",
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=target_state.value)
|
||||
|
||||
except InvalidStateTransitionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("State rollback failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
from_state=current_state,
|
||||
to_state=target_state.value,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"State rollback failed: {str(e)}") from e
|
||||
|
||||
async def mark_registration_failed(
|
||||
self,
|
||||
state_id: str,
|
||||
error: str
|
||||
) -> None:
|
||||
"""
|
||||
Mark registration as failed
|
||||
|
||||
Args:
|
||||
state_id: Registration state ID
|
||||
error: Error message
|
||||
|
||||
Raises:
|
||||
RegistrationStateError: If operation fails
|
||||
"""
|
||||
try:
|
||||
if state_id not in self.registration_states:
|
||||
raise RegistrationStateError(f"Registration state {state_id} not found")
|
||||
|
||||
self.registration_states[state_id]['current_state'] = RegistrationState.FAILED.value
|
||||
self.registration_states[state_id]['error'] = error
|
||||
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
|
||||
|
||||
logger.error("Registration marked as failed",
|
||||
state_id=state_id,
|
||||
error=error)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark registration as failed",
|
||||
error=str(e),
|
||||
state_id=state_id,
|
||||
exc_info=True)
|
||||
raise RegistrationStateError(f"Mark failed operation failed: {str(e)}") from e
|
||||
|
||||
def _is_valid_transition(self, current_state: str, new_state: str) -> bool:
|
||||
"""
|
||||
Validate state transition
|
||||
|
||||
Args:
|
||||
current_state: Current state
|
||||
new_state: New state
|
||||
|
||||
Returns:
|
||||
True if transition is valid
|
||||
"""
|
||||
# Define valid state transitions
|
||||
valid_transitions = {
|
||||
RegistrationState.INITIATED.value: [
|
||||
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.PAYMENT_VERIFICATION_PENDING.value: [
|
||||
RegistrationState.PAYMENT_VERIFIED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.PAYMENT_VERIFIED.value: [
|
||||
RegistrationState.SUBSCRIPTION_CREATED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.SUBSCRIPTION_CREATED.value: [
|
||||
RegistrationState.USER_CREATED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.USER_CREATED.value: [
|
||||
RegistrationState.COMPLETED.value,
|
||||
RegistrationState.FAILED.value
|
||||
],
|
||||
RegistrationState.COMPLETED.value: [],
|
||||
RegistrationState.FAILED.value: []
|
||||
}
|
||||
|
||||
return new_state in valid_transitions.get(current_state, [])
|
||||
|
||||
def _can_rollback(self, current_state: str, target_state: str) -> bool:
|
||||
"""
|
||||
Check if rollback to target state is allowed
|
||||
|
||||
Args:
|
||||
current_state: Current state
|
||||
target_state: Target state for rollback
|
||||
|
||||
Returns:
|
||||
True if rollback is allowed
|
||||
"""
|
||||
# Define state order for rollback validation
|
||||
state_order = [
|
||||
RegistrationState.INITIATED.value,
|
||||
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
|
||||
RegistrationState.PAYMENT_VERIFIED.value,
|
||||
RegistrationState.SUBSCRIPTION_CREATED.value,
|
||||
RegistrationState.USER_CREATED.value,
|
||||
RegistrationState.COMPLETED.value
|
||||
]
|
||||
|
||||
try:
|
||||
current_index = state_order.index(current_state)
|
||||
target_index = state_order.index(target_state)
|
||||
|
||||
# Can only rollback to earlier states
|
||||
return target_index < current_index
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
# Singleton instance for dependency injection
|
||||
registration_state_service = RegistrationStateService()
|
||||
258
services/tenant/app/services/subscription_cache.py
Normal file
258
services/tenant/app/services/subscription_cache.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Subscription Cache Service
|
||||
Provides Redis-based caching for subscription data with 10-minute TTL
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Optional, Dict, Any
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.repositories import SubscriptionRepository
|
||||
from app.models.tenants import Subscription
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Cache TTL in seconds (10 minutes)
|
||||
SUBSCRIPTION_CACHE_TTL = 600
|
||||
|
||||
|
||||
class SubscriptionCacheService:
|
||||
"""Service for cached subscription lookups"""
|
||||
|
||||
def __init__(self, redis_client=None, database_manager=None):
|
||||
self.redis = redis_client
|
||||
self.database_manager = database_manager
|
||||
|
||||
async def ensure_database_manager(self):
|
||||
"""Ensure database manager is properly initialized"""
|
||||
if self.database_manager is None:
|
||||
from app.core.config import settings
|
||||
from shared.database.base import create_database_manager
|
||||
self.database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
|
||||
|
||||
async def get_tenant_tier_cached(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get tenant subscription tier with Redis caching
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Subscription tier (starter, professional, enterprise)
|
||||
"""
|
||||
try:
|
||||
# Ensure database manager is initialized
|
||||
await self.ensure_database_manager()
|
||||
|
||||
cache_key = f"subscription:tier:{tenant_id}"
|
||||
|
||||
# Try to get from cache
|
||||
if self.redis:
|
||||
try:
|
||||
cached_tier = await self.redis.get(cache_key)
|
||||
if cached_tier:
|
||||
logger.debug("Subscription tier cache hit", tenant_id=tenant_id, tier=cached_tier)
|
||||
return cached_tier.decode('utf-8') if isinstance(cached_tier, bytes) else cached_tier
|
||||
except Exception as e:
|
||||
logger.warning("Redis cache read failed, falling back to database",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
# Cache miss or Redis unavailable - fetch from database
|
||||
logger.debug("Subscription tier cache miss", tenant_id=tenant_id)
|
||||
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
subscription_repo = SubscriptionRepository(Subscription, db_session)
|
||||
subscription = await subscription_repo.get_active_subscription(tenant_id)
|
||||
|
||||
if not subscription:
|
||||
logger.warning("No active subscription found, returning starter tier",
|
||||
tenant_id=tenant_id)
|
||||
return "starter"
|
||||
|
||||
tier = subscription.plan
|
||||
|
||||
# Cache the result
|
||||
if self.redis:
|
||||
try:
|
||||
await self.redis.setex(cache_key, SUBSCRIPTION_CACHE_TTL, tier)
|
||||
logger.debug("Cached subscription tier", tenant_id=tenant_id, tier=tier)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache subscription tier",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
return tier
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription tier",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
return "starter" # Fallback to starter on error
|
||||
|
||||
async def get_tenant_subscription_cached(self, tenant_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get full tenant subscription with Redis caching
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Subscription data as dictionary or None
|
||||
"""
|
||||
try:
|
||||
# Ensure database manager is initialized
|
||||
await self.ensure_database_manager()
|
||||
|
||||
cache_key = f"subscription:full:{tenant_id}"
|
||||
|
||||
# Try to get from cache
|
||||
if self.redis:
|
||||
try:
|
||||
cached_data = await self.redis.get(cache_key)
|
||||
if cached_data:
|
||||
logger.debug("Subscription cache hit", tenant_id=tenant_id)
|
||||
data = json.loads(cached_data.decode('utf-8') if isinstance(cached_data, bytes) else cached_data)
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning("Redis cache read failed, falling back to database",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
# Cache miss or Redis unavailable - fetch from database
|
||||
logger.debug("Subscription cache miss", tenant_id=tenant_id)
|
||||
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
subscription_repo = SubscriptionRepository(Subscription, db_session)
|
||||
subscription = await subscription_repo.get_active_subscription(tenant_id)
|
||||
|
||||
if not subscription:
|
||||
logger.warning("No active subscription found", tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
# Convert to dictionary
|
||||
subscription_data = {
|
||||
"id": str(subscription.id),
|
||||
"tenant_id": str(subscription.tenant_id),
|
||||
"plan": subscription.plan,
|
||||
"status": subscription.status,
|
||||
"monthly_price": subscription.monthly_price,
|
||||
"billing_cycle": subscription.billing_cycle,
|
||||
"next_billing_date": subscription.next_billing_date.isoformat() if subscription.next_billing_date else None,
|
||||
"trial_ends_at": subscription.trial_ends_at.isoformat() if subscription.trial_ends_at else None,
|
||||
"cancelled_at": subscription.cancelled_at.isoformat() if subscription.cancelled_at else None,
|
||||
"cancellation_effective_date": subscription.cancellation_effective_date.isoformat() if subscription.cancellation_effective_date else None,
|
||||
"max_users": subscription.max_users,
|
||||
"max_locations": subscription.max_locations,
|
||||
"max_products": subscription.max_products,
|
||||
"features": subscription.features,
|
||||
"created_at": subscription.created_at.isoformat() if subscription.created_at else None,
|
||||
"updated_at": subscription.updated_at.isoformat() if subscription.updated_at else None
|
||||
}
|
||||
|
||||
# Cache the result
|
||||
if self.redis:
|
||||
try:
|
||||
await self.redis.setex(
|
||||
cache_key,
|
||||
SUBSCRIPTION_CACHE_TTL,
|
||||
json.dumps(subscription_data)
|
||||
)
|
||||
logger.debug("Cached subscription data", tenant_id=tenant_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cache subscription data",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
return subscription_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
return None
|
||||
|
||||
async def invalidate_subscription_cache(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Invalidate subscription cache for a tenant
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
"""
|
||||
try:
|
||||
if not self.redis:
|
||||
logger.debug("Redis not available, skipping cache invalidation",
|
||||
tenant_id=tenant_id)
|
||||
return
|
||||
|
||||
tier_key = f"subscription:tier:{tenant_id}"
|
||||
full_key = f"subscription:full:{tenant_id}"
|
||||
|
||||
# Delete both cache keys
|
||||
await self.redis.delete(tier_key, full_key)
|
||||
|
||||
logger.info("Invalidated subscription cache",
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to invalidate subscription cache",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
async def warm_cache(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Pre-warm the cache by loading subscription data
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
"""
|
||||
try:
|
||||
logger.debug("Warming subscription cache", tenant_id=tenant_id)
|
||||
|
||||
# Load both tier and full subscription to cache
|
||||
await self.get_tenant_tier_cached(tenant_id)
|
||||
await self.get_tenant_subscription_cached(tenant_id)
|
||||
|
||||
logger.info("Subscription cache warmed", tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to warm subscription cache",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
|
||||
|
||||
# Singleton instance for easy access
|
||||
_cache_service_instance: Optional[SubscriptionCacheService] = None
|
||||
|
||||
|
||||
def get_subscription_cache_service(redis_client=None) -> SubscriptionCacheService:
|
||||
"""
|
||||
Get or create subscription cache service singleton
|
||||
|
||||
Args:
|
||||
redis_client: Optional Redis client
|
||||
|
||||
Returns:
|
||||
SubscriptionCacheService instance
|
||||
"""
|
||||
global _cache_service_instance
|
||||
|
||||
if _cache_service_instance is None:
|
||||
from shared.redis_utils import initialize_redis
|
||||
from app.core.config import settings
|
||||
import asyncio
|
||||
|
||||
# Initialize Redis client if not provided
|
||||
redis_client_instance = None
|
||||
if redis_client is None:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if not loop.is_running():
|
||||
redis_client_instance = asyncio.run(initialize_redis(settings.REDIS_URL))
|
||||
else:
|
||||
# If event loop is running, we can't use asyncio.run
|
||||
# This is a limitation, but we'll handle it by not initializing Redis here
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
redis_client_instance = redis_client
|
||||
|
||||
_cache_service_instance = SubscriptionCacheService(redis_client=redis_client_instance)
|
||||
elif redis_client is not None and _cache_service_instance.redis is None:
|
||||
_cache_service_instance.redis = redis_client
|
||||
|
||||
return _cache_service_instance
|
||||
705
services/tenant/app/services/subscription_limit_service.py
Normal file
705
services/tenant/app/services/subscription_limit_service.py
Normal file
@@ -0,0 +1,705 @@
|
||||
"""
|
||||
Subscription Limit Service
|
||||
Service for validating tenant actions against subscription limits and features
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from fastapi import HTTPException, status
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.repositories import SubscriptionRepository, TenantRepository, TenantMemberRepository
|
||||
from app.models.tenants import Subscription, Tenant, TenantMember
|
||||
from shared.database.exceptions import DatabaseError
|
||||
from shared.database.base import create_database_manager
|
||||
from shared.subscription.plans import SubscriptionPlanMetadata, get_training_job_quota, get_forecast_quota
|
||||
from shared.clients.recipes_client import create_recipes_client
|
||||
from shared.clients.suppliers_client import create_suppliers_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionLimitService:
|
||||
"""Service for validating subscription limits and features"""
|
||||
|
||||
def __init__(self, database_manager=None, redis_client=None):
|
||||
self.database_manager = database_manager or create_database_manager()
|
||||
self.redis = redis_client
|
||||
|
||||
async def _init_repositories(self, session):
|
||||
"""Initialize repositories with session"""
|
||||
self.subscription_repo = SubscriptionRepository(Subscription, session)
|
||||
self.tenant_repo = TenantRepository(Tenant, session)
|
||||
self.member_repo = TenantMemberRepository(TenantMember, session)
|
||||
return {
|
||||
'subscription': self.subscription_repo,
|
||||
'tenant': self.tenant_repo,
|
||||
'member': self.member_repo
|
||||
}
|
||||
|
||||
async def get_tenant_subscription_limits(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get current subscription limits for a tenant"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
# Return basic limits if no subscription
|
||||
return {
|
||||
"plan": "starter",
|
||||
"max_users": 5,
|
||||
"max_locations": 1,
|
||||
"max_products": 50,
|
||||
"features": {
|
||||
"inventory_management": "basic",
|
||||
"demand_prediction": "basic",
|
||||
"production_reports": "basic",
|
||||
"analytics": "basic",
|
||||
"support": "email",
|
||||
"ai_model_configuration": "basic" # Added AI model configuration for all tiers
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"plan": subscription.plan,
|
||||
"max_users": subscription.max_users,
|
||||
"max_locations": subscription.max_locations,
|
||||
"max_products": subscription.max_products,
|
||||
"features": subscription.features or {}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get subscription limits",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
# Return basic limits on error
|
||||
return {
|
||||
"plan": "starter",
|
||||
"max_users": 5,
|
||||
"max_locations": 1,
|
||||
"max_products": 50,
|
||||
"features": {}
|
||||
}
|
||||
|
||||
async def can_add_location(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Check if tenant can add another location"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
# Get subscription limits
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
return {"can_add": False, "reason": "No active subscription"}
|
||||
|
||||
# Check if unlimited locations (-1)
|
||||
if subscription.max_locations == -1:
|
||||
return {"can_add": True, "reason": "Unlimited locations allowed"}
|
||||
|
||||
# Count current locations
|
||||
# Currently, each tenant has 1 location (their primary bakery location)
|
||||
# This is stored in tenant.address, tenant.city, tenant.postal_code
|
||||
# If multi-location support is added in the future, this would query a locations table
|
||||
current_locations = 1 # Each tenant has one primary location
|
||||
|
||||
can_add = current_locations < subscription.max_locations
|
||||
return {
|
||||
"can_add": can_add,
|
||||
"current_count": current_locations,
|
||||
"max_allowed": subscription.max_locations,
|
||||
"reason": "Within limits" if can_add else f"Maximum {subscription.max_locations} locations allowed for {subscription.plan} plan"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check location limits",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {"can_add": False, "reason": "Error checking limits"}
|
||||
|
||||
async def can_add_product(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Check if tenant can add another product"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
# Get subscription limits
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
return {"can_add": False, "reason": "No active subscription"}
|
||||
|
||||
# Check if unlimited products (-1)
|
||||
if subscription.max_products == -1:
|
||||
return {"can_add": True, "reason": "Unlimited products allowed"}
|
||||
|
||||
# Count current products from inventory service
|
||||
current_products = await self._get_ingredient_count(tenant_id)
|
||||
|
||||
can_add = current_products < subscription.max_products
|
||||
return {
|
||||
"can_add": can_add,
|
||||
"current_count": current_products,
|
||||
"max_allowed": subscription.max_products,
|
||||
"reason": "Within limits" if can_add else f"Maximum {subscription.max_products} products allowed for {subscription.plan} plan"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check product limits",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {"can_add": False, "reason": "Error checking limits"}
|
||||
|
||||
async def can_add_user(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Check if tenant can add another user/member"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
# Get subscription limits
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
return {"can_add": False, "reason": "No active subscription"}
|
||||
|
||||
# Check if unlimited users (-1)
|
||||
if subscription.max_users == -1:
|
||||
return {"can_add": True, "reason": "Unlimited users allowed"}
|
||||
|
||||
# Count current active members
|
||||
members = await self.member_repo.get_tenant_members(tenant_id, active_only=True)
|
||||
current_users = len(members)
|
||||
|
||||
can_add = current_users < subscription.max_users
|
||||
return {
|
||||
"can_add": can_add,
|
||||
"current_count": current_users,
|
||||
"max_allowed": subscription.max_users,
|
||||
"reason": "Within limits" if can_add else f"Maximum {subscription.max_users} users allowed for {subscription.plan} plan"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check user limits",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {"can_add": False, "reason": "Error checking limits"}
|
||||
|
||||
async def can_add_recipe(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Check if tenant can add another recipe"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
return {"can_add": False, "reason": "No active subscription"}
|
||||
|
||||
# Get recipe limit from plan
|
||||
recipes_limit = await self._get_limit_from_plan(subscription.plan, 'recipes')
|
||||
|
||||
# Check if unlimited (-1 or None)
|
||||
if recipes_limit is None or recipes_limit == -1:
|
||||
return {"can_add": True, "reason": "Unlimited recipes allowed"}
|
||||
|
||||
# Count current recipes from recipes service
|
||||
current_recipes = await self._get_recipe_count(tenant_id)
|
||||
|
||||
can_add = current_recipes < recipes_limit
|
||||
return {
|
||||
"can_add": can_add,
|
||||
"current_count": current_recipes,
|
||||
"max_allowed": recipes_limit,
|
||||
"reason": "Within limits" if can_add else f"Maximum {recipes_limit} recipes allowed for {subscription.plan} plan"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check recipe limits",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {"can_add": False, "reason": "Error checking limits"}
|
||||
|
||||
async def can_add_supplier(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Check if tenant can add another supplier"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
return {"can_add": False, "reason": "No active subscription"}
|
||||
|
||||
# Get supplier limit from plan
|
||||
suppliers_limit = await self._get_limit_from_plan(subscription.plan, 'suppliers')
|
||||
|
||||
# Check if unlimited (-1 or None)
|
||||
if suppliers_limit is None or suppliers_limit == -1:
|
||||
return {"can_add": True, "reason": "Unlimited suppliers allowed"}
|
||||
|
||||
# Count current suppliers from suppliers service
|
||||
current_suppliers = await self._get_supplier_count(tenant_id)
|
||||
|
||||
can_add = current_suppliers < suppliers_limit
|
||||
return {
|
||||
"can_add": can_add,
|
||||
"current_count": current_suppliers,
|
||||
"max_allowed": suppliers_limit,
|
||||
"reason": "Within limits" if can_add else f"Maximum {suppliers_limit} suppliers allowed for {subscription.plan} plan"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check supplier limits",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {"can_add": False, "reason": "Error checking limits"}
|
||||
|
||||
async def has_feature(self, tenant_id: str, feature: str) -> Dict[str, Any]:
|
||||
"""Check if tenant has access to a specific feature"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
return {"has_feature": False, "reason": "No active subscription"}
|
||||
|
||||
features = subscription.features or {}
|
||||
has_feature = feature in features
|
||||
|
||||
return {
|
||||
"has_feature": has_feature,
|
||||
"feature_value": features.get(feature),
|
||||
"plan": subscription.plan,
|
||||
"reason": "Feature available" if has_feature else f"Feature '{feature}' not available in {subscription.plan} plan"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check feature access",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e))
|
||||
return {"has_feature": False, "reason": "Error checking feature access"}
|
||||
|
||||
async def get_feature_level(self, tenant_id: str, feature: str) -> Optional[str]:
|
||||
"""Get the level/type of a feature for a tenant"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
return None
|
||||
|
||||
features = subscription.features or {}
|
||||
return features.get(feature)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get feature level",
|
||||
tenant_id=tenant_id,
|
||||
feature=feature,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def validate_plan_upgrade(self, tenant_id: str, new_plan: str) -> Dict[str, Any]:
|
||||
"""Validate if a tenant can upgrade to a new plan"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
# Get current subscription
|
||||
current_subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not current_subscription:
|
||||
return {"can_upgrade": True, "reason": "No current subscription, can start with any plan"}
|
||||
|
||||
# Define plan hierarchy
|
||||
plan_hierarchy = {"starter": 1, "professional": 2, "enterprise": 3}
|
||||
|
||||
current_level = plan_hierarchy.get(current_subscription.plan, 0)
|
||||
new_level = plan_hierarchy.get(new_plan, 0)
|
||||
|
||||
if new_level == 0:
|
||||
return {"can_upgrade": False, "reason": f"Invalid plan: {new_plan}"}
|
||||
|
||||
# Check current usage against new plan limits
|
||||
from shared.subscription.plans import SubscriptionPlanMetadata, PlanPricing
|
||||
new_plan_config = SubscriptionPlanMetadata.get_plan_info(new_plan)
|
||||
|
||||
# Get the max_users limit from the plan limits
|
||||
plan_limits = new_plan_config.get('limits', {})
|
||||
max_users_limit = plan_limits.get('users', 5) # Default to 5 if not specified
|
||||
# Convert "Unlimited" string to None for comparison
|
||||
if max_users_limit == "Unlimited":
|
||||
max_users_limit = None
|
||||
elif max_users_limit is None:
|
||||
max_users_limit = -1 # Use -1 to represent unlimited in the comparison
|
||||
|
||||
# Check if current usage fits new plan
|
||||
members = await self.member_repo.get_tenant_members(tenant_id, active_only=True)
|
||||
current_users = len(members)
|
||||
|
||||
if max_users_limit is not None and max_users_limit != -1 and current_users > max_users_limit:
|
||||
return {
|
||||
"can_upgrade": False,
|
||||
"reason": f"Current usage ({current_users} users) exceeds {new_plan} plan limits ({max_users_limit} users)"
|
||||
}
|
||||
|
||||
return {
|
||||
"can_upgrade": True,
|
||||
"current_plan": current_subscription.plan,
|
||||
"new_plan": new_plan,
|
||||
"price_change": float(PlanPricing.get_price(new_plan)) - current_subscription.monthly_price,
|
||||
"new_features": new_plan_config.get("features", []),
|
||||
"reason": "Upgrade validation successful"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to validate plan upgrade",
|
||||
tenant_id=tenant_id,
|
||||
new_plan=new_plan,
|
||||
error=str(e))
|
||||
return {"can_upgrade": False, "reason": "Error validating upgrade"}
|
||||
|
||||
async def get_usage_summary(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get a summary of current usage vs limits for a tenant - ALL 9 METRICS"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
await self._init_repositories(db_session)
|
||||
|
||||
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
|
||||
if not subscription:
|
||||
logger.info("No subscription found, returning mock data", tenant_id=tenant_id)
|
||||
return {
|
||||
"plan": "demo",
|
||||
"monthly_price": 0,
|
||||
"status": "active",
|
||||
"billing_cycle": "monthly",
|
||||
"usage": {
|
||||
"users": {
|
||||
"current": 1,
|
||||
"limit": 5,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 20.0
|
||||
},
|
||||
"locations": {
|
||||
"current": 1,
|
||||
"limit": 1,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 100.0
|
||||
},
|
||||
"products": {
|
||||
"current": 0,
|
||||
"limit": 50,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 0.0
|
||||
},
|
||||
"recipes": {
|
||||
"current": 0,
|
||||
"limit": 100,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 0.0
|
||||
},
|
||||
"suppliers": {
|
||||
"current": 0,
|
||||
"limit": 20,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 0.0
|
||||
},
|
||||
"training_jobs_today": {
|
||||
"current": 0,
|
||||
"limit": 2,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 0.0
|
||||
},
|
||||
"forecasts_today": {
|
||||
"current": 0,
|
||||
"limit": 10,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 0.0
|
||||
},
|
||||
"api_calls_this_hour": {
|
||||
"current": 0,
|
||||
"limit": 100,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 0.0
|
||||
},
|
||||
"file_storage_used_gb": {
|
||||
"current": 0.0,
|
||||
"limit": 1.0,
|
||||
"unlimited": False,
|
||||
"usage_percentage": 0.0
|
||||
}
|
||||
},
|
||||
"features": {},
|
||||
"next_billing_date": None,
|
||||
"trial_ends_at": None
|
||||
}
|
||||
|
||||
# Get current usage - Team & Organization
|
||||
members = await self.member_repo.get_tenant_members(tenant_id, active_only=True)
|
||||
current_users = len(members)
|
||||
current_locations = 1 # Each tenant has one primary location
|
||||
|
||||
# Get current usage - Products & Inventory (parallel calls for performance)
|
||||
import asyncio
|
||||
current_products, current_recipes, current_suppliers = await asyncio.gather(
|
||||
self._get_ingredient_count(tenant_id),
|
||||
self._get_recipe_count(tenant_id),
|
||||
self._get_supplier_count(tenant_id)
|
||||
)
|
||||
|
||||
# Get current usage - IA & Analytics + API & Storage (parallel Redis calls for performance)
|
||||
training_jobs_usage, forecasts_usage, api_calls_usage, storage_usage = await asyncio.gather(
|
||||
self._get_training_jobs_today(tenant_id, subscription.plan),
|
||||
self._get_forecasts_today(tenant_id, subscription.plan),
|
||||
self._get_api_calls_this_hour(tenant_id, subscription.plan),
|
||||
self._get_file_storage_usage_gb(tenant_id, subscription.plan)
|
||||
)
|
||||
|
||||
# Get limits from subscription
|
||||
recipes_limit = await self._get_limit_from_plan(subscription.plan, 'recipes')
|
||||
suppliers_limit = await self._get_limit_from_plan(subscription.plan, 'suppliers')
|
||||
|
||||
return {
|
||||
"plan": subscription.plan,
|
||||
"monthly_price": subscription.monthly_price,
|
||||
"status": subscription.status,
|
||||
"billing_cycle": subscription.billing_cycle or "monthly",
|
||||
"usage": {
|
||||
# Team & Organization
|
||||
"users": {
|
||||
"current": current_users,
|
||||
"limit": subscription.max_users,
|
||||
"unlimited": subscription.max_users == -1,
|
||||
"usage_percentage": 0 if subscription.max_users == -1 else self._calculate_percentage(current_users, subscription.max_users)
|
||||
},
|
||||
"locations": {
|
||||
"current": current_locations,
|
||||
"limit": subscription.max_locations,
|
||||
"unlimited": subscription.max_locations == -1,
|
||||
"usage_percentage": 0 if subscription.max_locations == -1 else self._calculate_percentage(current_locations, subscription.max_locations)
|
||||
},
|
||||
# Products & Inventory
|
||||
"products": {
|
||||
"current": current_products,
|
||||
"limit": subscription.max_products,
|
||||
"unlimited": subscription.max_products == -1,
|
||||
"usage_percentage": 0 if subscription.max_products == -1 else self._calculate_percentage(current_products, subscription.max_products)
|
||||
},
|
||||
"recipes": {
|
||||
"current": current_recipes,
|
||||
"limit": recipes_limit,
|
||||
"unlimited": recipes_limit is None,
|
||||
"usage_percentage": self._calculate_percentage(current_recipes, recipes_limit)
|
||||
},
|
||||
"suppliers": {
|
||||
"current": current_suppliers,
|
||||
"limit": suppliers_limit,
|
||||
"unlimited": suppliers_limit is None,
|
||||
"usage_percentage": self._calculate_percentage(current_suppliers, suppliers_limit)
|
||||
},
|
||||
# IA & Analytics (Daily quotas)
|
||||
"training_jobs_today": training_jobs_usage,
|
||||
"forecasts_today": forecasts_usage,
|
||||
# API & Storage
|
||||
"api_calls_this_hour": api_calls_usage,
|
||||
"file_storage_used_gb": storage_usage
|
||||
},
|
||||
"features": subscription.features or {},
|
||||
"next_billing_date": subscription.next_billing_date.isoformat() if subscription.next_billing_date else None,
|
||||
"trial_ends_at": subscription.trial_ends_at.isoformat() if subscription.trial_ends_at else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get usage summary",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {"error": "Failed to get usage summary"}
|
||||
|
||||
async def _get_ingredient_count(self, tenant_id: str) -> int:
|
||||
"""Get ingredient count from inventory service using shared client"""
|
||||
try:
|
||||
from app.core.config import settings
|
||||
from shared.clients.inventory_client import create_inventory_client
|
||||
|
||||
# Use the shared inventory client with proper authentication
|
||||
inventory_client = create_inventory_client(settings, service_name="tenant")
|
||||
count = await inventory_client.count_ingredients(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"Retrieved ingredient count via inventory client",
|
||||
tenant_id=tenant_id,
|
||||
count=count
|
||||
)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting ingredient count via inventory client",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Return 0 as fallback to avoid breaking subscription display
|
||||
return 0
|
||||
|
||||
async def _get_recipe_count(self, tenant_id: str) -> int:
|
||||
"""Get recipe count from recipes service using shared client"""
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
# Use the shared recipes client with proper authentication and resilience
|
||||
recipes_client = create_recipes_client(settings, service_name="tenant")
|
||||
count = await recipes_client.count_recipes(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"Retrieved recipe count via recipes client",
|
||||
tenant_id=tenant_id,
|
||||
count=count
|
||||
)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting recipe count via recipes client",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Return 0 as fallback to avoid breaking subscription display
|
||||
return 0
|
||||
|
||||
async def _get_supplier_count(self, tenant_id: str) -> int:
|
||||
"""Get supplier count from suppliers service using shared client"""
|
||||
try:
|
||||
from app.core.config import settings
|
||||
|
||||
# Use the shared suppliers client with proper authentication and resilience
|
||||
suppliers_client = create_suppliers_client(settings, service_name="tenant")
|
||||
count = await suppliers_client.count_suppliers(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"Retrieved supplier count via suppliers client",
|
||||
tenant_id=tenant_id,
|
||||
count=count
|
||||
)
|
||||
return count
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error getting supplier count via suppliers client",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e)
|
||||
)
|
||||
# Return 0 as fallback to avoid breaking subscription display
|
||||
return 0
|
||||
|
||||
async def _get_redis_quota(self, quota_key: str) -> int:
|
||||
"""Get current count from Redis quota key"""
|
||||
try:
|
||||
if not self.redis:
|
||||
# Try to initialize Redis client if not available
|
||||
from app.core.config import settings
|
||||
import shared.redis_utils
|
||||
self.redis = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
|
||||
if not self.redis:
|
||||
return 0
|
||||
|
||||
current = await self.redis.get(quota_key)
|
||||
return int(current) if current else 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting Redis quota", key=quota_key, error=str(e))
|
||||
return 0
|
||||
|
||||
async def _get_training_jobs_today(self, tenant_id: str, plan: str) -> Dict[str, Any]:
|
||||
"""Get training jobs usage for today"""
|
||||
try:
|
||||
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
quota_key = f"quota:daily:training_jobs:{tenant_id}:{date_str}"
|
||||
current_count = await self._get_redis_quota(quota_key)
|
||||
|
||||
limit = get_training_job_quota(plan)
|
||||
|
||||
return {
|
||||
"current": current_count,
|
||||
"limit": limit,
|
||||
"unlimited": limit is None,
|
||||
"usage_percentage": self._calculate_percentage(current_count, limit)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting training jobs today", tenant_id=tenant_id, error=str(e))
|
||||
return {"current": 0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
|
||||
|
||||
async def _get_forecasts_today(self, tenant_id: str, plan: str) -> Dict[str, Any]:
|
||||
"""Get forecast generation usage for today"""
|
||||
try:
|
||||
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
quota_key = f"quota:daily:forecast_generation:{tenant_id}:{date_str}"
|
||||
current_count = await self._get_redis_quota(quota_key)
|
||||
|
||||
limit = get_forecast_quota(plan)
|
||||
|
||||
return {
|
||||
"current": current_count,
|
||||
"limit": limit,
|
||||
"unlimited": limit is None,
|
||||
"usage_percentage": self._calculate_percentage(current_count, limit)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting forecasts today", tenant_id=tenant_id, error=str(e))
|
||||
return {"current": 0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
|
||||
|
||||
async def _get_api_calls_this_hour(self, tenant_id: str, plan: str) -> Dict[str, Any]:
|
||||
"""Get API calls usage for current hour"""
|
||||
try:
|
||||
hour_str = datetime.now(timezone.utc).strftime('%Y-%m-%d-%H')
|
||||
quota_key = f"quota:hourly:api_calls:{tenant_id}:{hour_str}"
|
||||
current_count = await self._get_redis_quota(quota_key)
|
||||
|
||||
plan_metadata = SubscriptionPlanMetadata.PLANS.get(plan, {})
|
||||
limit = plan_metadata.get('limits', {}).get('api_calls_per_hour')
|
||||
|
||||
return {
|
||||
"current": current_count,
|
||||
"limit": limit,
|
||||
"unlimited": limit is None,
|
||||
"usage_percentage": self._calculate_percentage(current_count, limit)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting API calls this hour", tenant_id=tenant_id, error=str(e))
|
||||
return {"current": 0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
|
||||
|
||||
async def _get_file_storage_usage_gb(self, tenant_id: str, plan: str) -> Dict[str, Any]:
|
||||
"""Get file storage usage in GB"""
|
||||
try:
|
||||
storage_key = f"storage:total_bytes:{tenant_id}"
|
||||
total_bytes = await self._get_redis_quota(storage_key)
|
||||
total_gb = round(total_bytes / (1024 ** 3), 2) if total_bytes > 0 else 0.0
|
||||
|
||||
plan_metadata = SubscriptionPlanMetadata.PLANS.get(plan, {})
|
||||
limit = plan_metadata.get('limits', {}).get('file_storage_gb')
|
||||
|
||||
return {
|
||||
"current": total_gb,
|
||||
"limit": limit,
|
||||
"unlimited": limit is None,
|
||||
"usage_percentage": self._calculate_percentage(total_gb, limit)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting file storage usage", tenant_id=tenant_id, error=str(e))
|
||||
return {"current": 0.0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
|
||||
|
||||
def _calculate_percentage(self, current: float, limit: Optional[int]) -> float:
|
||||
"""Calculate usage percentage"""
|
||||
if limit is None or limit == -1:
|
||||
return 0.0
|
||||
if limit == 0:
|
||||
return 0.0
|
||||
return round((current / limit) * 100, 1)
|
||||
|
||||
async def _get_limit_from_plan(self, plan: str, limit_key: str) -> Optional[int]:
|
||||
"""Get limit value from plan metadata"""
|
||||
plan_metadata = SubscriptionPlanMetadata.PLANS.get(plan, {})
|
||||
limit = plan_metadata.get('limits', {}).get(limit_key)
|
||||
return limit if limit != -1 else None
|
||||
2153
services/tenant/app/services/subscription_orchestration_service.py
Normal file
2153
services/tenant/app/services/subscription_orchestration_service.py
Normal file
File diff suppressed because it is too large
Load Diff
792
services/tenant/app/services/subscription_service.py
Normal file
792
services/tenant/app/services/subscription_service.py
Normal file
@@ -0,0 +1,792 @@
|
||||
"""
|
||||
Subscription Service - State Manager
|
||||
This service handles ONLY subscription database operations and state management
|
||||
NO payment provider interactions, NO orchestration, NO coupon logic
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from uuid import UUID
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models.tenants import Subscription, Tenant
|
||||
from app.repositories.subscription_repository import SubscriptionRepository
|
||||
from app.core.config import settings
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
from shared.subscription.plans import PlanPricing, QuotaLimits, SubscriptionPlanMetadata
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SubscriptionService:
|
||||
"""Service for managing subscription state and database operations ONLY"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
self.subscription_repo = SubscriptionRepository(Subscription, db_session)
|
||||
|
||||
async def create_subscription_record(
|
||||
self,
|
||||
tenant_id: str,
|
||||
subscription_id: str,
|
||||
customer_id: str,
|
||||
plan: str,
|
||||
status: str,
|
||||
trial_period_days: Optional[int] = None,
|
||||
billing_interval: str = "monthly"
|
||||
) -> Subscription:
|
||||
"""
|
||||
Create a local subscription record in the database
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
subscription_id: Payment provider subscription ID
|
||||
customer_id: Payment provider customer ID
|
||||
plan: Subscription plan
|
||||
status: Subscription status
|
||||
trial_period_days: Optional trial period in days
|
||||
billing_interval: Billing interval (monthly or yearly)
|
||||
|
||||
Returns:
|
||||
Created Subscription object
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
|
||||
# Verify tenant exists
|
||||
query = select(Tenant).where(Tenant.id == tenant_uuid)
|
||||
result = await self.db_session.execute(query)
|
||||
tenant = result.scalar_one_or_none()
|
||||
|
||||
if not tenant:
|
||||
raise ValidationError(f"Tenant not found: {tenant_id}")
|
||||
|
||||
# Create local subscription record
|
||||
subscription_data = {
|
||||
'tenant_id': str(tenant_id),
|
||||
'subscription_id': subscription_id,
|
||||
'customer_id': customer_id,
|
||||
'plan': plan,
|
||||
'status': status,
|
||||
'created_at': datetime.now(timezone.utc),
|
||||
'billing_cycle': billing_interval
|
||||
}
|
||||
|
||||
# Add trial-related data if applicable
|
||||
if trial_period_days and trial_period_days > 0:
|
||||
from datetime import timedelta
|
||||
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
|
||||
subscription_data['trial_ends_at'] = trial_ends_at
|
||||
|
||||
# Check if subscription with this subscription_id already exists to prevent duplicates
|
||||
existing_subscription = await self.subscription_repo.get_by_provider_id(subscription_id)
|
||||
if existing_subscription:
|
||||
# Update the existing subscription instead of creating a duplicate
|
||||
updated_subscription = await self.subscription_repo.update(str(existing_subscription.id), subscription_data)
|
||||
|
||||
logger.info("Existing subscription updated",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
plan=plan)
|
||||
|
||||
return updated_subscription
|
||||
else:
|
||||
# Create new subscription
|
||||
created_subscription = await self.subscription_repo.create(subscription_data)
|
||||
|
||||
logger.info("subscription_record_created",
|
||||
tenant_id=tenant_id,
|
||||
subscription_id=subscription_id,
|
||||
plan=plan)
|
||||
|
||||
return created_subscription
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"create_subscription_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error(f"create_subscription_record_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to create subscription record: {str(e)}")
|
||||
|
||||
async def update_subscription_status(
|
||||
self,
|
||||
tenant_id: str,
|
||||
status: str,
|
||||
stripe_data: Optional[Dict[str, Any]] = None
|
||||
) -> Subscription:
|
||||
"""
|
||||
Update subscription status in database
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
status: New subscription status
|
||||
stripe_data: Optional data from Stripe to update
|
||||
|
||||
Returns:
|
||||
Updated Subscription object
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
|
||||
# Get subscription from repository
|
||||
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
|
||||
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
|
||||
|
||||
# Prepare update data
|
||||
update_data = {
|
||||
'status': status,
|
||||
'updated_at': datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
# Include Stripe data if provided
|
||||
if stripe_data:
|
||||
# Note: current_period_start and current_period_end are not in the local model
|
||||
# These would need to be stored separately or handled differently
|
||||
# For now, we'll skip storing these Stripe-specific fields in the local model
|
||||
pass
|
||||
|
||||
# Update status flags based on status value
|
||||
if status == 'active':
|
||||
update_data['is_active'] = True
|
||||
update_data['cancelled_at'] = None
|
||||
elif status in ['canceled', 'past_due', 'unpaid', 'inactive']:
|
||||
update_data['is_active'] = False
|
||||
elif status == 'pending_cancellation':
|
||||
update_data['is_active'] = True # Still active until effective date
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
|
||||
|
||||
# Invalidate subscription cache
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info("subscription_status_updated",
|
||||
tenant_id=tenant_id,
|
||||
old_status=subscription.status,
|
||||
new_status=status)
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"update_subscription_status_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error(f"update_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
|
||||
|
||||
async def get_subscription_by_tenant_id(
|
||||
self,
|
||||
tenant_id: str
|
||||
) -> Optional[Subscription]:
|
||||
"""
|
||||
Get subscription by tenant ID
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
|
||||
Returns:
|
||||
Subscription object or None
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
return await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
|
||||
except Exception as e:
|
||||
logger.error("get_subscription_by_tenant_id_failed",
|
||||
error=str(e), tenant_id=tenant_id)
|
||||
return None
|
||||
|
||||
async def get_subscription_by_provider_id(
|
||||
self,
|
||||
subscription_id: str
|
||||
) -> Optional[Subscription]:
|
||||
"""
|
||||
Get subscription by payment provider subscription ID
|
||||
|
||||
Args:
|
||||
subscription_id: Payment provider subscription ID
|
||||
|
||||
Returns:
|
||||
Subscription object or None
|
||||
"""
|
||||
try:
|
||||
return await self.subscription_repo.get_by_provider_id(subscription_id)
|
||||
except Exception as e:
|
||||
logger.error("get_subscription_by_provider_id_failed",
|
||||
error=str(e), subscription_id=subscription_id)
|
||||
return None
|
||||
|
||||
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
|
||||
"""
|
||||
Get subscriptions by Stripe customer ID
|
||||
|
||||
Args:
|
||||
customer_id: Stripe customer ID
|
||||
|
||||
Returns:
|
||||
List of Subscription objects
|
||||
"""
|
||||
try:
|
||||
return await self.subscription_repo.get_by_customer_id(customer_id)
|
||||
except Exception as e:
|
||||
logger.error("get_subscriptions_by_customer_id_failed",
|
||||
error=str(e), customer_id=customer_id)
|
||||
return []
|
||||
|
||||
async def cancel_subscription(
|
||||
self,
|
||||
tenant_id: str,
|
||||
reason: str = ""
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Mark subscription as pending cancellation in database
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to cancel subscription for
|
||||
reason: Optional cancellation reason
|
||||
|
||||
Returns:
|
||||
Dictionary with cancellation details
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
|
||||
# Get subscription from repository
|
||||
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
|
||||
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
|
||||
|
||||
if subscription.status in ['pending_cancellation', 'inactive']:
|
||||
raise ValidationError(f"Subscription is already {subscription.status}")
|
||||
|
||||
# Calculate cancellation effective date (end of billing period)
|
||||
cancellation_effective_date = subscription.next_billing_date or (
|
||||
datetime.now(timezone.utc) + timedelta(days=30)
|
||||
)
|
||||
|
||||
# Update subscription status in database
|
||||
update_data = {
|
||||
'status': 'pending_cancellation',
|
||||
'cancelled_at': datetime.now(timezone.utc),
|
||||
'cancellation_effective_date': cancellation_effective_date
|
||||
}
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
|
||||
|
||||
# Invalidate subscription cache
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
days_remaining = (cancellation_effective_date - datetime.now(timezone.utc)).days
|
||||
|
||||
logger.info(
|
||||
"subscription_cancelled",
|
||||
tenant_id=str(tenant_id),
|
||||
effective_date=cancellation_effective_date.isoformat(),
|
||||
reason=reason[:200] if reason else None
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Subscription cancelled successfully. You will have read-only access until the end of your billing period.",
|
||||
"status": "pending_cancellation",
|
||||
"cancellation_effective_date": cancellation_effective_date.isoformat(),
|
||||
"days_remaining": days_remaining,
|
||||
"read_only_mode_starts": cancellation_effective_date.isoformat()
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"subscription_cancellation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error(f"subscription_cancellation_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
|
||||
|
||||
async def reactivate_subscription(
|
||||
self,
|
||||
tenant_id: str,
|
||||
plan: str = "starter"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reactivate a cancelled or inactive subscription
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to reactivate subscription for
|
||||
plan: Plan to reactivate with
|
||||
|
||||
Returns:
|
||||
Dictionary with reactivation details
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
|
||||
# Get subscription from repository
|
||||
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
|
||||
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
|
||||
|
||||
if subscription.status not in ['pending_cancellation', 'inactive']:
|
||||
raise ValidationError(f"Cannot reactivate subscription with status: {subscription.status}")
|
||||
|
||||
# Update subscription status and plan
|
||||
update_data = {
|
||||
'status': 'active',
|
||||
'plan': plan,
|
||||
'cancelled_at': None,
|
||||
'cancellation_effective_date': None
|
||||
}
|
||||
|
||||
if subscription.status == 'inactive':
|
||||
update_data['next_billing_date'] = datetime.now(timezone.utc) + timedelta(days=30)
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
|
||||
|
||||
# Invalidate subscription cache
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"subscription_reactivated",
|
||||
tenant_id=str(tenant_id),
|
||||
new_plan=plan
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Subscription reactivated successfully",
|
||||
"status": "active",
|
||||
"plan": plan,
|
||||
"next_billing_date": updated_subscription.next_billing_date.isoformat() if updated_subscription.next_billing_date else None
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"subscription_reactivation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error(f"subscription_reactivation_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
|
||||
|
||||
async def get_subscription_status(
|
||||
self,
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current subscription status including read-only mode info
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID to get status for
|
||||
|
||||
Returns:
|
||||
Dictionary with subscription status details
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
|
||||
# Get subscription from repository
|
||||
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
|
||||
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
|
||||
|
||||
is_read_only = subscription.status in ['pending_cancellation', 'inactive']
|
||||
days_until_inactive = None
|
||||
|
||||
if subscription.status == 'pending_cancellation' and subscription.cancellation_effective_date:
|
||||
days_until_inactive = (subscription.cancellation_effective_date - datetime.now(timezone.utc)).days
|
||||
|
||||
return {
|
||||
"tenant_id": str(tenant_id),
|
||||
"status": subscription.status,
|
||||
"plan": subscription.plan,
|
||||
"is_read_only": is_read_only,
|
||||
"cancellation_effective_date": subscription.cancellation_effective_date.isoformat() if subscription.cancellation_effective_date else None,
|
||||
"days_until_inactive": days_until_inactive
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("get_subscription_status_validation_failed",
|
||||
error=str(ve), tenant_id=tenant_id)
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error(f"get_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to get subscription status: {str(e)}")
|
||||
|
||||
async def update_subscription_plan_record(
|
||||
self,
|
||||
tenant_id: str,
|
||||
new_plan: str,
|
||||
new_status: str,
|
||||
new_period_start: datetime,
|
||||
new_period_end: datetime,
|
||||
billing_cycle: str = "monthly",
|
||||
proration_details: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update local subscription plan record in database
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
new_plan: New plan name
|
||||
new_status: New subscription status
|
||||
new_period_start: New period start date
|
||||
new_period_end: New period end date
|
||||
billing_cycle: Billing cycle for the new plan
|
||||
proration_details: Proration details from payment provider
|
||||
|
||||
Returns:
|
||||
Dictionary with update results
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
|
||||
# Get current subscription
|
||||
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
|
||||
|
||||
# Update local subscription record
|
||||
update_data = {
|
||||
'plan': new_plan,
|
||||
'status': new_status,
|
||||
'updated_at': datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
# Note: current_period_start and current_period_end are not in the local model
|
||||
# These Stripe-specific fields would need to be stored separately
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
|
||||
|
||||
# Invalidate subscription cache
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
logger.info(
|
||||
"subscription_plan_record_updated",
|
||||
tenant_id=str(tenant_id),
|
||||
old_plan=subscription.plan,
|
||||
new_plan=new_plan,
|
||||
proration_amount=proration_details.get("net_amount", 0) if proration_details else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Subscription plan record updated to {new_plan}",
|
||||
"old_plan": subscription.plan,
|
||||
"new_plan": new_plan,
|
||||
"proration_details": proration_details,
|
||||
"new_status": new_status,
|
||||
"new_period_end": new_period_end.isoformat()
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"update_subscription_plan_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error(f"update_subscription_plan_record_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to update subscription plan record: {str(e)}")
|
||||
|
||||
async def update_billing_cycle_record(
|
||||
self,
|
||||
tenant_id: str,
|
||||
new_billing_cycle: str,
|
||||
new_status: str,
|
||||
new_period_start: datetime,
|
||||
new_period_end: datetime,
|
||||
current_plan: str,
|
||||
proration_details: Optional[Dict[str, Any]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Update local billing cycle record in database
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
new_billing_cycle: New billing cycle ('monthly' or 'yearly')
|
||||
new_status: New subscription status
|
||||
new_period_start: New period start date
|
||||
new_period_end: New period end date
|
||||
current_plan: Current plan name
|
||||
proration_details: Proration details from payment provider
|
||||
|
||||
Returns:
|
||||
Dictionary with billing cycle update results
|
||||
"""
|
||||
try:
|
||||
tenant_uuid = UUID(tenant_id)
|
||||
|
||||
# Get current subscription
|
||||
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
|
||||
if not subscription:
|
||||
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
|
||||
|
||||
# Update local subscription record
|
||||
update_data = {
|
||||
'status': new_status,
|
||||
'updated_at': datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
# Note: current_period_start and current_period_end are not in the local model
|
||||
# These Stripe-specific fields would need to be stored separately
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
|
||||
|
||||
# Invalidate subscription cache
|
||||
await self._invalidate_cache(tenant_id)
|
||||
|
||||
old_billing_cycle = getattr(subscription, 'billing_cycle', 'monthly')
|
||||
|
||||
logger.info(
|
||||
"subscription_billing_cycle_record_updated",
|
||||
tenant_id=str(tenant_id),
|
||||
old_billing_cycle=old_billing_cycle,
|
||||
new_billing_cycle=new_billing_cycle,
|
||||
proration_amount=proration_details.get("net_amount", 0) if proration_details else 0
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Billing cycle record changed to {new_billing_cycle}",
|
||||
"old_billing_cycle": old_billing_cycle,
|
||||
"new_billing_cycle": new_billing_cycle,
|
||||
"proration_details": proration_details,
|
||||
"new_status": new_status,
|
||||
"new_period_end": new_period_end.isoformat()
|
||||
}
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error(f"change_billing_cycle_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error(f"change_billing_cycle_failed, tenant_id={tenant_id}, error={str(e)}")
|
||||
raise DatabaseError(f"Failed to change billing cycle: {str(e)}")
|
||||
|
||||
async def _invalidate_cache(self, tenant_id: str):
|
||||
"""Helper method to invalidate subscription cache"""
|
||||
try:
|
||||
from app.services.subscription_cache import get_subscription_cache_service
|
||||
import shared.redis_utils
|
||||
|
||||
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
|
||||
cache_service = get_subscription_cache_service(redis_client)
|
||||
await cache_service.invalidate_subscription_cache(str(tenant_id))
|
||||
|
||||
logger.info(
|
||||
"Subscription cache invalidated",
|
||||
tenant_id=str(tenant_id)
|
||||
)
|
||||
except Exception as cache_error:
|
||||
logger.error(
|
||||
"Failed to invalidate subscription cache",
|
||||
tenant_id=str(tenant_id),
|
||||
error=str(cache_error)
|
||||
)
|
||||
|
||||
async def validate_subscription_change(
|
||||
self,
|
||||
tenant_id: str,
|
||||
new_plan: str
|
||||
) -> bool:
|
||||
"""
|
||||
Validate if a subscription change is allowed
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
new_plan: New plan to validate
|
||||
|
||||
Returns:
|
||||
True if change is allowed
|
||||
"""
|
||||
try:
|
||||
subscription = await self.get_subscription_by_tenant_id(tenant_id)
|
||||
|
||||
if not subscription:
|
||||
return False
|
||||
|
||||
# Can't change if already pending cancellation or inactive
|
||||
if subscription.status in ['pending_cancellation', 'inactive']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("validate_subscription_change_failed",
|
||||
error=str(e), tenant_id=tenant_id)
|
||||
return False
|
||||
|
||||
# ========================================================================
|
||||
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
|
||||
# ========================================================================
|
||||
|
||||
async def create_tenant_independent_subscription_record(
|
||||
self,
|
||||
subscription_id: str,
|
||||
customer_id: str,
|
||||
plan: str,
|
||||
status: str,
|
||||
trial_period_days: Optional[int] = None,
|
||||
billing_interval: str = "monthly",
|
||||
user_id: str = None
|
||||
) -> Subscription:
|
||||
"""
|
||||
Create a tenant-independent subscription record in the database
|
||||
|
||||
This subscription is not linked to any tenant and will be linked during onboarding
|
||||
|
||||
Args:
|
||||
subscription_id: Payment provider subscription ID
|
||||
customer_id: Payment provider customer ID
|
||||
plan: Subscription plan
|
||||
status: Subscription status
|
||||
trial_period_days: Optional trial period in days
|
||||
billing_interval: Billing interval (monthly or yearly)
|
||||
user_id: User ID who created this subscription
|
||||
|
||||
Returns:
|
||||
Created Subscription object
|
||||
"""
|
||||
try:
|
||||
# Create tenant-independent subscription record
|
||||
subscription_data = {
|
||||
'subscription_id': subscription_id,
|
||||
'customer_id': customer_id,
|
||||
'plan': plan, # Repository expects 'plan', not 'plan_id'
|
||||
'status': status,
|
||||
'created_at': datetime.now(timezone.utc),
|
||||
'billing_cycle': billing_interval,
|
||||
'user_id': user_id,
|
||||
'is_tenant_linked': False,
|
||||
'tenant_linking_status': 'pending'
|
||||
}
|
||||
|
||||
# Add trial-related data if applicable
|
||||
if trial_period_days and trial_period_days > 0:
|
||||
from datetime import timedelta
|
||||
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
|
||||
subscription_data['trial_ends_at'] = trial_ends_at
|
||||
|
||||
created_subscription = await self.subscription_repo.create_tenant_independent_subscription(subscription_data)
|
||||
|
||||
logger.info("tenant_independent_subscription_record_created",
|
||||
subscription_id=subscription_id,
|
||||
user_id=user_id,
|
||||
plan=plan)
|
||||
|
||||
return created_subscription
|
||||
|
||||
except ValidationError as ve:
|
||||
logger.error("create_tenant_independent_subscription_record_validation_failed",
|
||||
error=str(ve), user_id=user_id)
|
||||
raise ve
|
||||
except Exception as e:
|
||||
logger.error("create_tenant_independent_subscription_record_failed",
|
||||
error=str(e), user_id=user_id)
|
||||
raise DatabaseError(f"Failed to create tenant-independent subscription record: {str(e)}")
|
||||
|
||||
async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]:
|
||||
"""Get all subscriptions waiting to be linked to tenants"""
|
||||
try:
|
||||
return await self.subscription_repo.get_pending_tenant_linking_subscriptions()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get pending tenant linking subscriptions: {str(e)}")
|
||||
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
|
||||
|
||||
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
|
||||
"""Get pending tenant linking subscriptions for a specific user"""
|
||||
try:
|
||||
return await self.subscription_repo.get_pending_subscriptions_by_user(user_id)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending subscriptions by user",
|
||||
user_id=user_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
|
||||
|
||||
async def link_subscription_to_tenant(
|
||||
self,
|
||||
subscription_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str
|
||||
) -> Subscription:
|
||||
"""
|
||||
Link a pending subscription to a tenant
|
||||
|
||||
This completes the registration flow by associating the subscription
|
||||
created during registration with the tenant created during onboarding
|
||||
|
||||
Args:
|
||||
subscription_id: Subscription ID to link
|
||||
tenant_id: Tenant ID to link to
|
||||
user_id: User ID performing the linking (for validation)
|
||||
|
||||
Returns:
|
||||
Updated Subscription object
|
||||
"""
|
||||
try:
|
||||
return await self.subscription_repo.link_subscription_to_tenant(
|
||||
subscription_id, tenant_id, user_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to link subscription to tenant",
|
||||
subscription_id=subscription_id,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}")
|
||||
|
||||
async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int:
|
||||
"""Clean up subscriptions that were never linked to tenants"""
|
||||
try:
|
||||
return await self.subscription_repo.cleanup_orphaned_subscriptions(days_old)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cleanup orphaned subscriptions: {str(e)}")
|
||||
raise DatabaseError(f"Failed to cleanup orphaned subscriptions: {str(e)}")
|
||||
|
||||
async def update_subscription_info(
|
||||
self,
|
||||
subscription_id: str,
|
||||
update_data: Dict[str, Any]
|
||||
) -> Subscription:
|
||||
"""
|
||||
Update subscription-related information (3DS flags, status, etc.)
|
||||
|
||||
This is useful for updating tenant-independent subscriptions during registration.
|
||||
|
||||
Args:
|
||||
subscription_id: Subscription ID
|
||||
update_data: Dictionary with fields to update
|
||||
|
||||
Returns:
|
||||
Updated Subscription object
|
||||
"""
|
||||
try:
|
||||
# Filter allowed fields
|
||||
allowed_fields = {
|
||||
'plan', 'status', 'is_tenant_linked', 'tenant_linking_status',
|
||||
'threeds_authentication_required', 'threeds_authentication_required_at',
|
||||
'threeds_authentication_completed', 'threeds_authentication_completed_at',
|
||||
'last_threeds_setup_intent_id', 'threeds_action_type'
|
||||
}
|
||||
|
||||
filtered_data = {k: v for k, v in update_data.items() if k in allowed_fields}
|
||||
|
||||
if not filtered_data:
|
||||
logger.warning("No valid subscription info fields provided for update",
|
||||
subscription_id=subscription_id)
|
||||
return await self.subscription_repo.get_by_id(subscription_id)
|
||||
|
||||
updated_subscription = await self.subscription_repo.update(subscription_id, filtered_data)
|
||||
|
||||
if not updated_subscription:
|
||||
raise ValidationError(f"Subscription not found: {subscription_id}")
|
||||
|
||||
logger.info("Subscription info updated",
|
||||
subscription_id=subscription_id,
|
||||
updated_fields=list(filtered_data.keys()))
|
||||
|
||||
return updated_subscription
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to update subscription info",
|
||||
subscription_id=subscription_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update subscription info: {str(e)}")
|
||||
1609
services/tenant/app/services/tenant_service.py
Normal file
1609
services/tenant/app/services/tenant_service.py
Normal file
File diff suppressed because it is too large
Load Diff
293
services/tenant/app/services/tenant_settings_service.py
Normal file
293
services/tenant/app/services/tenant_settings_service.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# services/tenant/app/services/tenant_settings_service.py
|
||||
"""
|
||||
Tenant Settings Service
|
||||
Business logic for managing tenant-specific operational settings
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from uuid import UUID
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from ..models.tenant_settings import TenantSettings
|
||||
from ..repositories.tenant_settings_repository import TenantSettingsRepository
|
||||
from ..schemas.tenant_settings import (
|
||||
TenantSettingsUpdate,
|
||||
ProcurementSettings,
|
||||
InventorySettings,
|
||||
ProductionSettings,
|
||||
SupplierSettings,
|
||||
POSSettings,
|
||||
OrderSettings,
|
||||
ReplenishmentSettings,
|
||||
SafetyStockSettings,
|
||||
MOQSettings,
|
||||
SupplierSelectionSettings,
|
||||
NotificationSettings
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TenantSettingsService:
|
||||
"""
|
||||
Service for managing tenant settings
|
||||
Handles validation, CRUD operations, and default value management
|
||||
"""
|
||||
|
||||
# Map category names to schema validators
|
||||
CATEGORY_SCHEMAS = {
|
||||
"procurement": ProcurementSettings,
|
||||
"inventory": InventorySettings,
|
||||
"production": ProductionSettings,
|
||||
"supplier": SupplierSettings,
|
||||
"pos": POSSettings,
|
||||
"order": OrderSettings,
|
||||
"replenishment": ReplenishmentSettings,
|
||||
"safety_stock": SafetyStockSettings,
|
||||
"moq": MOQSettings,
|
||||
"supplier_selection": SupplierSelectionSettings,
|
||||
"notification": NotificationSettings
|
||||
}
|
||||
|
||||
# Map category names to database column names
|
||||
CATEGORY_COLUMNS = {
|
||||
"procurement": "procurement_settings",
|
||||
"inventory": "inventory_settings",
|
||||
"production": "production_settings",
|
||||
"supplier": "supplier_settings",
|
||||
"pos": "pos_settings",
|
||||
"order": "order_settings",
|
||||
"replenishment": "replenishment_settings",
|
||||
"safety_stock": "safety_stock_settings",
|
||||
"moq": "moq_settings",
|
||||
"supplier_selection": "supplier_selection_settings",
|
||||
"notification": "notification_settings"
|
||||
}
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
self.repository = TenantSettingsRepository(db)
|
||||
|
||||
async def get_settings(self, tenant_id: UUID) -> TenantSettings:
|
||||
"""
|
||||
Get tenant settings, creating defaults if they don't exist
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
TenantSettings object
|
||||
|
||||
Raises:
|
||||
HTTPException: If tenant not found
|
||||
"""
|
||||
try:
|
||||
# Try to get existing settings using repository
|
||||
settings = await self.repository.get_by_tenant_id(tenant_id)
|
||||
|
||||
logger.info(f"Existing settings lookup for tenant {tenant_id}: {'found' if settings else 'not found'}")
|
||||
|
||||
# Create default settings if they don't exist
|
||||
if not settings:
|
||||
logger.info(f"Creating default settings for tenant {tenant_id}")
|
||||
settings = await self._create_default_settings(tenant_id)
|
||||
logger.info(f"Successfully created default settings for tenant {tenant_id}")
|
||||
|
||||
return settings
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get or create tenant settings, tenant_id={tenant_id}, error={str(e)}", exc_info=True)
|
||||
# Re-raise as HTTPException to match the expected behavior
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get tenant settings: {str(e)}"
|
||||
)
|
||||
|
||||
async def update_settings(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
updates: TenantSettingsUpdate
|
||||
) -> TenantSettings:
|
||||
"""
|
||||
Update tenant settings
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
updates: TenantSettingsUpdate object with new values
|
||||
|
||||
Returns:
|
||||
Updated TenantSettings object
|
||||
"""
|
||||
settings = await self.get_settings(tenant_id)
|
||||
|
||||
# Update each category if provided
|
||||
if updates.procurement_settings is not None:
|
||||
settings.procurement_settings = updates.procurement_settings.dict()
|
||||
|
||||
if updates.inventory_settings is not None:
|
||||
settings.inventory_settings = updates.inventory_settings.dict()
|
||||
|
||||
if updates.production_settings is not None:
|
||||
settings.production_settings = updates.production_settings.dict()
|
||||
|
||||
if updates.supplier_settings is not None:
|
||||
settings.supplier_settings = updates.supplier_settings.dict()
|
||||
|
||||
if updates.pos_settings is not None:
|
||||
settings.pos_settings = updates.pos_settings.dict()
|
||||
|
||||
if updates.order_settings is not None:
|
||||
settings.order_settings = updates.order_settings.dict()
|
||||
|
||||
if updates.replenishment_settings is not None:
|
||||
settings.replenishment_settings = updates.replenishment_settings.dict()
|
||||
|
||||
if updates.safety_stock_settings is not None:
|
||||
settings.safety_stock_settings = updates.safety_stock_settings.dict()
|
||||
|
||||
if updates.moq_settings is not None:
|
||||
settings.moq_settings = updates.moq_settings.dict()
|
||||
|
||||
if updates.supplier_selection_settings is not None:
|
||||
settings.supplier_selection_settings = updates.supplier_selection_settings.dict()
|
||||
|
||||
return await self.repository.update(settings)
|
||||
|
||||
async def get_category(self, tenant_id: UUID, category: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get settings for a specific category
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
category: Category name (procurement, inventory, production, etc.)
|
||||
|
||||
Returns:
|
||||
Dictionary with category settings
|
||||
|
||||
Raises:
|
||||
HTTPException: If category is invalid
|
||||
"""
|
||||
if category not in self.CATEGORY_COLUMNS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid category: {category}. Valid categories: {', '.join(self.CATEGORY_COLUMNS.keys())}"
|
||||
)
|
||||
|
||||
settings = await self.get_settings(tenant_id)
|
||||
column_name = self.CATEGORY_COLUMNS[category]
|
||||
|
||||
return getattr(settings, column_name)
|
||||
|
||||
async def update_category(
|
||||
self,
|
||||
tenant_id: UUID,
|
||||
category: str,
|
||||
updates: Dict[str, Any]
|
||||
) -> TenantSettings:
|
||||
"""
|
||||
Update settings for a specific category
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
category: Category name
|
||||
updates: Dictionary with new values
|
||||
|
||||
Returns:
|
||||
Updated TenantSettings object
|
||||
|
||||
Raises:
|
||||
HTTPException: If category is invalid or validation fails
|
||||
"""
|
||||
if category not in self.CATEGORY_COLUMNS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid category: {category}"
|
||||
)
|
||||
|
||||
# Validate updates using the appropriate schema
|
||||
schema = self.CATEGORY_SCHEMAS[category]
|
||||
try:
|
||||
validated_data = schema(**updates)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Validation error: {str(e)}"
|
||||
)
|
||||
|
||||
# Get existing settings and update the category
|
||||
settings = await self.get_settings(tenant_id)
|
||||
column_name = self.CATEGORY_COLUMNS[category]
|
||||
setattr(settings, column_name, validated_data.dict())
|
||||
|
||||
return await self.repository.update(settings)
|
||||
|
||||
async def reset_category(self, tenant_id: UUID, category: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Reset a category to default values
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
category: Category name
|
||||
|
||||
Returns:
|
||||
Dictionary with reset category settings
|
||||
|
||||
Raises:
|
||||
HTTPException: If category is invalid
|
||||
"""
|
||||
if category not in self.CATEGORY_COLUMNS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid category: {category}"
|
||||
)
|
||||
|
||||
# Get default settings for the category
|
||||
defaults = TenantSettings.get_default_settings()
|
||||
column_name = self.CATEGORY_COLUMNS[category]
|
||||
default_category_settings = defaults[column_name]
|
||||
|
||||
# Update the category with defaults
|
||||
settings = await self.get_settings(tenant_id)
|
||||
setattr(settings, column_name, default_category_settings)
|
||||
|
||||
await self.repository.update(settings)
|
||||
|
||||
return default_category_settings
|
||||
|
||||
async def _create_default_settings(self, tenant_id: UUID) -> TenantSettings:
|
||||
"""
|
||||
Create default settings for a new tenant
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
|
||||
Returns:
|
||||
Newly created TenantSettings object
|
||||
"""
|
||||
defaults = TenantSettings.get_default_settings()
|
||||
|
||||
settings = TenantSettings(
|
||||
tenant_id=tenant_id,
|
||||
procurement_settings=defaults["procurement_settings"],
|
||||
inventory_settings=defaults["inventory_settings"],
|
||||
production_settings=defaults["production_settings"],
|
||||
supplier_settings=defaults["supplier_settings"],
|
||||
pos_settings=defaults["pos_settings"],
|
||||
order_settings=defaults["order_settings"],
|
||||
replenishment_settings=defaults["replenishment_settings"],
|
||||
safety_stock_settings=defaults["safety_stock_settings"],
|
||||
moq_settings=defaults["moq_settings"],
|
||||
supplier_selection_settings=defaults["supplier_selection_settings"]
|
||||
)
|
||||
|
||||
return await self.repository.create(settings)
|
||||
|
||||
async def delete_settings(self, tenant_id: UUID) -> None:
|
||||
"""
|
||||
Delete tenant settings (used when tenant is deleted)
|
||||
|
||||
Args:
|
||||
tenant_id: UUID of the tenant
|
||||
"""
|
||||
await self.repository.delete(tenant_id)
|
||||
Reference in New Issue
Block a user