Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

View File

View File

@@ -0,0 +1,8 @@
"""
Tenant API Package
API endpoints for tenant management
"""
from . import tenants
__all__ = ["tenants"]

View File

@@ -0,0 +1,359 @@
"""
Enterprise Upgrade API
Endpoints for upgrading tenants to enterprise tier and managing child outlets
"""
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Any, Optional
import uuid
from datetime import datetime, date
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.tenants import Tenant
from app.models.tenant_location import TenantLocation
from app.services.tenant_service import EnhancedTenantService
from app.core.config import settings
from shared.auth.tenant_access import verify_tenant_permission_dep
from shared.auth.decorators import get_current_user_dep
from shared.clients.subscription_client import SubscriptionServiceClient, get_subscription_service_client
from shared.subscription.plans import SubscriptionTier, QuotaLimits
from shared.database.base import create_database_manager
import structlog
logger = structlog.get_logger()
router = APIRouter()
# Dependency injection for enhanced tenant service
def get_enhanced_tenant_service():
try:
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
return EnhancedTenantService(database_manager)
except Exception as e:
logger.error("Failed to create enhanced tenant service", error=str(e))
raise HTTPException(status_code=500, detail="Service initialization failed")
# Pydantic models for request bodies
class EnterpriseUpgradeRequest(BaseModel):
location_name: Optional[str] = Field(default="Central Production Facility")
address: Optional[str] = None
city: Optional[str] = None
postal_code: Optional[str] = None
latitude: Optional[float] = None
longitude: Optional[float] = None
production_capacity_kg: Optional[int] = Field(default=1000)
class ChildOutletRequest(BaseModel):
name: str
subdomain: str
address: str
city: Optional[str] = None
postal_code: str
latitude: Optional[float] = None
longitude: Optional[float] = None
phone: Optional[str] = None
email: Optional[str] = None
delivery_days: Optional[list] = None
@router.post("/tenants/{tenant_id}/upgrade-to-enterprise")
async def upgrade_to_enterprise(
tenant_id: str,
upgrade_data: EnterpriseUpgradeRequest,
subscription_client: SubscriptionServiceClient = Depends(get_subscription_service_client),
current_user: Dict[str, Any] = Depends(get_current_user_dep)
):
"""
Upgrade a tenant to enterprise tier with central production facility
"""
try:
from app.core.database import database_manager
from app.repositories.tenant_repository import TenantRepository
# Get the current tenant
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
tenant = await tenant_repo.get_by_id(tenant_id)
if not tenant:
raise HTTPException(status_code=404, detail="Tenant not found")
# Verify current subscription allows upgrade to enterprise
current_subscription = await subscription_client.get_subscription(tenant_id)
if current_subscription['plan'] not in [SubscriptionTier.STARTER.value, SubscriptionTier.PROFESSIONAL.value]:
raise HTTPException(status_code=400, detail="Only starter and professional tier tenants can be upgraded to enterprise")
# Verify user has admin/owner role
# This is handled by current_user check
# Update tenant to parent type
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
updated_tenant = await tenant_repo.update(
tenant_id,
{
'tenant_type': 'parent',
'hierarchy_path': f"{tenant_id}" # Root path
}
)
await session.commit()
# Create central production location
location_data = {
'tenant_id': tenant_id,
'name': upgrade_data.location_name,
'location_type': 'central_production',
'address': upgrade_data.address or tenant.address,
'city': upgrade_data.city or tenant.city,
'postal_code': upgrade_data.postal_code or tenant.postal_code,
'latitude': upgrade_data.latitude or tenant.latitude,
'longitude': upgrade_data.longitude or tenant.longitude,
'capacity': upgrade_data.production_capacity_kg,
'is_active': True
}
from app.repositories.tenant_location_repository import TenantLocationRepository
from app.core.database import database_manager
# Create async session
async with database_manager.get_session() as session:
location_repo = TenantLocationRepository(session)
created_location = await location_repo.create_location(location_data)
await session.commit()
# Update subscription to enterprise tier
await subscription_client.update_subscription_plan(
tenant_id=tenant_id,
new_plan=SubscriptionTier.ENTERPRISE.value
)
return {
'success': True,
'tenant': updated_tenant,
'production_location': created_location,
'message': 'Tenant successfully upgraded to enterprise tier'
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to upgrade tenant: {str(e)}")
@router.post("/tenants/{parent_id}/add-child-outlet")
async def add_child_outlet(
parent_id: str,
child_data: ChildOutletRequest,
subscription_client: SubscriptionServiceClient = Depends(get_subscription_service_client),
current_user: Dict[str, Any] = Depends(get_current_user_dep)
):
"""
Add a new child outlet to a parent tenant
"""
try:
from app.core.database import database_manager
from app.repositories.tenant_repository import TenantRepository
# Get parent tenant and verify it's a parent
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
parent_tenant = await tenant_repo.get_by_id(parent_id)
if not parent_tenant:
raise HTTPException(status_code=400, detail="Parent tenant not found")
parent_dict = {
'id': str(parent_tenant.id),
'name': parent_tenant.name,
'tenant_type': parent_tenant.tenant_type,
'subscription_tier': parent_tenant.subscription_tier,
'business_type': parent_tenant.business_type,
'business_model': parent_tenant.business_model,
'city': parent_tenant.city,
'phone': parent_tenant.phone,
'email': parent_tenant.email,
'owner_id': parent_tenant.owner_id
}
if parent_dict.get('tenant_type') != 'parent':
raise HTTPException(status_code=400, detail="Tenant is not a parent type")
# Validate subscription tier
from shared.clients import get_tenant_client
from shared.subscription.plans import PlanFeatures
tenant_client = get_tenant_client(config=settings, service_name="tenant-service")
subscription = await tenant_client.get_tenant_subscription(parent_id)
if not subscription:
raise HTTPException(
status_code=403,
detail="No active subscription found for parent tenant"
)
tier = subscription.get("plan", "starter")
if not PlanFeatures.validate_tenant_access(tier, "child"):
raise HTTPException(
status_code=403,
detail=f"Creating child outlets requires Enterprise subscription. Current plan: {tier}"
)
# Check if parent has reached child quota
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
current_child_count = await tenant_repo.get_child_tenant_count(parent_id)
# Get max children from subscription plan
max_children = QuotaLimits.get_limit("MAX_CHILD_TENANTS", tier)
if max_children is not None and current_child_count >= max_children:
raise HTTPException(
status_code=403,
detail=f"Child tenant limit reached. Current: {current_child_count}, Maximum: {max_children}"
)
# Create new child tenant
child_id = str(uuid.uuid4())
child_tenant_data = {
'id': child_id,
'name': child_data.name,
'subdomain': child_data.subdomain,
'business_type': parent_dict.get('business_type', 'bakery'),
'business_model': parent_dict.get('business_model', 'retail_bakery'),
'address': child_data.address,
'city': child_data.city or parent_dict.get('city'),
'postal_code': child_data.postal_code,
'latitude': child_data.latitude,
'longitude': child_data.longitude,
'phone': child_data.phone or parent_dict.get('phone'),
'email': child_data.email or parent_dict.get('email'),
'parent_tenant_id': parent_id,
'tenant_type': 'child',
'hierarchy_path': f"{parent_id}.{child_id}",
'owner_id': parent_dict.get('owner_id'), # Same owner as parent
'is_active': True
}
# Use database managed session
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
created_child = await tenant_repo.create(child_tenant_data)
await session.commit()
created_child_dict = {
'id': str(created_child.id),
'name': created_child.name,
'subdomain': created_child.subdomain
}
# Create retail outlet location for the child
location_data = {
'tenant_id': uuid.UUID(child_id),
'name': f"Outlet - {child_data.name}",
'location_type': 'retail_outlet',
'address': child_data.address,
'city': child_data.city or parent_dict.get('city'),
'postal_code': child_data.postal_code,
'latitude': child_data.latitude,
'longitude': child_data.longitude,
'delivery_windows': child_data.delivery_days,
'is_active': True
}
from app.repositories.tenant_location_repository import TenantLocationRepository
# Create async session
async with database_manager.get_session() as session:
location_repo = TenantLocationRepository(session)
created_location = await location_repo.create_location(location_data)
await session.commit()
location_dict = {
'id': str(created_location.id) if created_location else None,
'name': created_location.name if created_location else None
}
# Copy relevant settings from parent (with child-specific overrides)
# This would typically involve copying settings via tenant settings service
# Create child subscription inheriting from parent
await subscription_client.create_child_subscription(
child_tenant_id=child_id,
parent_tenant_id=parent_id
)
return {
'success': True,
'child_tenant': created_child_dict,
'location': location_dict,
'message': 'Child outlet successfully added'
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to add child outlet: {str(e)}")
@router.get("/tenants/{tenant_id}/hierarchy")
async def get_tenant_hierarchy(
tenant_id: str,
current_user: Dict[str, Any] = Depends(get_current_user_dep)
):
"""
Get tenant hierarchy information
"""
try:
from app.core.database import database_manager
from app.repositories.tenant_repository import TenantRepository
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
tenant = await tenant_repo.get_by_id(tenant_id)
if not tenant:
raise HTTPException(status_code=404, detail="Tenant not found")
result = {
'tenant_id': tenant_id,
'name': tenant.name,
'tenant_type': tenant.tenant_type,
'parent_tenant_id': tenant.parent_tenant_id,
'hierarchy_path': tenant.hierarchy_path,
'is_parent': tenant.tenant_type == 'parent',
'is_child': tenant.tenant_type == 'child'
}
# If this is a parent, include child count
if tenant.tenant_type == 'parent':
child_count = await tenant_repo.get_child_tenant_count(tenant_id)
result['child_tenant_count'] = child_count
return result
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get hierarchy: {str(e)}")
@router.get("/users/{user_id}/tenant-hierarchy")
async def get_user_accessible_tenant_hierarchy(
user_id: str,
current_user: Dict[str, Any] = Depends(get_current_user_dep)
):
"""
Get all tenants a user has access to, organized in hierarchy
"""
try:
from app.core.database import database_manager
from app.repositories.tenant_repository import TenantRepository
# Fetch all tenants where user has access, organized hierarchically
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
user_tenants = await tenant_repo.get_user_tenants_with_hierarchy(user_id)
return {
'user_id': user_id,
'tenants': user_tenants,
'total_count': len(user_tenants)
}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to get user hierarchy: {str(e)}")

View File

@@ -0,0 +1,827 @@
"""
Internal Demo Cloning API
Service-to-service endpoint for cloning tenant data
"""
from fastapi import APIRouter, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import structlog
import uuid
from datetime import datetime, timezone, timedelta
from typing import Optional
import os
import json
from pathlib import Path
from app.core.database import get_db
from app.models.tenants import Tenant, Subscription, TenantMember
from app.models.tenant_location import TenantLocation
from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
def parse_date_field(
field_value: any,
session_time: datetime,
field_name: str = "date"
) -> Optional[datetime]:
"""
Parse a date field from JSON, supporting BASE_TS markers and ISO timestamps.
Args:
field_value: The date field value (can be BASE_TS marker, ISO string, or None)
session_time: Session creation time (timezone-aware UTC)
field_name: Name of the field (for logging)
Returns:
Timezone-aware UTC datetime or None
"""
if field_value is None:
return None
# Handle BASE_TS markers
if isinstance(field_value, str) and field_value.startswith("BASE_TS"):
try:
return resolve_time_marker(field_value, session_time)
except (ValueError, AttributeError) as e:
logger.warning(
"Failed to resolve BASE_TS marker",
field_name=field_name,
marker=field_value,
error=str(e)
)
return None
# Handle ISO timestamps (legacy format - convert to absolute datetime)
if isinstance(field_value, str) and ('T' in field_value or 'Z' in field_value):
try:
parsed_date = datetime.fromisoformat(field_value.replace('Z', '+00:00'))
# Adjust relative to session time
return adjust_date_for_demo(parsed_date, session_time)
except (ValueError, AttributeError) as e:
logger.warning(
"Failed to parse ISO timestamp",
field_name=field_name,
value=field_value,
error=str(e)
)
return None
logger.warning(
"Unknown date format",
field_name=field_name,
value=field_value,
value_type=type(field_value).__name__
)
return None
@router.post("/clone")
async def clone_demo_data(
base_tenant_id: str,
virtual_tenant_id: str,
demo_account_type: str,
session_id: Optional[str] = None,
db: AsyncSession = Depends(get_db)
):
"""
Clone tenant service data for a virtual demo tenant
This endpoint creates the virtual tenant record that will be used
for the demo session. No actual data cloning is needed in tenant service
beyond creating the tenant record itself.
Args:
base_tenant_id: Template tenant UUID (not used, for consistency)
virtual_tenant_id: Target virtual tenant UUID
demo_account_type: Type of demo account
session_id: Originating session ID for tracing
Returns:
Cloning status and record count
"""
start_time = datetime.now(timezone.utc)
logger.info(
"Starting tenant data cloning",
virtual_tenant_id=virtual_tenant_id,
demo_account_type=demo_account_type,
session_id=session_id
)
try:
# Validate UUIDs
virtual_uuid = uuid.UUID(virtual_tenant_id)
# Check if tenant already exists
result = await db.execute(
select(Tenant).where(Tenant.id == virtual_uuid)
)
existing_tenant = result.scalars().first()
if existing_tenant:
logger.info(
"Virtual tenant already exists",
virtual_tenant_id=virtual_tenant_id,
tenant_name=existing_tenant.name
)
# Ensure the tenant has a subscription (copy from template if missing)
from datetime import timedelta
result = await db.execute(
select(Subscription).where(
Subscription.tenant_id == virtual_uuid,
Subscription.status == "active"
)
)
existing_subscription = result.scalars().first()
if not existing_subscription:
logger.info("Creating missing subscription for existing virtual tenant by copying from template",
virtual_tenant_id=virtual_tenant_id,
base_tenant_id=base_tenant_id)
# Load subscription from seed data instead of cloning from template
try:
from shared.utils.seed_data_paths import get_seed_data_path
if demo_account_type == "professional":
json_file = get_seed_data_path("professional", "01-tenant.json")
elif demo_account_type == "enterprise":
json_file = get_seed_data_path("enterprise", "01-tenant.json")
else:
raise ValueError(f"Invalid demo account type: {demo_account_type}")
except ImportError:
# Fallback to original path
seed_data_dir = Path(__file__).parent.parent.parent.parent / "infrastructure" / "seed-data"
if demo_account_type == "professional":
json_file = seed_data_dir / "professional" / "01-tenant.json"
elif demo_account_type == "enterprise":
json_file = seed_data_dir / "enterprise" / "parent" / "01-tenant.json"
else:
raise ValueError(f"Invalid demo account type: {demo_account_type}")
if json_file.exists():
import json
with open(json_file, 'r', encoding='utf-8') as f:
seed_data = json.load(f)
subscription_data = seed_data.get('subscription')
if subscription_data:
# Load subscription from seed data
subscription = Subscription(
tenant_id=virtual_uuid,
plan=subscription_data.get('plan', 'professional'),
status=subscription_data.get('status', 'active'),
monthly_price=subscription_data.get('monthly_price', 299.00),
billing_cycle=subscription_data.get('billing_cycle', 'monthly'),
max_users=subscription_data.get('max_users', 10),
max_locations=subscription_data.get('max_locations', 3),
max_products=subscription_data.get('max_products', 500),
features=subscription_data.get('features', {}),
trial_ends_at=parse_date_field(
subscription_data.get('trial_ends_at'),
session_time,
"trial_ends_at"
),
next_billing_date=parse_date_field(
subscription_data.get('next_billing_date'),
session_time,
"next_billing_date"
),
subscription_id=subscription_data.get('stripe_subscription_id'),
customer_id=subscription_data.get('stripe_customer_id'),
cancelled_at=parse_date_field(
subscription_data.get('cancelled_at'),
session_time,
"cancelled_at"
),
cancellation_effective_date=parse_date_field(
subscription_data.get('cancellation_effective_date'),
session_time,
"cancellation_effective_date"
),
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(subscription)
await db.commit()
logger.info("Subscription loaded from seed data successfully",
virtual_tenant_id=virtual_tenant_id,
plan=subscription.plan)
else:
logger.warning("No subscription found in seed data",
virtual_tenant_id=virtual_tenant_id)
else:
logger.warning("Seed data file not found, falling back to default subscription",
file_path=str(json_file))
# Create default subscription if seed data not available
subscription = Subscription(
tenant_id=virtual_uuid,
plan="professional" if demo_account_type == "professional" else "enterprise",
status="active",
monthly_price=299.00 if demo_account_type == "professional" else 799.00,
max_users=10 if demo_account_type == "professional" else 50,
max_locations=3 if demo_account_type == "professional" else -1,
max_products=500 if demo_account_type == "professional" else -1,
features={
"production_planning": True,
"procurement_management": True,
"inventory_management": True,
"sales_analytics": True,
"multi_location": True,
"advanced_reporting": True,
"api_access": True,
"priority_support": True
},
next_billing_date=datetime.now(timezone.utc) + timedelta(days=90),
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(subscription)
await db.commit()
logger.info("Default subscription created",
virtual_tenant_id=virtual_tenant_id,
plan=subscription.plan)
# Return success - idempotent operation
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
return {
"service": "tenant",
"status": "completed",
"records_cloned": 0 if existing_subscription else 1,
"duration_ms": duration_ms,
"details": {
"tenant_already_exists": True,
"tenant_id": str(virtual_uuid),
"subscription_created": not existing_subscription
}
}
# Create virtual tenant record with required fields
# Note: Use the actual demo user IDs from seed_demo_users.py
# These match the demo users created in the auth service
DEMO_OWNER_IDS = {
"professional": "c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6", # María García López
"enterprise": "d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7" # Carlos Martínez Ruiz
}
demo_owner_uuid = uuid.UUID(DEMO_OWNER_IDS.get(demo_account_type, DEMO_OWNER_IDS["professional"]))
tenant = Tenant(
id=virtual_uuid,
name=f"Demo Tenant - {demo_account_type.replace('_', ' ').title()}",
address="Calle Demo 123", # Required field - provide demo address
city="Madrid",
postal_code="28001",
business_type="bakery",
is_demo=True,
is_demo_template=False,
demo_session_id=session_id, # Link tenant to demo session
business_model=demo_account_type,
is_active=True,
timezone="Europe/Madrid",
owner_id=demo_owner_uuid, # Required field - matches seed_demo_users.py
tenant_type="parent" if demo_account_type in ["enterprise", "enterprise_parent"] else "standalone"
)
db.add(tenant)
await db.flush() # Flush to get the tenant ID
# Create demo subscription with appropriate tier based on demo account type
# Determine subscription tier based on demo account type
if demo_account_type == "professional":
plan = "professional"
max_locations = 3
elif demo_account_type in ["enterprise", "enterprise_parent"]:
plan = "enterprise"
max_locations = -1 # Unlimited
elif demo_account_type == "enterprise_child":
plan = "enterprise"
max_locations = 1
else:
plan = "starter"
max_locations = 1
demo_subscription = Subscription(
tenant_id=tenant.id,
plan=plan, # Set appropriate tier based on demo account type
status="active",
monthly_price=0.0, # Free for demo
billing_cycle="monthly",
max_users=-1, # Unlimited for demo
max_locations=max_locations,
max_products=-1, # Unlimited for demo
features={},
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(demo_subscription)
# Create tenant member records for demo owner and staff
import json
# Helper function to get permissions for role
def get_permissions_for_role(role: str) -> str:
permission_map = {
"owner": ["read", "write", "admin", "delete"],
"admin": ["read", "write", "admin"],
"production_manager": ["read", "write"],
"baker": ["read", "write"],
"sales": ["read", "write"],
"quality_control": ["read", "write"],
"warehouse": ["read", "write"],
"logistics": ["read", "write"],
"procurement": ["read", "write"],
"maintenance": ["read", "write"],
"member": ["read", "write"],
"viewer": ["read"]
}
permissions = permission_map.get(role, ["read"])
return json.dumps(permissions)
# Define staff users for each demo account type (must match seed_demo_tenant_members.py)
STAFF_USERS = {
"professional": [
# Owner
{
"user_id": uuid.UUID("c1a2b3c4-d5e6-47a8-b9c0-d1e2f3a4b5c6"),
"role": "owner"
},
# Staff
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000001"),
"role": "baker"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000002"),
"role": "sales"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000003"),
"role": "quality_control"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000004"),
"role": "admin"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000005"),
"role": "warehouse"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000006"),
"role": "production_manager"
}
],
"enterprise": [
# Owner
{
"user_id": uuid.UUID("d2e3f4a5-b6c7-48d9-e0f1-a2b3c4d5e6f7"),
"role": "owner"
},
# Staff
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000011"),
"role": "production_manager"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000012"),
"role": "quality_control"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000013"),
"role": "logistics"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000014"),
"role": "sales"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000015"),
"role": "procurement"
},
{
"user_id": uuid.UUID("50000000-0000-0000-0000-000000000016"),
"role": "maintenance"
}
]
}
# Get staff users for this demo account type
staff_users = STAFF_USERS.get(demo_account_type, [])
# Create tenant member records for all users (owner + staff)
members_created = 0
for staff_member in staff_users:
tenant_member = TenantMember(
tenant_id=virtual_uuid,
user_id=staff_member["user_id"],
role=staff_member["role"],
permissions=get_permissions_for_role(staff_member["role"]),
is_active=True,
invited_by=demo_owner_uuid,
invited_at=datetime.now(timezone.utc),
joined_at=datetime.now(timezone.utc),
created_at=datetime.now(timezone.utc)
)
db.add(tenant_member)
members_created += 1
logger.info(
"Created tenant members for virtual tenant",
virtual_tenant_id=virtual_tenant_id,
members_created=members_created
)
# Clone TenantLocations
from app.models.tenant_location import TenantLocation
base_uuid = uuid.UUID(base_tenant_id)
location_result = await db.execute(
select(TenantLocation).where(TenantLocation.tenant_id == base_uuid)
)
base_locations = location_result.scalars().all()
records_cloned = 1 + members_created # Tenant + TenantMembers
for base_location in base_locations:
virtual_location = TenantLocation(
id=uuid.uuid4(),
tenant_id=virtual_tenant_id,
name=base_location.name,
location_type=base_location.location_type,
address=base_location.address,
city=base_location.city,
postal_code=base_location.postal_code,
latitude=base_location.latitude,
longitude=base_location.longitude,
capacity=base_location.capacity,
delivery_windows=base_location.delivery_windows,
operational_hours=base_location.operational_hours,
max_delivery_radius_km=base_location.max_delivery_radius_km,
delivery_schedule_config=base_location.delivery_schedule_config,
is_active=base_location.is_active,
contact_person=base_location.contact_person,
contact_phone=base_location.contact_phone,
contact_email=base_location.contact_email,
metadata_=base_location.metadata_ if isinstance(base_location.metadata_, dict) else (base_location.metadata_ or {})
)
db.add(virtual_location)
records_cloned += 1
logger.info("Cloned TenantLocations", count=len(base_locations))
# Subscription already created earlier based on demo_account_type (lines 179-206)
# No need to clone from template - this prevents duplicate subscription creation
await db.commit()
await db.refresh(tenant)
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
"Virtual tenant created successfully with subscription",
virtual_tenant_id=virtual_tenant_id,
tenant_name=tenant.name,
subscription_plan=plan,
duration_ms=duration_ms
)
records_cloned = 1 + members_created + 1 # Tenant + TenantMembers + Subscription
return {
"service": "tenant",
"status": "completed",
"records_cloned": records_cloned,
"duration_ms": duration_ms,
"details": {
"tenant_id": str(tenant.id),
"tenant_name": tenant.name,
"business_model": tenant.business_model,
"owner_id": str(demo_owner_uuid),
"members_created": members_created,
"subscription_plan": plan,
"subscription_created": True
}
}
except ValueError as e:
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
except Exception as e:
logger.error(
"Failed to clone tenant data",
error=str(e),
virtual_tenant_id=virtual_tenant_id,
exc_info=True
)
# Rollback on error
await db.rollback()
return {
"service": "tenant",
"status": "failed",
"records_cloned": 0,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"error": str(e)
}
@router.post("/create-child")
async def create_child_outlet(
request: dict,
db: AsyncSession = Depends(get_db)
):
"""
Create a child outlet tenant for enterprise demos
Args:
request: JSON request body with child tenant details
Returns:
Creation status and tenant details
"""
# Extract parameters from request body
base_tenant_id = request.get("base_tenant_id")
virtual_tenant_id = request.get("virtual_tenant_id")
parent_tenant_id = request.get("parent_tenant_id")
child_name = request.get("child_name")
location = request.get("location", {})
session_id = request.get("session_id")
start_time = datetime.now(timezone.utc)
logger.info(
"Creating child outlet tenant",
virtual_tenant_id=virtual_tenant_id,
parent_tenant_id=parent_tenant_id,
child_name=child_name,
session_id=session_id
)
try:
# Validate UUIDs
virtual_uuid = uuid.UUID(virtual_tenant_id)
parent_uuid = uuid.UUID(parent_tenant_id)
# Check if child tenant already exists
result = await db.execute(select(Tenant).where(Tenant.id == virtual_uuid))
existing_tenant = result.scalars().first()
if existing_tenant:
logger.info(
"Child tenant already exists",
virtual_tenant_id=virtual_tenant_id,
tenant_name=existing_tenant.name
)
# Return existing tenant - idempotent operation
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
return {
"service": "tenant",
"status": "completed",
"records_created": 0,
"duration_ms": duration_ms,
"details": {
"tenant_id": str(virtual_uuid),
"tenant_name": existing_tenant.name,
"already_exists": True
}
}
# Get parent tenant to retrieve the correct owner_id
parent_result = await db.execute(select(Tenant).where(Tenant.id == parent_uuid))
parent_tenant = parent_result.scalars().first()
if not parent_tenant:
logger.error("Parent tenant not found", parent_tenant_id=parent_tenant_id)
return {
"service": "tenant",
"status": "failed",
"records_created": 0,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"error": f"Parent tenant {parent_tenant_id} not found"
}
# Use the parent's owner_id for the child tenant (enterprise demo owner)
parent_owner_id = parent_tenant.owner_id
# Create child tenant with parent relationship
child_tenant = Tenant(
id=virtual_uuid,
name=child_name,
address=location.get("address", f"Calle Outlet {location.get('city', 'Madrid')}"),
city=location.get("city", "Madrid"),
postal_code=location.get("postal_code", "28001"),
business_type="bakery",
is_demo=True,
is_demo_template=False,
demo_session_id=session_id, # Link child tenant to demo session
business_model="retail_outlet",
is_active=True,
timezone="Europe/Madrid",
# Set parent relationship
parent_tenant_id=parent_uuid,
tenant_type="child",
hierarchy_path=f"{str(parent_uuid)}.{str(virtual_uuid)}",
# Owner ID - MUST match the parent tenant owner (enterprise demo owner)
# This ensures the parent owner can see and access child tenants
owner_id=parent_owner_id
)
db.add(child_tenant)
await db.flush() # Flush to get the tenant ID
# Create TenantLocation for this retail outlet
child_location = TenantLocation(
id=uuid.uuid4(),
tenant_id=virtual_uuid,
name=f"{child_name} - Retail Outlet",
location_type="retail_outlet",
address=location.get("address", f"Calle Outlet {location.get('city', 'Madrid')}"),
city=location.get("city", "Madrid"),
postal_code=location.get("postal_code", "28001"),
latitude=location.get("latitude"),
longitude=location.get("longitude"),
delivery_windows={
"monday": "07:00-10:00",
"wednesday": "07:00-10:00",
"friday": "07:00-10:00"
},
operational_hours={
"monday": "07:00-21:00",
"tuesday": "07:00-21:00",
"wednesday": "07:00-21:00",
"thursday": "07:00-21:00",
"friday": "07:00-21:00",
"saturday": "08:00-21:00",
"sunday": "09:00-21:00"
},
delivery_schedule_config={
"delivery_days": ["monday", "wednesday", "friday"],
"time_window": "07:00-10:00"
},
is_active=True
)
db.add(child_location)
logger.info("Created TenantLocation for child", child_id=str(virtual_uuid), location_name=child_location.name)
# Create parent tenant lookup to get the correct plan for the child
parent_result = await db.execute(
select(Subscription).where(
Subscription.tenant_id == parent_uuid,
Subscription.status == "active"
)
)
parent_subscription = parent_result.scalars().first()
# Child inherits the same plan as parent
parent_plan = parent_subscription.plan if parent_subscription else "enterprise"
child_subscription = Subscription(
tenant_id=child_tenant.id,
plan=parent_plan, # Child inherits the same plan as parent
status="active",
monthly_price=0.0, # Free for demo
billing_cycle="monthly",
max_users=10, # Demo limits
max_locations=1, # Single location for outlet
max_products=200,
features={},
is_tenant_linked=True # Required for check constraint when tenant_id is set
)
db.add(child_subscription)
# Create basic tenant members like parent
import json
# Use the parent's owner_id (already retrieved above)
# This ensures consistency between tenant.owner_id and TenantMember records
# Create tenant member for owner
child_owner_member = TenantMember(
tenant_id=virtual_uuid,
user_id=parent_owner_id,
role="owner",
permissions=json.dumps(["read", "write", "admin", "delete"]),
is_active=True,
invited_by=parent_owner_id,
invited_at=datetime.now(timezone.utc),
joined_at=datetime.now(timezone.utc),
created_at=datetime.now(timezone.utc)
)
db.add(child_owner_member)
# Create staff members for the outlet from parent enterprise users
# Use parent's enterprise staff (from enterprise/parent/02-auth.json)
staff_users = [
{
"user_id": uuid.UUID("f6c54d0f-5899-4952-ad94-7a492c07167a"), # Laura López - Logistics
"role": "logistics_coord"
},
{
"user_id": uuid.UUID("80765906-0074-4206-8f58-5867df1975fd"), # José Martínez - Quality
"role": "quality_control"
},
{
"user_id": uuid.UUID("701cb9d2-6049-4bb9-8d3a-1b3bd3aae45f"), # Francisco Moreno - Warehouse
"role": "warehouse_supervisor"
}
]
members_created = 1 # Start with owner
for staff_member in staff_users:
tenant_member = TenantMember(
tenant_id=virtual_uuid,
user_id=staff_member["user_id"],
role=staff_member["role"],
permissions=json.dumps(["read", "write"]) if staff_member["role"] != "admin" else json.dumps(["read", "write", "admin"]),
is_active=True,
invited_by=parent_owner_id,
invited_at=datetime.now(timezone.utc),
joined_at=datetime.now(timezone.utc),
created_at=datetime.now(timezone.utc)
)
db.add(tenant_member)
members_created += 1
await db.commit()
await db.refresh(child_tenant)
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
"Child outlet created successfully",
virtual_tenant_id=str(virtual_tenant_id),
parent_tenant_id=str(parent_tenant_id),
child_name=child_name,
owner_id=str(parent_owner_id),
duration_ms=duration_ms
)
return {
"service": "tenant",
"status": "completed",
"records_created": 2 + members_created, # Tenant + Subscription + Members
"duration_ms": duration_ms,
"details": {
"tenant_id": str(child_tenant.id),
"tenant_name": child_tenant.name,
"parent_tenant_id": str(parent_tenant_id),
"location": location,
"members_created": members_created,
"subscription_plan": "enterprise"
}
}
except ValueError as e:
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
except Exception as e:
logger.error(
"Failed to create child outlet",
error=str(e),
virtual_tenant_id=virtual_tenant_id,
parent_tenant_id=parent_tenant_id,
exc_info=True
)
# Rollback on error
await db.rollback()
return {
"service": "tenant",
"status": "failed",
"records_created": 0,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"error": str(e)
}
@router.get("/clone/health")
async def clone_health_check():
"""
Health check for internal cloning endpoint
Used by orchestrator to verify service availability
"""
return {
"service": "tenant",
"clone_endpoint": "available",
"version": "2.0.0"
}

View File

@@ -0,0 +1,445 @@
"""
Network Alerts API
Endpoints for aggregating and managing alerts across enterprise networks
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from typing import List, Dict, Any, Optional
from datetime import datetime
from pydantic import BaseModel, Field
import structlog
from app.services.network_alerts_service import NetworkAlertsService
from shared.auth.tenant_access import verify_tenant_permission_dep
from shared.clients import get_tenant_client, get_alerts_client
from app.core.config import settings
logger = structlog.get_logger()
router = APIRouter()
# Pydantic models for request/response
class NetworkAlert(BaseModel):
alert_id: str = Field(..., description="Unique alert ID")
tenant_id: str = Field(..., description="Tenant ID where alert originated")
tenant_name: str = Field(..., description="Tenant name")
alert_type: str = Field(..., description="Type of alert: inventory, production, delivery, etc.")
severity: str = Field(..., description="Severity: critical, high, medium, low")
title: str = Field(..., description="Alert title")
message: str = Field(..., description="Alert message")
timestamp: str = Field(..., description="Alert timestamp")
status: str = Field(..., description="Alert status: active, acknowledged, resolved")
source_system: str = Field(..., description="System that generated the alert")
related_entity_id: Optional[str] = Field(None, description="ID of related entity (product, route, etc.)")
related_entity_type: Optional[str] = Field(None, description="Type of related entity")
class AlertSeveritySummary(BaseModel):
critical_count: int = Field(..., description="Number of critical alerts")
high_count: int = Field(..., description="Number of high severity alerts")
medium_count: int = Field(..., description="Number of medium severity alerts")
low_count: int = Field(..., description="Number of low severity alerts")
total_alerts: int = Field(..., description="Total number of alerts")
class AlertTypeSummary(BaseModel):
inventory_alerts: int = Field(..., description="Inventory-related alerts")
production_alerts: int = Field(..., description="Production-related alerts")
delivery_alerts: int = Field(..., description="Delivery-related alerts")
equipment_alerts: int = Field(..., description="Equipment-related alerts")
quality_alerts: int = Field(..., description="Quality-related alerts")
other_alerts: int = Field(..., description="Other types of alerts")
class NetworkAlertsSummary(BaseModel):
total_alerts: int = Field(..., description="Total alerts across network")
active_alerts: int = Field(..., description="Currently active alerts")
acknowledged_alerts: int = Field(..., description="Acknowledged alerts")
resolved_alerts: int = Field(..., description="Resolved alerts")
severity_summary: AlertSeveritySummary = Field(..., description="Alerts by severity")
type_summary: AlertTypeSummary = Field(..., description="Alerts by type")
most_recent_alert: Optional[NetworkAlert] = Field(None, description="Most recent alert")
class AlertCorrelation(BaseModel):
correlation_id: str = Field(..., description="Correlation group ID")
primary_alert: NetworkAlert = Field(..., description="Primary alert in the group")
related_alerts: List[NetworkAlert] = Field(..., description="Alerts correlated with primary alert")
correlation_type: str = Field(..., description="Type of correlation: causal, temporal, spatial")
correlation_strength: float = Field(..., description="Correlation strength (0-1)")
impact_analysis: str = Field(..., description="Analysis of combined impact")
async def get_network_alerts_service() -> NetworkAlertsService:
"""Dependency injection for NetworkAlertsService"""
tenant_client = get_tenant_client(settings, "tenant-service")
alerts_client = get_alerts_client(settings, "tenant-service")
return NetworkAlertsService(tenant_client, alerts_client)
@router.get("/tenants/{parent_id}/network/alerts",
response_model=List[NetworkAlert],
summary="Get aggregated alerts across network")
async def get_network_alerts(
parent_id: str,
severity: Optional[str] = Query(None, description="Filter by severity: critical, high, medium, low"),
alert_type: Optional[str] = Query(None, description="Filter by alert type"),
status: Optional[str] = Query(None, description="Filter by status: active, acknowledged, resolved"),
limit: int = Query(100, description="Maximum number of alerts to return"),
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get aggregated alerts across all child tenants in a parent network
This endpoint provides a unified view of alerts across the entire enterprise network,
enabling network managers to identify and prioritize issues that require attention.
"""
try:
# Verify this is a parent tenant
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access network alerts"
)
# Get all child tenants
child_tenants = await network_alerts_service.get_child_tenants(parent_id)
if not child_tenants:
return []
# Aggregate alerts from all child tenants
all_alerts = []
for child in child_tenants:
child_id = child['id']
child_name = child['name']
# Get alerts for this child tenant
child_alerts = await network_alerts_service.get_alerts_for_tenant(child_id)
# Enrich with tenant information and apply filters
for alert in child_alerts:
enriched_alert = {
'alert_id': alert.get('alert_id', str(uuid.uuid4())),
'tenant_id': child_id,
'tenant_name': child_name,
'alert_type': alert.get('alert_type', 'unknown'),
'severity': alert.get('severity', 'medium'),
'title': alert.get('title', 'No title'),
'message': alert.get('message', 'No message'),
'timestamp': alert.get('timestamp', datetime.now().isoformat()),
'status': alert.get('status', 'active'),
'source_system': alert.get('source_system', 'unknown'),
'related_entity_id': alert.get('related_entity_id'),
'related_entity_type': alert.get('related_entity_type')
}
# Apply filters
if severity and enriched_alert['severity'] != severity:
continue
if alert_type and enriched_alert['alert_type'] != alert_type:
continue
if status and enriched_alert['status'] != status:
continue
all_alerts.append(enriched_alert)
# Sort by severity (critical first) and timestamp (newest first)
severity_order = {'critical': 1, 'high': 2, 'medium': 3, 'low': 4}
all_alerts.sort(key=lambda x: (severity_order.get(x['severity'], 5), -int(x['timestamp'] or 0)))
return all_alerts[:limit]
except Exception as e:
logger.error("Failed to get network alerts", parent_id=parent_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to get network alerts: {str(e)}")
@router.get("/tenants/{parent_id}/network/alerts/summary",
response_model=NetworkAlertsSummary,
summary="Get network alerts summary")
async def get_network_alerts_summary(
parent_id: str,
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get summary of alerts across the network
Provides aggregated metrics and statistics about alerts across all child tenants.
"""
try:
# Verify this is a parent tenant
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access network alerts summary"
)
# Get all network alerts
all_alerts = await network_alerts_service.get_network_alerts(parent_id)
if not all_alerts:
return NetworkAlertsSummary(
total_alerts=0,
active_alerts=0,
acknowledged_alerts=0,
resolved_alerts=0,
severity_summary=AlertSeveritySummary(
critical_count=0,
high_count=0,
medium_count=0,
low_count=0,
total_alerts=0
),
type_summary=AlertTypeSummary(
inventory_alerts=0,
production_alerts=0,
delivery_alerts=0,
equipment_alerts=0,
quality_alerts=0,
other_alerts=0
),
most_recent_alert=None
)
# Calculate summary metrics
active_alerts = sum(1 for a in all_alerts if a['status'] == 'active')
acknowledged_alerts = sum(1 for a in all_alerts if a['status'] == 'acknowledged')
resolved_alerts = sum(1 for a in all_alerts if a['status'] == 'resolved')
# Calculate severity summary
severity_summary = AlertSeveritySummary(
critical_count=sum(1 for a in all_alerts if a['severity'] == 'critical'),
high_count=sum(1 for a in all_alerts if a['severity'] == 'high'),
medium_count=sum(1 for a in all_alerts if a['severity'] == 'medium'),
low_count=sum(1 for a in all_alerts if a['severity'] == 'low'),
total_alerts=len(all_alerts)
)
# Calculate type summary
type_summary = AlertTypeSummary(
inventory_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'inventory'),
production_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'production'),
delivery_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'delivery'),
equipment_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'equipment'),
quality_alerts=sum(1 for a in all_alerts if a['alert_type'] == 'quality'),
other_alerts=sum(1 for a in all_alerts if a['alert_type'] not in ['inventory', 'production', 'delivery', 'equipment', 'quality'])
)
# Get most recent alert
most_recent_alert = None
if all_alerts:
most_recent_alert = max(all_alerts, key=lambda x: x['timestamp'])
return NetworkAlertsSummary(
total_alerts=len(all_alerts),
active_alerts=active_alerts,
acknowledged_alerts=acknowledged_alerts,
resolved_alerts=resolved_alerts,
severity_summary=severity_summary,
type_summary=type_summary,
most_recent_alert=most_recent_alert
)
except Exception as e:
logger.error("Failed to get network alerts summary", parent_id=parent_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to get alerts summary: {str(e)}")
@router.get("/tenants/{parent_id}/network/alerts/correlations",
response_model=List[AlertCorrelation],
summary="Get correlated alert groups")
async def get_correlated_alerts(
parent_id: str,
min_correlation_strength: float = Query(0.7, ge=0.5, le=1.0, description="Minimum correlation strength"),
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get groups of correlated alerts
Identifies alerts that are related or have cascading effects across the network.
"""
try:
# Verify this is a parent tenant
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access alert correlations"
)
# Get all network alerts
all_alerts = await network_alerts_service.get_network_alerts(parent_id)
if not all_alerts:
return []
# Detect correlations (simplified for demo)
correlations = await network_alerts_service.detect_alert_correlations(
all_alerts, min_correlation_strength
)
return correlations
except Exception as e:
logger.error("Failed to get correlated alerts", parent_id=parent_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to get alert correlations: {str(e)}")
@router.post("/tenants/{parent_id}/network/alerts/{alert_id}/acknowledge",
summary="Acknowledge network alert")
async def acknowledge_network_alert(
parent_id: str,
alert_id: str,
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Acknowledge a network alert
Marks an alert as acknowledged to indicate it's being addressed.
"""
try:
# Verify this is a parent tenant
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can acknowledge network alerts"
)
# Acknowledge the alert
result = await network_alerts_service.acknowledge_alert(parent_id, alert_id)
return {
'success': True,
'alert_id': alert_id,
'status': 'acknowledged',
'message': 'Alert acknowledged successfully'
}
except Exception as e:
logger.error("Failed to acknowledge alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to acknowledge alert: {str(e)}")
@router.post("/tenants/{parent_id}/network/alerts/{alert_id}/resolve",
summary="Resolve network alert")
async def resolve_network_alert(
parent_id: str,
alert_id: str,
resolution_notes: Optional[str] = Query(None, description="Notes about resolution"),
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Resolve a network alert
Marks an alert as resolved after the issue has been addressed.
"""
try:
# Verify this is a parent tenant
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can resolve network alerts"
)
# Resolve the alert
result = await network_alerts_service.resolve_alert(parent_id, alert_id, resolution_notes)
return {
'success': True,
'alert_id': alert_id,
'status': 'resolved',
'resolution_notes': resolution_notes,
'message': 'Alert resolved successfully'
}
except Exception as e:
logger.error("Failed to resolve alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to resolve alert: {str(e)}")
@router.get("/tenants/{parent_id}/network/alerts/trends",
summary="Get alert trends over time")
async def get_alert_trends(
parent_id: str,
days: int = Query(30, ge=7, le=365, description="Number of days to analyze"),
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get alert trends over time
Analyzes how alert patterns change over time to identify systemic issues.
"""
try:
# Verify this is a parent tenant
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access alert trends"
)
# Get alert trends
trends = await network_alerts_service.get_alert_trends(parent_id, days)
return {
'success': True,
'trends': trends,
'period': f'Last {days} days'
}
except Exception as e:
logger.error("Failed to get alert trends", parent_id=parent_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to get alert trends: {str(e)}")
@router.get("/tenants/{parent_id}/network/alerts/prioritization",
summary="Get prioritized alerts")
async def get_prioritized_alerts(
parent_id: str,
limit: int = Query(10, description="Maximum number of alerts to return"),
network_alerts_service: NetworkAlertsService = Depends(get_network_alerts_service),
verified_tenant: str = Depends(verify_tenant_permission_dep)
):
"""
Get prioritized alerts based on impact and urgency
Uses AI to prioritize alerts based on potential business impact and urgency.
"""
try:
# Verify this is a parent tenant
tenant_info = await network_alerts_service.tenant_client.get_tenant(parent_id)
if tenant_info.get('tenant_type') != 'parent':
raise HTTPException(
status_code=403,
detail="Only parent tenants can access prioritized alerts"
)
# Get prioritized alerts
prioritized_alerts = await network_alerts_service.get_prioritized_alerts(parent_id, limit)
return {
'success': True,
'prioritized_alerts': prioritized_alerts,
'message': f'Top {len(prioritized_alerts)} prioritized alerts'
}
except Exception as e:
logger.error("Failed to get prioritized alerts", parent_id=parent_id, error=str(e))
raise HTTPException(status_code=500, detail=f"Failed to get prioritized alerts: {str(e)}")
# Import datetime at runtime to avoid circular imports
from datetime import datetime, timedelta
import uuid

View File

@@ -0,0 +1,129 @@
"""
Onboarding Status API
Provides lightweight onboarding status checks by aggregating counts from multiple services
"""
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
import asyncio
import httpx
import os
from app.core.database import get_db
from app.core.config import settings
from shared.auth.decorators import get_current_tenant_id_dep
from shared.routing.route_builder import RouteBuilder
logger = structlog.get_logger()
router = APIRouter()
route_builder = RouteBuilder("tenants")
@router.get(route_builder.build_base_route("{tenant_id}/onboarding/status", include_tenant_prefix=False))
async def get_onboarding_status(
tenant_id: str,
db: AsyncSession = Depends(get_db)
):
"""
Get lightweight onboarding status by fetching counts from each service.
Returns:
- ingredients_count: Number of active ingredients
- suppliers_count: Number of active suppliers
- recipes_count: Number of active recipes
- has_minimum_setup: Boolean indicating if minimum requirements are met
- progress_percentage: Overall onboarding progress (0-100)
"""
try:
# Service URLs from environment
inventory_url = os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")
suppliers_url = os.getenv("SUPPLIERS_SERVICE_URL", "http://suppliers-service:8000")
recipes_url = os.getenv("RECIPES_SERVICE_URL", "http://recipes-service:8000")
# Fetch counts from all services in parallel
async with httpx.AsyncClient(timeout=10.0) as client:
results = await asyncio.gather(
client.get(
f"{inventory_url}/internal/count",
params={"tenant_id": tenant_id}
),
client.get(
f"{suppliers_url}/internal/count",
params={"tenant_id": tenant_id}
),
client.get(
f"{recipes_url}/internal/count",
params={"tenant_id": tenant_id}
),
return_exceptions=True
)
# Extract counts with fallback to 0
ingredients_count = 0
suppliers_count = 0
recipes_count = 0
if not isinstance(results[0], Exception) and results[0].status_code == 200:
ingredients_count = results[0].json().get("count", 0)
if not isinstance(results[1], Exception) and results[1].status_code == 200:
suppliers_count = results[1].json().get("count", 0)
if not isinstance(results[2], Exception) and results[2].status_code == 200:
recipes_count = results[2].json().get("count", 0)
# Calculate minimum setup requirements
# Minimum: 3 ingredients, 1 supplier, 1 recipe
has_minimum_ingredients = ingredients_count >= 3
has_minimum_suppliers = suppliers_count >= 1
has_minimum_recipes = recipes_count >= 1
has_minimum_setup = all([
has_minimum_ingredients,
has_minimum_suppliers,
has_minimum_recipes
])
# Calculate progress percentage
# Each requirement contributes 33.33%
progress = 0
if has_minimum_ingredients:
progress += 33
if has_minimum_suppliers:
progress += 33
if has_minimum_recipes:
progress += 34
return {
"ingredients_count": ingredients_count,
"suppliers_count": suppliers_count,
"recipes_count": recipes_count,
"has_minimum_setup": has_minimum_setup,
"progress_percentage": progress,
"requirements": {
"ingredients": {
"current": ingredients_count,
"minimum": 3,
"met": has_minimum_ingredients
},
"suppliers": {
"current": suppliers_count,
"minimum": 1,
"met": has_minimum_suppliers
},
"recipes": {
"current": recipes_count,
"minimum": 1,
"met": has_minimum_recipes
}
}
}
except Exception as e:
logger.error("Failed to get onboarding status", tenant_id=tenant_id, error=str(e))
raise HTTPException(
status_code=500,
detail=f"Failed to get onboarding status: {str(e)}"
)

View File

@@ -0,0 +1,330 @@
"""
Subscription Plans API
Public endpoint for fetching available subscription plans
"""
from fastapi import APIRouter, HTTPException
from typing import Dict, Any
import structlog
from shared.subscription.plans import (
SubscriptionTier,
SubscriptionPlanMetadata,
PlanPricing,
QuotaLimits,
PlanFeatures,
FeatureCategories,
UserFacingFeatures
)
logger = structlog.get_logger()
router = APIRouter(prefix="/plans", tags=["subscription-plans"])
@router.get("", response_model=Dict[str, Any])
async def get_available_plans():
"""
Get all available subscription plans with complete metadata
**Public endpoint** - No authentication required
Returns:
Dictionary containing plan metadata for all tiers
Example Response:
```json
{
"plans": {
"starter": {
"name": "Starter",
"description": "Perfect for small bakeries getting started",
"monthly_price": 49.00,
"yearly_price": 490.00,
"features": [...],
"limits": {...}
},
...
}
}
```
"""
try:
plans_data = {}
for tier in SubscriptionTier:
metadata = SubscriptionPlanMetadata.PLANS[tier]
# Convert Decimal to float for JSON serialization
plans_data[tier.value] = {
"name": metadata["name"],
"description_key": metadata["description_key"],
"tagline_key": metadata["tagline_key"],
"popular": metadata["popular"],
"monthly_price": float(metadata["monthly_price"]),
"yearly_price": float(metadata["yearly_price"]),
"trial_days": metadata["trial_days"],
"features": metadata["features"],
"hero_features": metadata.get("hero_features", []),
"roi_badge": metadata.get("roi_badge"),
"business_metrics": metadata.get("business_metrics"),
"limits": metadata["limits"],
"support_key": metadata["support_key"],
"recommended_for_key": metadata["recommended_for_key"],
"contact_sales": metadata.get("contact_sales", False),
"custom_pricing": metadata.get("custom_pricing", False),
}
logger.info("subscription_plans_fetched", tier_count=len(plans_data))
return {"plans": plans_data}
except Exception as e:
logger.error("failed_to_fetch_plans", error=str(e))
raise HTTPException(
status_code=500,
detail="Failed to fetch subscription plans"
)
@router.get("/{tier}", response_model=Dict[str, Any])
async def get_plan_by_tier(tier: str):
"""
Get metadata for a specific subscription tier
**Public endpoint** - No authentication required
Args:
tier: Subscription tier (starter, professional, enterprise)
Returns:
Plan metadata for the specified tier
Raises:
404: If tier is not found
"""
try:
# Validate tier
tier_enum = SubscriptionTier(tier.lower())
metadata = SubscriptionPlanMetadata.PLANS[tier_enum]
plan_data = {
"tier": tier_enum.value,
"name": metadata["name"],
"description_key": metadata["description_key"],
"tagline_key": metadata["tagline_key"],
"popular": metadata["popular"],
"monthly_price": float(metadata["monthly_price"]),
"yearly_price": float(metadata["yearly_price"]),
"trial_days": metadata["trial_days"],
"features": metadata["features"],
"hero_features": metadata.get("hero_features", []),
"roi_badge": metadata.get("roi_badge"),
"business_metrics": metadata.get("business_metrics"),
"limits": metadata["limits"],
"support_key": metadata["support_key"],
"recommended_for_key": metadata["recommended_for_key"],
"contact_sales": metadata.get("contact_sales", False),
"custom_pricing": metadata.get("custom_pricing", False),
}
logger.info("subscription_plan_fetched", tier=tier)
return plan_data
except ValueError:
raise HTTPException(
status_code=404,
detail=f"Subscription tier '{tier}' not found"
)
except Exception as e:
logger.error("failed_to_fetch_plan", tier=tier, error=str(e))
raise HTTPException(
status_code=500,
detail="Failed to fetch subscription plan"
)
@router.get("/{tier}/features")
async def get_plan_features(tier: str):
"""
Get all features available in a subscription tier
**Public endpoint** - No authentication required
Args:
tier: Subscription tier (starter, professional, enterprise)
Returns:
List of feature keys available in the tier
"""
try:
tier_enum = SubscriptionTier(tier.lower())
features = PlanFeatures.get_features(tier_enum.value)
return {
"tier": tier_enum.value,
"features": features,
"feature_count": len(features)
}
except ValueError:
raise HTTPException(
status_code=404,
detail=f"Subscription tier '{tier}' not found"
)
@router.get("/{tier}/limits")
async def get_plan_limits(tier: str):
"""
Get all quota limits for a subscription tier
**Public endpoint** - No authentication required
Args:
tier: Subscription tier (starter, professional, enterprise)
Returns:
All quota limits for the tier
"""
try:
tier_enum = SubscriptionTier(tier.lower())
limits = {
"tier": tier_enum.value,
"team_and_organization": {
"max_users": QuotaLimits.MAX_USERS[tier_enum],
"max_locations": QuotaLimits.MAX_LOCATIONS[tier_enum],
},
"product_and_inventory": {
"max_products": QuotaLimits.MAX_PRODUCTS[tier_enum],
"max_recipes": QuotaLimits.MAX_RECIPES[tier_enum],
"max_suppliers": QuotaLimits.MAX_SUPPLIERS[tier_enum],
},
"ml_and_analytics": {
"training_jobs_per_day": QuotaLimits.TRAINING_JOBS_PER_DAY[tier_enum],
"forecast_generation_per_day": QuotaLimits.FORECAST_GENERATION_PER_DAY[tier_enum],
"dataset_size_rows": QuotaLimits.DATASET_SIZE_ROWS[tier_enum],
"forecast_horizon_days": QuotaLimits.FORECAST_HORIZON_DAYS[tier_enum],
"historical_data_access_days": QuotaLimits.HISTORICAL_DATA_ACCESS_DAYS[tier_enum],
},
"import_export": {
"bulk_import_rows": QuotaLimits.BULK_IMPORT_ROWS[tier_enum],
"bulk_export_rows": QuotaLimits.BULK_EXPORT_ROWS[tier_enum],
},
"integrations": {
"pos_sync_interval_minutes": QuotaLimits.POS_SYNC_INTERVAL_MINUTES[tier_enum],
"api_calls_per_hour": QuotaLimits.API_CALLS_PER_HOUR[tier_enum],
"webhook_endpoints": QuotaLimits.WEBHOOK_ENDPOINTS[tier_enum],
},
"storage": {
"file_storage_gb": QuotaLimits.FILE_STORAGE_GB[tier_enum],
"report_retention_days": QuotaLimits.REPORT_RETENTION_DAYS[tier_enum],
}
}
return limits
except ValueError:
raise HTTPException(
status_code=404,
detail=f"Subscription tier '{tier}' not found"
)
@router.get("/feature-categories")
async def get_feature_categories():
"""
Get all feature categories with icons and translation keys
**Public endpoint** - No authentication required
Returns:
Dictionary of feature categories
"""
try:
return {
"categories": FeatureCategories.CATEGORIES
}
except Exception as e:
logger.error("failed_to_fetch_feature_categories", error=str(e))
raise HTTPException(
status_code=500,
detail="Failed to fetch feature categories"
)
@router.get("/feature-descriptions")
async def get_feature_descriptions():
"""
Get user-facing feature descriptions with translation keys
**Public endpoint** - No authentication required
Returns:
Dictionary of feature descriptions mapped by feature key
"""
try:
return {
"features": UserFacingFeatures.FEATURE_DISPLAY
}
except Exception as e:
logger.error("failed_to_fetch_feature_descriptions", error=str(e))
raise HTTPException(
status_code=500,
detail="Failed to fetch feature descriptions"
)
@router.get("/compare")
async def compare_plans():
"""
Get plan comparison data for all tiers
**Public endpoint** - No authentication required
Returns:
Comparison matrix of all plans with key features and limits
"""
try:
comparison = {
"tiers": ["starter", "professional", "enterprise"],
"pricing": {},
"key_features": {},
"key_limits": {}
}
for tier in SubscriptionTier:
metadata = SubscriptionPlanMetadata.PLANS[tier]
# Pricing
comparison["pricing"][tier.value] = {
"monthly": float(metadata["monthly_price"]),
"yearly": float(metadata["yearly_price"]),
"savings_percentage": round(
((float(metadata["monthly_price"]) * 12) - float(metadata["yearly_price"])) /
(float(metadata["monthly_price"]) * 12) * 100
)
}
# Key features (first 10)
comparison["key_features"][tier.value] = metadata["features"][:10]
# Key limits
comparison["key_limits"][tier.value] = {
"users": metadata["limits"]["users"],
"locations": metadata["limits"]["locations"],
"products": metadata["limits"]["products"],
"forecasts_per_day": metadata["limits"]["forecasts_per_day"],
"training_jobs_per_day": QuotaLimits.TRAINING_JOBS_PER_DAY[tier],
}
return comparison
except Exception as e:
logger.error("failed_to_compare_plans", error=str(e))
raise HTTPException(
status_code=500,
detail="Failed to generate plan comparison"
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,595 @@
"""
Tenant Hierarchy API - Handles parent-child tenant relationships
"""
from fastapi import APIRouter, Depends, HTTPException, status, Path
from typing import List, Dict, Any
from uuid import UUID
from app.schemas.tenants import (
TenantResponse,
ChildTenantCreate,
BulkChildTenantsCreate,
BulkChildTenantsResponse,
ChildTenantResponse,
TenantHierarchyResponse
)
from app.services.tenant_service import EnhancedTenantService
from app.repositories.tenant_repository import TenantRepository
from shared.auth.decorators import get_current_user_dep
from shared.routing.route_builder import RouteBuilder
from shared.database.base import create_database_manager
from shared.monitoring.metrics import track_endpoint_metrics
import structlog
logger = structlog.get_logger()
router = APIRouter()
route_builder = RouteBuilder("tenants")
# Dependency injection for enhanced tenant service
def get_enhanced_tenant_service():
try:
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
return EnhancedTenantService(database_manager)
except Exception as e:
logger.error("Failed to create enhanced tenant service", error=str(e))
raise HTTPException(status_code=500, detail="Service initialization failed")
@router.get(route_builder.build_base_route("{tenant_id}/children", include_tenant_prefix=False), response_model=List[TenantResponse])
@track_endpoint_metrics("tenant_children_list")
async def get_tenant_children(
tenant_id: UUID = Path(..., description="Parent Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Get all child tenants for a parent tenant.
This endpoint returns all active child tenants associated with the specified parent tenant.
"""
try:
logger.info(
"Get tenant children request received",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
user_type=current_user.get("type", "user"),
is_service=current_user.get("type") == "service",
role=current_user.get("role"),
service_name=current_user.get("service", "none")
)
# Skip access check for service-to-service calls
is_service_call = current_user.get("type") == "service"
if not is_service_call:
# Verify user has access to the parent tenant
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
if not access_info.has_access:
logger.warning(
"Access denied to parent tenant",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to parent tenant"
)
else:
logger.debug(
"Service-to-service call - bypassing access check",
service=current_user.get("service"),
tenant_id=str(tenant_id)
)
# Get child tenants from repository
from app.models.tenants import Tenant
async with tenant_service.database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
child_tenants = await tenant_repo.get_child_tenants(str(tenant_id))
logger.debug(
"Get tenant children successful",
tenant_id=str(tenant_id),
child_count=len(child_tenants)
)
# Convert to plain dicts while still in session to avoid lazy-load issues
child_dicts = []
for child in child_tenants:
# Handle subscription_tier safely - avoid lazy load
try:
# Try to get subscription_tier if subscriptions are already loaded
sub_tier = child.__dict__.get('_subscription_tier_cache', 'enterprise')
except:
sub_tier = 'enterprise' # Default for enterprise children
child_dict = {
'id': str(child.id),
'name': child.name,
'subdomain': child.subdomain,
'business_type': child.business_type,
'business_model': child.business_model,
'address': child.address,
'city': child.city,
'postal_code': child.postal_code,
'latitude': child.latitude,
'longitude': child.longitude,
'phone': child.phone,
'email': child.email,
'timezone': child.timezone,
'owner_id': str(child.owner_id),
'parent_tenant_id': str(child.parent_tenant_id) if child.parent_tenant_id else None,
'tenant_type': child.tenant_type,
'hierarchy_path': child.hierarchy_path,
'subscription_tier': sub_tier, # Use the safely retrieved value
'ml_model_trained': child.ml_model_trained,
'last_training_date': child.last_training_date,
'is_active': child.is_active,
'is_demo': child.is_demo,
'demo_session_id': child.demo_session_id,
'created_at': child.created_at,
'updated_at': child.updated_at
}
child_dicts.append(child_dict)
# Convert to Pydantic models outside the session without from_attributes
child_responses = [TenantResponse(**child_dict) for child_dict in child_dicts]
return child_responses
except HTTPException:
raise
except Exception as e:
logger.error("Get tenant children failed",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Get tenant children failed"
)
@router.get(route_builder.build_base_route("{tenant_id}/children/count", include_tenant_prefix=False))
@track_endpoint_metrics("tenant_children_count")
async def get_tenant_children_count(
tenant_id: UUID = Path(..., description="Parent Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Get count of child tenants for a parent tenant.
This endpoint returns the number of active child tenants associated with the specified parent tenant.
"""
try:
logger.info(
"Get tenant children count request received",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id")
)
# Skip access check for service-to-service calls
is_service_call = current_user.get("type") == "service"
if not is_service_call:
# Verify user has access to the parent tenant
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
if not access_info.has_access:
logger.warning(
"Access denied to parent tenant",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to parent tenant"
)
else:
logger.debug(
"Service-to-service call - bypassing access check",
service=current_user.get("service"),
tenant_id=str(tenant_id)
)
# Get child count from repository
from app.models.tenants import Tenant
async with tenant_service.database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
child_count = await tenant_repo.get_child_tenant_count(str(tenant_id))
logger.debug(
"Get tenant children count successful",
tenant_id=str(tenant_id),
child_count=child_count
)
return {
"parent_tenant_id": str(tenant_id),
"child_count": child_count
}
except HTTPException:
raise
except Exception as e:
logger.error("Get tenant children count failed",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Get tenant children count failed"
)
@router.get(route_builder.build_base_route("{tenant_id}/hierarchy", include_tenant_prefix=False), response_model=TenantHierarchyResponse)
@track_endpoint_metrics("tenant_hierarchy")
async def get_tenant_hierarchy(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Get tenant hierarchy information.
Returns hierarchy metadata for a tenant including:
- Tenant type (standalone, parent, child)
- Parent tenant ID (if this is a child)
- Hierarchy path (materialized path)
- Number of child tenants (for parent tenants)
- Hierarchy level (depth in the tree)
This endpoint is used by the authentication layer for hierarchical access control
and by enterprise features for network management.
"""
try:
logger.info(
"Get tenant hierarchy request received",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
user_type=current_user.get("type", "user"),
is_service=current_user.get("type") == "service"
)
# Get tenant from database
from app.models.tenants import Tenant
async with tenant_service.database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
# Get the tenant
tenant = await tenant_repo.get(str(tenant_id))
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tenant {tenant_id} not found"
)
# Skip access check for service-to-service calls
is_service_call = current_user.get("type") == "service"
if not is_service_call:
# Verify user has access to this tenant
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
if not access_info.has_access:
logger.warning(
"Access denied to tenant for hierarchy query",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant"
)
else:
logger.debug(
"Service-to-service call - bypassing access check",
service=current_user.get("service"),
tenant_id=str(tenant_id)
)
# Get child count if this is a parent tenant
child_count = 0
if tenant.tenant_type in ["parent", "standalone"]:
child_count = await tenant_repo.get_child_tenant_count(str(tenant_id))
# Calculate hierarchy level from hierarchy_path
hierarchy_level = 0
if tenant.hierarchy_path:
# hierarchy_path format: "parent_id" or "parent_id.child_id" or "parent_id.child_id.grandchild_id"
hierarchy_level = tenant.hierarchy_path.count('.')
# Build response
hierarchy_info = TenantHierarchyResponse(
tenant_id=str(tenant.id),
tenant_type=tenant.tenant_type or "standalone",
parent_tenant_id=str(tenant.parent_tenant_id) if tenant.parent_tenant_id else None,
hierarchy_path=tenant.hierarchy_path,
child_count=child_count,
hierarchy_level=hierarchy_level
)
logger.info(
"Get tenant hierarchy successful",
tenant_id=str(tenant_id),
tenant_type=tenant.tenant_type,
parent_tenant_id=str(tenant.parent_tenant_id) if tenant.parent_tenant_id else None,
child_count=child_count,
hierarchy_level=hierarchy_level
)
return hierarchy_info
except HTTPException:
raise
except Exception as e:
logger.error("Get tenant hierarchy failed",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Get tenant hierarchy failed"
)
@router.post("/api/v1/tenants/{tenant_id}/bulk-children", response_model=BulkChildTenantsResponse)
@track_endpoint_metrics("bulk_create_child_tenants")
async def bulk_create_child_tenants(
request: BulkChildTenantsCreate,
tenant_id: str = Path(..., description="Parent tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Bulk create child tenants for enterprise onboarding.
This endpoint creates multiple child tenants (outlets/branches) for an enterprise parent tenant
and establishes the parent-child relationship. It's designed for use during the onboarding flow
when an enterprise customer registers their network of locations.
Features:
- Creates child tenants with proper hierarchy
- Inherits subscription from parent
- Optionally configures distribution routes
- Returns detailed success/failure information
"""
try:
logger.info(
"Bulk child tenant creation request received",
parent_tenant_id=tenant_id,
child_count=len(request.child_tenants),
user_id=current_user.get("user_id")
)
# Verify parent tenant exists and user has access
async with tenant_service.database_manager.get_session() as session:
from app.models.tenants import Tenant
tenant_repo = TenantRepository(Tenant, session)
parent_tenant = await tenant_repo.get_by_id(tenant_id)
if not parent_tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Parent tenant not found"
)
# Verify user has access to parent tenant (owners/admins only)
access_info = await tenant_service.verify_user_access(
current_user["user_id"],
tenant_id
)
if not access_info.has_access or access_info.role not in ["owner", "admin"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Only tenant owners/admins can create child tenants"
)
# Verify parent is enterprise tier
parent_subscription = await tenant_service.subscription_repo.get_active_subscription(tenant_id)
if not parent_subscription or parent_subscription.plan != "enterprise":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only enterprise tier tenants can have child tenants"
)
# Update parent tenant type if it's still standalone
if parent_tenant.tenant_type == "standalone":
parent_tenant.tenant_type = "parent"
parent_tenant.hierarchy_path = str(parent_tenant.id)
await session.commit()
await session.refresh(parent_tenant)
# Create child tenants
created_tenants = []
failed_tenants = []
for child_data in request.child_tenants:
# Create a nested transaction (savepoint) for each child tenant
# This allows us to rollback individual child tenant creation without affecting others
async with session.begin_nested():
try:
# Create child tenant with full tenant model fields
child_tenant = Tenant(
name=child_data.name,
subdomain=None, # Child tenants typically don't have subdomains
business_type=child_data.business_type or parent_tenant.business_type,
business_model=child_data.business_model or "retail_bakery", # Child outlets are typically retail
address=child_data.address,
city=child_data.city,
postal_code=child_data.postal_code,
latitude=child_data.latitude,
longitude=child_data.longitude,
phone=child_data.phone or parent_tenant.phone,
email=child_data.email or parent_tenant.email,
timezone=child_data.timezone or parent_tenant.timezone,
owner_id=parent_tenant.owner_id,
parent_tenant_id=parent_tenant.id,
tenant_type="child",
hierarchy_path=f"{parent_tenant.hierarchy_path}", # Will be updated after flush
is_active=True,
is_demo=parent_tenant.is_demo,
demo_session_id=parent_tenant.demo_session_id,
demo_expires_at=parent_tenant.demo_expires_at,
metadata_={
"location_code": child_data.location_code,
"zone": child_data.zone,
**(child_data.metadata or {})
}
)
session.add(child_tenant)
await session.flush() # Get the ID without committing
# Update hierarchy_path now that we have the child tenant ID
child_tenant.hierarchy_path = f"{parent_tenant.hierarchy_path}.{str(child_tenant.id)}"
# Create TenantLocation record for the child
from app.models.tenant_location import TenantLocation
location = TenantLocation(
tenant_id=child_tenant.id,
name=child_data.name,
city=child_data.city,
address=child_data.address,
postal_code=child_data.postal_code,
latitude=child_data.latitude,
longitude=child_data.longitude,
is_active=True,
location_type="retail"
)
session.add(location)
# Inherit subscription from parent
from app.models.tenants import Subscription
from sqlalchemy import select
parent_subscription_result = await session.execute(
select(Subscription).where(
Subscription.tenant_id == parent_tenant.id,
Subscription.status == "active"
)
)
parent_sub = parent_subscription_result.scalar_one_or_none()
if parent_sub:
child_subscription = Subscription(
tenant_id=child_tenant.id,
plan=parent_sub.plan,
status="active",
billing_cycle=parent_sub.billing_cycle,
monthly_price=0, # Child tenants don't pay separately
trial_ends_at=parent_sub.trial_ends_at
)
session.add(child_subscription)
# Commit the nested transaction (savepoint)
await session.flush()
# Refresh objects to get their final state
await session.refresh(child_tenant)
await session.refresh(location)
# Build response
created_tenants.append(ChildTenantResponse(
id=str(child_tenant.id),
name=child_tenant.name,
subdomain=child_tenant.subdomain,
business_type=child_tenant.business_type,
business_model=child_tenant.business_model,
tenant_type=child_tenant.tenant_type,
parent_tenant_id=str(child_tenant.parent_tenant_id),
address=child_tenant.address,
city=child_tenant.city,
postal_code=child_tenant.postal_code,
phone=child_tenant.phone,
is_active=child_tenant.is_active,
subscription_plan="enterprise",
ml_model_trained=child_tenant.ml_model_trained,
last_training_date=child_tenant.last_training_date,
owner_id=str(child_tenant.owner_id),
created_at=child_tenant.created_at,
location_code=child_data.location_code,
zone=child_data.zone,
hierarchy_path=child_tenant.hierarchy_path
))
logger.info(
"Child tenant created successfully",
child_tenant_id=str(child_tenant.id),
child_name=child_tenant.name,
location_code=child_data.location_code
)
except Exception as child_error:
logger.error(
"Failed to create child tenant",
child_name=child_data.name,
error=str(child_error)
)
failed_tenants.append({
"name": child_data.name,
"location_code": child_data.location_code,
"error": str(child_error)
})
# Nested transaction will automatically rollback on exception
# This only rolls back the current child tenant, not the entire batch
# Commit all successful child tenant creations
await session.commit()
# TODO: Configure distribution routes if requested
distribution_configured = False
if request.auto_configure_distribution and len(created_tenants) > 0:
try:
# This would call the distribution service to set up routes
# For now, we'll skip this and just log
logger.info(
"Distribution route configuration requested",
parent_tenant_id=tenant_id,
child_count=len(created_tenants)
)
# distribution_configured = await configure_distribution_routes(...)
except Exception as dist_error:
logger.warning(
"Failed to configure distribution routes",
error=str(dist_error)
)
logger.info(
"Bulk child tenant creation completed",
parent_tenant_id=tenant_id,
created_count=len(created_tenants),
failed_count=len(failed_tenants)
)
return BulkChildTenantsResponse(
parent_tenant_id=tenant_id,
created_count=len(created_tenants),
failed_count=len(failed_tenants),
created_tenants=created_tenants,
failed_tenants=failed_tenants,
distribution_configured=distribution_configured
)
except HTTPException:
raise
except Exception as e:
logger.error(
"Bulk child tenant creation failed",
parent_tenant_id=tenant_id,
user_id=current_user.get("user_id"),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Bulk child tenant creation failed: {str(e)}"
)
# Register the router in the main app
def register_hierarchy_routes(app):
"""Register hierarchy routes with the main application"""
from shared.routing.route_builder import RouteBuilder
route_builder = RouteBuilder("tenants")
# Include the hierarchy routes with proper tenant prefix
app.include_router(
router,
prefix="/api/v1",
tags=["tenant-hierarchy"]
)

View File

@@ -0,0 +1,628 @@
"""
Tenant Locations API - Handles tenant location operations
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from typing import List, Dict, Any, Optional
from uuid import UUID
from app.schemas.tenant_locations import (
TenantLocationCreate,
TenantLocationUpdate,
TenantLocationResponse,
TenantLocationsResponse,
TenantLocationTypeFilter
)
from app.repositories.tenant_location_repository import TenantLocationRepository
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import admin_role_required
from shared.monitoring.metrics import track_endpoint_metrics
from shared.routing.route_builder import RouteBuilder
logger = structlog.get_logger()
router = APIRouter()
route_builder = RouteBuilder("tenants")
# Dependency injection for tenant location repository
async def get_tenant_location_repository():
"""Get tenant location repository instance with proper session management"""
try:
from app.core.database import database_manager
# Use async context manager properly to ensure session is closed
async with database_manager.get_session() as session:
yield TenantLocationRepository(session)
except Exception as e:
logger.error("Failed to create tenant location repository", error=str(e))
raise HTTPException(status_code=500, detail="Service initialization failed")
@router.get(route_builder.build_base_route("{tenant_id}/locations", include_tenant_prefix=False), response_model=TenantLocationsResponse)
@track_endpoint_metrics("tenant_locations_list")
async def get_tenant_locations(
tenant_id: UUID = Path(..., description="Tenant ID"),
location_types: str = Query(None, description="Comma-separated list of location types to filter"),
is_active: Optional[bool] = Query(None, description="Filter by active status"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
):
"""
Get all locations for a tenant.
Args:
tenant_id: ID of the tenant to get locations for
location_types: Optional comma-separated list of location types to filter (e.g., "central_production,retail_outlet")
is_active: Optional filter for active locations only
current_user: Current user making the request
location_repo: Tenant location repository instance
"""
try:
logger.info(
"Get tenant locations request received",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
location_types=location_types,
is_active=is_active
)
# Check that the user has access to this tenant
# This would typically be checked via access control middleware
# For now, we'll trust the gateway has validated tenant access
locations = []
if location_types:
# Filter by specific location types
types_list = [t.strip() for t in location_types.split(",")]
locations = await location_repo.get_locations_by_tenant_with_type(str(tenant_id), types_list)
elif is_active is True:
# Get only active locations
locations = await location_repo.get_active_locations_by_tenant(str(tenant_id))
elif is_active is False:
# Get only inactive locations (by getting all and filtering in memory - not efficient but functional)
all_locations = await location_repo.get_locations_by_tenant(str(tenant_id))
locations = [loc for loc in all_locations if not loc.is_active]
else:
# Get all locations
locations = await location_repo.get_locations_by_tenant(str(tenant_id))
logger.debug(
"Get tenant locations successful",
tenant_id=str(tenant_id),
location_count=len(locations)
)
# Convert to response format - handle metadata field to avoid SQLAlchemy conflicts
location_responses = []
for loc in locations:
# Create dict from ORM object manually to handle metadata field properly
loc_dict = {
'id': str(loc.id),
'tenant_id': str(loc.tenant_id),
'name': loc.name,
'location_type': loc.location_type,
'address': loc.address,
'city': loc.city,
'postal_code': loc.postal_code,
'latitude': loc.latitude,
'longitude': loc.longitude,
'contact_person': loc.contact_person,
'contact_phone': loc.contact_phone,
'contact_email': loc.contact_email,
'is_active': loc.is_active,
'delivery_windows': loc.delivery_windows,
'operational_hours': loc.operational_hours,
'capacity': loc.capacity,
'max_delivery_radius_km': loc.max_delivery_radius_km,
'delivery_schedule_config': loc.delivery_schedule_config,
'metadata': loc.metadata_, # Use the actual column name to avoid conflict
'created_at': loc.created_at,
'updated_at': loc.updated_at
}
location_responses.append(TenantLocationResponse.model_validate(loc_dict))
return TenantLocationsResponse(
locations=location_responses,
total=len(location_responses)
)
except HTTPException:
raise
except Exception as e:
logger.error("Get tenant locations failed",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Get tenant locations failed"
)
@router.get(route_builder.build_base_route("{tenant_id}/locations/{location_id}", include_tenant_prefix=False), response_model=TenantLocationResponse)
@track_endpoint_metrics("tenant_location_get")
async def get_tenant_location(
tenant_id: UUID = Path(..., description="Tenant ID"),
location_id: UUID = Path(..., description="Location ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
):
"""
Get a specific location for a tenant.
Args:
tenant_id: ID of the tenant
location_id: ID of the location to retrieve
current_user: Current user making the request
location_repo: Tenant location repository instance
"""
try:
logger.info(
"Get tenant location request received",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
# Get the specific location
location = await location_repo.get_location_by_id(str(location_id))
if not location:
logger.warning(
"Location not found",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
# Verify that the location belongs to the specified tenant
if str(location.tenant_id) != str(tenant_id):
logger.warning(
"Location does not belong to tenant",
tenant_id=str(tenant_id),
location_id=str(location_id),
location_tenant_id=str(location.tenant_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
logger.debug(
"Get tenant location successful",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
# Create dict from ORM object manually to handle metadata field properly
loc_dict = {
'id': str(location.id),
'tenant_id': str(location.tenant_id),
'name': location.name,
'location_type': location.location_type,
'address': location.address,
'city': location.city,
'postal_code': location.postal_code,
'latitude': location.latitude,
'longitude': location.longitude,
'contact_person': location.contact_person,
'contact_phone': location.contact_phone,
'contact_email': location.contact_email,
'is_active': location.is_active,
'delivery_windows': location.delivery_windows,
'operational_hours': location.operational_hours,
'capacity': location.capacity,
'max_delivery_radius_km': location.max_delivery_radius_km,
'delivery_schedule_config': location.delivery_schedule_config,
'metadata': location.metadata_, # Use the actual column name to avoid conflict
'created_at': location.created_at,
'updated_at': location.updated_at
}
return TenantLocationResponse.model_validate(loc_dict)
except HTTPException:
raise
except Exception as e:
logger.error("Get tenant location failed",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Get tenant location failed"
)
@router.post(route_builder.build_base_route("{tenant_id}/locations", include_tenant_prefix=False), response_model=TenantLocationResponse)
@admin_role_required
async def create_tenant_location(
location_data: TenantLocationCreate,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
):
"""
Create a new location for a tenant.
Requires admin or owner privileges.
Args:
location_data: Location data to create
tenant_id: ID of the tenant to create location for
current_user: Current user making the request
location_repo: Tenant location repository instance
"""
try:
logger.info(
"Create tenant location request received",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id")
)
# Verify that the tenant_id in the path matches the one in the data
if str(tenant_id) != location_data.tenant_id:
logger.warning(
"Tenant ID mismatch",
path_tenant_id=str(tenant_id),
data_tenant_id=location_data.tenant_id,
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Tenant ID in path does not match data"
)
# Prepare location data by excluding unset values
location_dict = location_data.model_dump(exclude_unset=True)
# Ensure tenant_id comes from the path for security
location_dict['tenant_id'] = str(tenant_id)
created_location = await location_repo.create_location(location_dict)
logger.info(
"Created tenant location successfully",
tenant_id=str(tenant_id),
location_id=str(created_location.id),
user_id=current_user.get("user_id")
)
# Create dict from ORM object manually to handle metadata field properly
loc_dict = {
'id': str(created_location.id),
'tenant_id': str(created_location.tenant_id),
'name': created_location.name,
'location_type': created_location.location_type,
'address': created_location.address,
'city': created_location.city,
'postal_code': created_location.postal_code,
'latitude': created_location.latitude,
'longitude': created_location.longitude,
'contact_person': created_location.contact_person,
'contact_phone': created_location.contact_phone,
'contact_email': created_location.contact_email,
'is_active': created_location.is_active,
'delivery_windows': created_location.delivery_windows,
'operational_hours': created_location.operational_hours,
'capacity': created_location.capacity,
'max_delivery_radius_km': created_location.max_delivery_radius_km,
'delivery_schedule_config': created_location.delivery_schedule_config,
'metadata': created_location.metadata_, # Use the actual column name to avoid conflict
'created_at': created_location.created_at,
'updated_at': created_location.updated_at
}
return TenantLocationResponse.model_validate(loc_dict)
except HTTPException:
raise
except Exception as e:
logger.error("Create tenant location failed",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Create tenant location failed"
)
@router.put(route_builder.build_base_route("{tenant_id}/locations/{location_id}", include_tenant_prefix=False), response_model=TenantLocationResponse)
@admin_role_required
async def update_tenant_location(
update_data: TenantLocationUpdate,
tenant_id: UUID = Path(..., description="Tenant ID"),
location_id: UUID = Path(..., description="Location ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
):
"""
Update a tenant location.
Requires admin or owner privileges.
Args:
update_data: Location data to update
tenant_id: ID of the tenant
location_id: ID of the location to update
current_user: Current user making the request
location_repo: Tenant location repository instance
"""
try:
logger.info(
"Update tenant location request received",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
# Check if the location exists and belongs to the tenant
existing_location = await location_repo.get_location_by_id(str(location_id))
if not existing_location:
logger.warning(
"Location not found for update",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
if str(existing_location.tenant_id) != str(tenant_id):
logger.warning(
"Location does not belong to tenant for update",
tenant_id=str(tenant_id),
location_id=str(location_id),
location_tenant_id=str(existing_location.tenant_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
# Prepare update data by excluding unset values
update_dict = update_data.model_dump(exclude_unset=True)
updated_location = await location_repo.update_location(str(location_id), update_dict)
if not updated_location:
logger.error(
"Failed to update location (not found after verification)",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
logger.info(
"Updated tenant location successfully",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
# Create dict from ORM object manually to handle metadata field properly
loc_dict = {
'id': str(updated_location.id),
'tenant_id': str(updated_location.tenant_id),
'name': updated_location.name,
'location_type': updated_location.location_type,
'address': updated_location.address,
'city': updated_location.city,
'postal_code': updated_location.postal_code,
'latitude': updated_location.latitude,
'longitude': updated_location.longitude,
'contact_person': updated_location.contact_person,
'contact_phone': updated_location.contact_phone,
'contact_email': updated_location.contact_email,
'is_active': updated_location.is_active,
'delivery_windows': updated_location.delivery_windows,
'operational_hours': updated_location.operational_hours,
'capacity': updated_location.capacity,
'max_delivery_radius_km': updated_location.max_delivery_radius_km,
'delivery_schedule_config': updated_location.delivery_schedule_config,
'metadata': updated_location.metadata_, # Use the actual column name to avoid conflict
'created_at': updated_location.created_at,
'updated_at': updated_location.updated_at
}
return TenantLocationResponse.model_validate(loc_dict)
except HTTPException:
raise
except Exception as e:
logger.error("Update tenant location failed",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Update tenant location failed"
)
@router.delete(route_builder.build_base_route("{tenant_id}/locations/{location_id}", include_tenant_prefix=False))
@admin_role_required
async def delete_tenant_location(
tenant_id: UUID = Path(..., description="Tenant ID"),
location_id: UUID = Path(..., description="Location ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
):
"""
Delete a tenant location.
Requires admin or owner privileges.
Args:
tenant_id: ID of the tenant
location_id: ID of the location to delete
current_user: Current user making the request
location_repo: Tenant location repository instance
"""
try:
logger.info(
"Delete tenant location request received",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
# Check if the location exists and belongs to the tenant
existing_location = await location_repo.get_location_by_id(str(location_id))
if not existing_location:
logger.warning(
"Location not found for deletion",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
if str(existing_location.tenant_id) != str(tenant_id):
logger.warning(
"Location does not belong to tenant for deletion",
tenant_id=str(tenant_id),
location_id=str(location_id),
location_tenant_id=str(existing_location.tenant_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
deleted = await location_repo.delete_location(str(location_id))
if not deleted:
logger.warning(
"Location not found for deletion (race condition)",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Location not found"
)
logger.info(
"Deleted tenant location successfully",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id")
)
return {
"message": "Location deleted successfully",
"location_id": str(location_id)
}
except HTTPException:
raise
except Exception as e:
logger.error("Delete tenant location failed",
tenant_id=str(tenant_id),
location_id=str(location_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Delete tenant location failed"
)
@router.get(route_builder.build_base_route("{tenant_id}/locations/type/{location_type}", include_tenant_prefix=False), response_model=TenantLocationsResponse)
@track_endpoint_metrics("tenant_locations_by_type")
async def get_tenant_locations_by_type(
tenant_id: UUID = Path(..., description="Tenant ID"),
location_type: str = Path(..., description="Location type to filter by", pattern=r'^(central_production|retail_outlet|warehouse|store|branch)$'),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
location_repo: TenantLocationRepository = Depends(get_tenant_location_repository)
):
"""
Get all locations of a specific type for a tenant.
Args:
tenant_id: ID of the tenant to get locations for
location_type: Type of location to filter by
current_user: Current user making the request
location_repo: Tenant location repository instance
"""
try:
logger.info(
"Get tenant locations by type request received",
tenant_id=str(tenant_id),
location_type=location_type,
user_id=current_user.get("user_id")
)
# Use the method that returns multiple locations by types
location_list = await location_repo.get_locations_by_tenant_with_type(str(tenant_id), [location_type])
logger.debug(
"Get tenant locations by type successful",
tenant_id=str(tenant_id),
location_type=location_type,
location_count=len(location_list)
)
# Convert to response format - handle metadata field to avoid SQLAlchemy conflicts
location_responses = []
for loc in location_list:
# Create dict from ORM object manually to handle metadata field properly
loc_dict = {
'id': str(loc.id),
'tenant_id': str(loc.tenant_id),
'name': loc.name,
'location_type': loc.location_type,
'address': loc.address,
'city': loc.city,
'postal_code': loc.postal_code,
'latitude': loc.latitude,
'longitude': loc.longitude,
'contact_person': loc.contact_person,
'contact_phone': loc.contact_phone,
'contact_email': loc.contact_email,
'is_active': loc.is_active,
'delivery_windows': loc.delivery_windows,
'operational_hours': loc.operational_hours,
'capacity': loc.capacity,
'max_delivery_radius_km': loc.max_delivery_radius_km,
'delivery_schedule_config': loc.delivery_schedule_config,
'metadata': loc.metadata_, # Use the actual column name to avoid conflict
'created_at': loc.created_at,
'updated_at': loc.updated_at
}
location_responses.append(TenantLocationResponse.model_validate(loc_dict))
return TenantLocationsResponse(
locations=location_responses,
total=len(location_responses)
)
except HTTPException:
raise
except Exception as e:
logger.error("Get tenant locations by type failed",
tenant_id=str(tenant_id),
location_type=location_type,
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Get tenant locations by type failed"
)

View File

@@ -0,0 +1,483 @@
"""
Tenant Member Management API - ATOMIC operations
Handles team member CRUD operations
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from typing import List, Dict, Any
from uuid import UUID
from app.schemas.tenants import TenantMemberResponse, AddMemberWithUserCreate, TenantResponse
from app.services.tenant_service import EnhancedTenantService
from shared.auth.decorators import get_current_user_dep
from shared.routing.route_builder import RouteBuilder
from shared.database.base import create_database_manager
from shared.monitoring.metrics import track_endpoint_metrics
logger = structlog.get_logger()
router = APIRouter()
route_builder = RouteBuilder("tenants")
# Dependency injection for enhanced tenant service
def get_enhanced_tenant_service():
try:
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
return EnhancedTenantService(database_manager)
except Exception as e:
logger.error("Failed to create enhanced tenant service", error=str(e))
raise HTTPException(status_code=500, detail="Service initialization failed")
@router.post(route_builder.build_base_route("{tenant_id}/members/with-user", include_tenant_prefix=False), response_model=TenantMemberResponse)
@track_endpoint_metrics("tenant_add_member_with_user_creation")
async def add_team_member_with_user_creation(
member_data: AddMemberWithUserCreate,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Add a team member to tenant with optional user creation (pilot phase).
This endpoint supports two modes:
1. Adding an existing user: Set user_id and create_user=False
2. Creating a new user: Set create_user=True and provide email, full_name, password
In pilot phase, this allows owners to directly create users with passwords.
In production, this will be replaced with an invitation-based flow.
"""
try:
# CRITICAL: Check subscription limit before adding user
from app.services.subscription_limit_service import SubscriptionLimitService
limit_service = SubscriptionLimitService()
limit_check = await limit_service.can_add_user(str(tenant_id))
if not limit_check.get('can_add', False):
logger.warning(
"User limit exceeded",
tenant_id=str(tenant_id),
current=limit_check.get('current_count'),
max=limit_check.get('max_allowed'),
reason=limit_check.get('reason')
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "user_limit_exceeded",
"message": limit_check.get('reason', 'User limit exceeded'),
"current_count": limit_check.get('current_count'),
"max_allowed": limit_check.get('max_allowed'),
"upgrade_required": True
}
)
user_id_to_add = member_data.user_id
# If create_user is True, create the user first via auth service
if member_data.create_user:
logger.info(
"Creating new user before adding to tenant",
tenant_id=str(tenant_id),
email=member_data.email,
requested_by=current_user["user_id"]
)
# Call auth service to create user
from shared.clients.auth_client import AuthServiceClient
from app.core.config import settings
auth_client = AuthServiceClient(settings)
# Map tenant role to user role
# tenant roles: admin, member, viewer
# user roles: admin, manager, user
user_role_map = {
"admin": "admin",
"member": "manager",
"viewer": "user"
}
user_role = user_role_map.get(member_data.role, "user")
try:
user_create_data = {
"email": member_data.email,
"full_name": member_data.full_name,
"password": member_data.password,
"phone": member_data.phone,
"role": user_role,
"language": member_data.language or "es",
"timezone": member_data.timezone or "Europe/Madrid"
}
created_user = await auth_client.create_user_by_owner(user_create_data)
user_id_to_add = created_user.get("id")
logger.info(
"User created successfully",
user_id=user_id_to_add,
email=member_data.email,
tenant_id=str(tenant_id)
)
except Exception as auth_error:
logger.error(
"Failed to create user via auth service",
error=str(auth_error),
email=member_data.email
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to create user account: {str(auth_error)}"
)
# Add the user (existing or newly created) to the tenant
result = await tenant_service.add_team_member(
str(tenant_id),
user_id_to_add,
member_data.role,
current_user["user_id"]
)
logger.info(
"Team member added successfully",
tenant_id=str(tenant_id),
user_id=user_id_to_add,
role=member_data.role,
user_was_created=member_data.create_user
)
return result
except HTTPException:
raise
except Exception as e:
logger.error(
"Add team member with user creation failed",
tenant_id=str(tenant_id),
error=str(e)
)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add team member"
)
@router.post(route_builder.build_base_route("{tenant_id}/members", include_tenant_prefix=False), response_model=TenantMemberResponse)
@track_endpoint_metrics("tenant_add_member")
async def add_team_member(
user_id: str,
role: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Add an existing team member to tenant (legacy endpoint)"""
try:
# CRITICAL: Check subscription limit before adding user
from app.services.subscription_limit_service import SubscriptionLimitService
limit_service = SubscriptionLimitService()
limit_check = await limit_service.can_add_user(str(tenant_id))
if not limit_check.get('can_add', False):
logger.warning(
"User limit exceeded",
tenant_id=str(tenant_id),
current=limit_check.get('current_count'),
max=limit_check.get('max_allowed'),
reason=limit_check.get('reason')
)
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail={
"error": "user_limit_exceeded",
"message": limit_check.get('reason', 'User limit exceeded'),
"current_count": limit_check.get('current_count'),
"max_allowed": limit_check.get('max_allowed'),
"upgrade_required": True
}
)
result = await tenant_service.add_team_member(
str(tenant_id),
user_id,
role,
current_user["user_id"]
)
logger.info(
"Team member added successfully",
tenant_id=str(tenant_id),
user_id=user_id,
role=role
)
return result
except HTTPException:
raise
except Exception as e:
logger.error("Add team member failed",
tenant_id=str(tenant_id),
user_id=user_id,
role=role,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add team member"
)
@router.get(route_builder.build_base_route("{tenant_id}/members", include_tenant_prefix=False), response_model=List[TenantMemberResponse])
async def get_team_members(
tenant_id: UUID = Path(..., description="Tenant ID"),
active_only: bool = Query(True, description="Only return active members"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get all team members for a tenant with enhanced filtering"""
try:
members = await tenant_service.get_team_members(
str(tenant_id),
current_user["user_id"],
active_only=active_only
)
return members
except HTTPException:
raise
except Exception as e:
logger.error("Get team members failed",
tenant_id=str(tenant_id),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get team members"
)
@router.put(route_builder.build_base_route("{tenant_id}/members/{member_user_id}/role", include_tenant_prefix=False), response_model=TenantMemberResponse)
@track_endpoint_metrics("tenant_update_member_role")
async def update_member_role(
new_role: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
member_user_id: str = Path(..., description="Member user ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Update team member role with enhanced permission validation"""
try:
result = await tenant_service.update_member_role(
str(tenant_id),
member_user_id,
new_role,
current_user["user_id"]
)
return result
except HTTPException:
raise
except Exception as e:
logger.error("Update member role failed",
tenant_id=str(tenant_id),
member_user_id=member_user_id,
new_role=new_role,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update member role"
)
@router.delete(route_builder.build_base_route("{tenant_id}/members/{member_user_id}", include_tenant_prefix=False))
@track_endpoint_metrics("tenant_remove_member")
async def remove_team_member(
tenant_id: UUID = Path(..., description="Tenant ID"),
member_user_id: str = Path(..., description="Member user ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Remove team member from tenant with enhanced validation"""
try:
success = await tenant_service.remove_team_member(
str(tenant_id),
member_user_id,
current_user["user_id"]
)
if success:
return {"success": True, "message": "Team member removed successfully"}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to remove team member"
)
except HTTPException:
raise
except Exception as e:
logger.error("Remove team member failed",
tenant_id=str(tenant_id),
member_user_id=member_user_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to remove team member"
)
@router.delete(route_builder.build_base_route("user/{user_id}/memberships", include_tenant_prefix=False))
@track_endpoint_metrics("user_memberships_delete")
async def delete_user_memberships(
user_id: str = Path(..., description="User ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Delete all tenant memberships for a user.
Used by auth service when deleting a user account.
Only accessible by internal services.
"""
logger.info(
"Delete user memberships request received",
user_id=user_id,
requesting_service=current_user.get("service", "unknown"),
is_service=current_user.get("type") == "service"
)
# Only allow internal service calls
if current_user.get("type") != "service":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="This endpoint is only accessible to internal services"
)
try:
result = await tenant_service.delete_user_memberships(user_id)
logger.info(
"User memberships deleted successfully",
user_id=user_id,
deleted_count=result.get("deleted_count"),
total_memberships=result.get("total_memberships")
)
return {
"message": "User memberships deleted successfully",
"summary": result
}
except HTTPException:
raise
except Exception as e:
logger.error("Delete user memberships failed",
user_id=user_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to delete user memberships"
)
@router.post(route_builder.build_base_route("{tenant_id}/transfer-ownership", include_tenant_prefix=False), response_model=TenantResponse)
@track_endpoint_metrics("tenant_transfer_ownership")
async def transfer_ownership(
new_owner_id: str,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Transfer tenant ownership to another admin.
Only the current owner or internal services can perform this action.
"""
logger.info(
"Transfer ownership request received",
tenant_id=str(tenant_id),
new_owner_id=new_owner_id,
requesting_user=current_user.get("user_id"),
is_service=current_user.get("type") == "service"
)
try:
# Get current tenant to find current owner
tenant_info = await tenant_service.get_tenant_by_id(str(tenant_id))
if not tenant_info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found"
)
current_owner_id = tenant_info.owner_id
result = await tenant_service.transfer_tenant_ownership(
str(tenant_id),
current_owner_id,
new_owner_id,
requesting_user_id=current_user.get("user_id")
)
logger.info(
"Ownership transferred successfully",
tenant_id=str(tenant_id),
from_owner=current_owner_id,
to_owner=new_owner_id
)
return result
except HTTPException:
raise
except Exception as e:
logger.error("Transfer ownership failed",
tenant_id=str(tenant_id),
new_owner_id=new_owner_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to transfer ownership"
)
@router.get(route_builder.build_base_route("{tenant_id}/admins", include_tenant_prefix=False), response_model=List[TenantMemberResponse])
@track_endpoint_metrics("tenant_get_admins")
async def get_tenant_admins(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""
Get all admins (owner + admins) for a tenant.
Used by auth service to check for other admins before tenant deletion.
"""
logger.info(
"Get tenant admins request received",
tenant_id=str(tenant_id),
requesting_user=current_user.get("user_id"),
is_service=current_user.get("type") == "service"
)
try:
admins = await tenant_service.get_tenant_admins(str(tenant_id))
logger.info(
"Retrieved tenant admins",
tenant_id=str(tenant_id),
admin_count=len(admins)
)
return admins
except HTTPException:
raise
except Exception as e:
logger.error("Get tenant admins failed",
tenant_id=str(tenant_id),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant admins"
)

View File

@@ -0,0 +1,734 @@
"""
Tenant Operations API - BUSINESS operations
Handles complex tenant operations, registration, search, and analytics
NOTE: All subscription-related endpoints have been moved to subscription.py
as part of the architecture redesign for better separation of concerns.
"""
import structlog
from datetime import datetime
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from typing import List, Dict, Any, Optional
from uuid import UUID
from app.schemas.tenants import (
BakeryRegistration, TenantResponse, TenantAccessResponse,
TenantSearchRequest
)
from app.services.tenant_service import EnhancedTenantService
from app.services.payment_service import PaymentService
from app.models import AuditLog
from shared.auth.decorators import (
get_current_user_dep,
require_admin_role_dep
)
from app.core.database import get_db
from sqlalchemy.ext.asyncio import AsyncSession
from shared.auth.access_control import owner_role_required, admin_role_required
from shared.routing.route_builder import RouteBuilder
from shared.database.base import create_database_manager
from shared.monitoring.metrics import track_endpoint_metrics
from shared.security import create_audit_logger, AuditSeverity, AuditAction
from shared.config.base import is_internal_service
logger = structlog.get_logger()
router = APIRouter()
route_builder = RouteBuilder("tenants")
# Initialize audit logger
audit_logger = create_audit_logger("tenant-service", AuditLog)
# Dependency injection for enhanced tenant service
def get_enhanced_tenant_service():
try:
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
return EnhancedTenantService(database_manager)
except Exception as e:
logger.error("Failed to create enhanced tenant service", error=str(e))
raise HTTPException(status_code=500, detail="Service initialization failed")
def get_payment_service():
try:
return PaymentService()
except Exception as e:
logger.error("Failed to create payment service", error=str(e))
raise HTTPException(status_code=500, detail="Payment service initialization failed")
# ============================================================================
# TENANT REGISTRATION & ACCESS OPERATIONS
# ============================================================================
@router.post(route_builder.build_base_route("register", include_tenant_prefix=False), response_model=TenantResponse)
async def register_bakery(
bakery_data: BakeryRegistration,
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service),
payment_service: PaymentService = Depends(get_payment_service),
db: AsyncSession = Depends(get_db)
):
"""Register a new bakery/tenant with enhanced validation and features"""
try:
coupon_validation = None
success = None
discount = None
error = None
result = await tenant_service.create_bakery(
bakery_data,
current_user["user_id"]
)
tenant_id = result.id
if bakery_data.link_existing_subscription and bakery_data.subscription_id:
logger.info("Linking existing subscription to new tenant",
tenant_id=tenant_id,
subscription_id=bakery_data.subscription_id,
user_id=current_user["user_id"])
try:
from app.services.subscription_service import SubscriptionService
subscription_service = SubscriptionService(db)
linking_result = await subscription_service.link_subscription_to_tenant(
subscription_id=bakery_data.subscription_id,
tenant_id=tenant_id,
user_id=current_user["user_id"]
)
logger.info("Subscription linked successfully during tenant registration",
tenant_id=tenant_id,
subscription_id=bakery_data.subscription_id)
except Exception as linking_error:
logger.error("Error linking subscription during tenant registration",
tenant_id=tenant_id,
subscription_id=bakery_data.subscription_id,
error=str(linking_error))
elif bakery_data.coupon_code:
from app.services.coupon_service import CouponService
coupon_service = CouponService(db)
coupon_validation = await coupon_service.validate_coupon_code(
bakery_data.coupon_code,
tenant_id
)
if not coupon_validation["valid"]:
logger.warning(
"Invalid coupon code provided during registration",
coupon_code=bakery_data.coupon_code,
error=coupon_validation["error_message"]
)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=coupon_validation["error_message"]
)
success, discount, error = await coupon_service.redeem_coupon(
bakery_data.coupon_code,
tenant_id,
base_trial_days=0
)
if success:
logger.info("Coupon redeemed during registration",
coupon_code=bakery_data.coupon_code,
tenant_id=tenant_id)
else:
logger.warning("Failed to redeem coupon during registration",
coupon_code=bakery_data.coupon_code,
error=error)
else:
try:
from app.repositories.subscription_repository import SubscriptionRepository
from app.models.tenants import Subscription
from datetime import datetime, timedelta, timezone
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
async with database_manager.get_session() as session:
subscription_repo = SubscriptionRepository(Subscription, session)
existing_subscription = await subscription_repo.get_by_tenant_id(str(result.id))
if existing_subscription:
logger.info(
"Tenant already has an active subscription, skipping default subscription creation",
tenant_id=str(result.id),
existing_plan=existing_subscription.plan,
subscription_id=str(existing_subscription.id)
)
else:
trial_end_date = datetime.now(timezone.utc)
next_billing_date = trial_end_date
await subscription_repo.create_subscription({
"tenant_id": str(result.id),
"plan": "starter",
"status": "trial",
"billing_cycle": "monthly",
"next_billing_date": next_billing_date,
"trial_ends_at": trial_end_date
})
await session.commit()
logger.info(
"Default subscription created for new tenant",
tenant_id=str(result.id),
plan="starter",
trial_days=0
)
except Exception as subscription_error:
logger.error(
"Failed to create default subscription for tenant",
tenant_id=str(result.id),
error=str(subscription_error)
)
if coupon_validation and coupon_validation["valid"]:
from app.core.config import settings
from app.services.coupon_service import CouponService
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
async with database_manager.get_session() as session:
coupon_service = CouponService(session)
success, discount, error = await coupon_service.redeem_coupon(
bakery_data.coupon_code,
result.id,
base_trial_days=0
)
if success:
logger.info(
"Coupon redeemed successfully",
tenant_id=result.id,
coupon_code=bakery_data.coupon_code,
discount=discount
)
else:
logger.warning(
"Failed to redeem coupon after registration",
tenant_id=result.id,
coupon_code=bakery_data.coupon_code,
error=error
)
logger.info("Bakery registered successfully",
name=bakery_data.name,
owner_email=current_user.get('email'),
tenant_id=result.id,
coupon_applied=bakery_data.coupon_code is not None)
return result
except HTTPException:
raise
except Exception as e:
logger.error("Bakery registration failed",
name=bakery_data.name,
owner_id=current_user["user_id"],
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Bakery registration failed"
)
@router.get(route_builder.build_base_route("{tenant_id}/my-access", include_tenant_prefix=False), response_model=TenantAccessResponse)
async def get_current_user_tenant_access(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep)
):
"""Get current user's access to tenant with role and permissions"""
try:
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
tenant_service = EnhancedTenantService(database_manager)
access_info = await tenant_service.verify_user_access(current_user["user_id"], str(tenant_id))
return access_info
except Exception as e:
logger.error("Current user access verification failed",
user_id=current_user["user_id"],
tenant_id=str(tenant_id),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Access verification failed"
)
@router.get(route_builder.build_base_route("{tenant_id}/access/{user_id}", include_tenant_prefix=False), response_model=TenantAccessResponse)
async def verify_tenant_access(
tenant_id: UUID = Path(..., description="Tenant ID"),
user_id: str = Path(..., description="User ID")
):
"""Verify if user has access to tenant - Enhanced version with detailed permissions"""
if is_internal_service(user_id):
logger.info("Service access granted", service=user_id, tenant_id=str(tenant_id))
return TenantAccessResponse(
has_access=True,
role="service",
permissions=["read", "write"]
)
try:
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
tenant_service = EnhancedTenantService(database_manager)
access_info = await tenant_service.verify_user_access(user_id, str(tenant_id))
return access_info
except Exception as e:
logger.error("Access verification failed",
user_id=user_id,
tenant_id=str(tenant_id),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Access verification failed"
)
# ============================================================================
# TENANT SEARCH & DISCOVERY OPERATIONS
# ============================================================================
@router.get(route_builder.build_base_route("subdomain/{subdomain}", include_tenant_prefix=False), response_model=TenantResponse)
@track_endpoint_metrics("tenant_get_by_subdomain")
async def get_tenant_by_subdomain(
subdomain: str = Path(..., description="Tenant subdomain"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get tenant by subdomain with enhanced validation"""
tenant = await tenant_service.get_tenant_by_subdomain(subdomain)
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found"
)
access = await tenant_service.verify_user_access(current_user["user_id"], tenant.id)
if not access.has_access:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant"
)
return tenant
@router.get(route_builder.build_base_route("user/{user_id}/owned", include_tenant_prefix=False), response_model=List[TenantResponse])
async def get_user_owned_tenants(
user_id: str = Path(..., description="User ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get all tenants owned by a user with enhanced data"""
user_role = current_user.get('role', '').lower()
is_demo_user = current_user.get("is_demo", False) and user_id == "demo-user"
if user_id != current_user["user_id"] and not is_demo_user and user_role != 'admin':
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Can only access your own tenants"
)
if current_user.get("is_demo", False):
demo_session_id = current_user.get("demo_session_id")
demo_account_type = current_user.get("demo_account_type", "")
if demo_session_id:
logger.info("Fetching virtual tenants for demo session",
demo_session_id=demo_session_id,
demo_account_type=demo_account_type)
virtual_tenants = await tenant_service.get_virtual_tenants_for_session(demo_session_id, demo_account_type)
return virtual_tenants
else:
virtual_tenants = await tenant_service.get_demo_tenants_by_session_type(
demo_account_type,
str(current_user["user_id"])
)
return virtual_tenants
actual_user_id = current_user["user_id"] if is_demo_user else user_id
tenants = await tenant_service.get_user_tenants(actual_user_id)
return tenants
@router.get(route_builder.build_base_route("search", include_tenant_prefix=False), response_model=List[TenantResponse])
@track_endpoint_metrics("tenant_search")
async def search_tenants(
search_term: str = Query(..., description="Search term"),
business_type: Optional[str] = Query(None, description="Business type filter"),
city: Optional[str] = Query(None, description="City filter"),
skip: int = Query(0, ge=0, description="Number of records to skip"),
limit: int = Query(50, ge=1, le=100, description="Maximum number of records to return"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Search tenants with advanced filters and pagination"""
tenants = await tenant_service.search_tenants(
search_term=search_term,
business_type=business_type,
city=city,
skip=skip,
limit=limit
)
return tenants
@router.get(route_builder.build_base_route("nearby", include_tenant_prefix=False), response_model=List[TenantResponse])
@track_endpoint_metrics("tenant_get_nearby")
async def get_nearby_tenants(
latitude: float = Query(..., description="Latitude coordinate"),
longitude: float = Query(..., description="Longitude coordinate"),
radius_km: float = Query(10.0, ge=0.1, le=100.0, description="Search radius in kilometers"),
limit: int = Query(50, ge=1, le=100, description="Maximum number of results"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get tenants near a geographic location with enhanced geospatial search"""
tenants = await tenant_service.get_tenants_near_location(
latitude=latitude,
longitude=longitude,
radius_km=radius_km,
limit=limit
)
return tenants
@router.get(route_builder.build_base_route("users/{user_id}", include_tenant_prefix=False), response_model=List[TenantResponse])
@track_endpoint_metrics("tenant_get_user_tenants")
async def get_user_tenants(
user_id: str = Path(..., description="User ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get all tenants owned by a user - Fixed endpoint for frontend"""
is_demo_user = current_user.get("is_demo", False)
is_service_account = current_user.get("type") == "service"
user_role = current_user.get('role', '').lower()
if user_id != current_user["user_id"] and not is_service_account and not (is_demo_user and user_id == "demo-user") and user_role != 'admin':
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Can only access your own tenants"
)
try:
tenants = await tenant_service.get_user_tenants(user_id)
logger.info("Retrieved user tenants", user_id=user_id, tenant_count=len(tenants))
return tenants
except Exception as e:
logger.error("Get user tenants failed", user_id=user_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get user tenants"
)
@router.get(route_builder.build_base_route("members/user/{user_id}", include_tenant_prefix=False))
@track_endpoint_metrics("tenant_get_user_memberships")
async def get_user_memberships(
user_id: str = Path(..., description="User ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get all tenant memberships for a user (for authentication service)"""
is_demo_user = current_user.get("is_demo", False)
is_service_account = current_user.get("type") == "service"
user_role = current_user.get('role', '').lower()
if user_id != current_user["user_id"] and not is_service_account and not (is_demo_user and user_id == "demo-user") and user_role != 'admin':
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Can only access your own memberships"
)
try:
memberships = await tenant_service.get_user_memberships(user_id)
logger.info("Retrieved user memberships", user_id=user_id, membership_count=len(memberships))
return memberships
except Exception as e:
logger.error("Get user memberships failed", user_id=user_id, error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get user memberships"
)
# ============================================================================
# TENANT MODEL STATUS OPERATIONS
# ============================================================================
@router.put(route_builder.build_base_route("{tenant_id}/model-status", include_tenant_prefix=False))
@track_endpoint_metrics("tenant_update_model_status")
async def update_tenant_model_status(
tenant_id: UUID = Path(..., description="Tenant ID"),
ml_model_trained: bool = Query(..., description="Whether model is trained"),
last_training_date: Optional[datetime] = Query(None, description="Last training date"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Update tenant model training status with enhanced tracking"""
try:
result = await tenant_service.update_model_status(
str(tenant_id),
ml_model_trained,
current_user["user_id"],
last_training_date
)
return result
except HTTPException:
raise
except Exception as e:
logger.error("Model status update failed",
tenant_id=str(tenant_id),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to update model status"
)
# ============================================================================
# TENANT ACTIVATION/DEACTIVATION OPERATIONS
# ============================================================================
@router.post(route_builder.build_base_route("{tenant_id}/deactivate", include_tenant_prefix=False))
@track_endpoint_metrics("tenant_deactivate")
@owner_role_required
async def deactivate_tenant(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Deactivate a tenant (owner only) with enhanced validation"""
try:
success = await tenant_service.deactivate_tenant(
str(tenant_id),
current_user["user_id"]
)
if success:
try:
from app.core.database import get_db_session
async with get_db_session() as db:
await audit_logger.log_event(
db_session=db,
tenant_id=str(tenant_id),
user_id=current_user["user_id"],
action=AuditAction.DEACTIVATE.value,
resource_type="tenant",
resource_id=str(tenant_id),
severity=AuditSeverity.CRITICAL.value,
description=f"Owner {current_user.get('email', current_user['user_id'])} deactivated tenant",
endpoint="/{tenant_id}/deactivate",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return {"success": True, "message": "Tenant deactivated successfully"}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to deactivate tenant"
)
except HTTPException:
raise
except Exception as e:
logger.error("Tenant deactivation failed",
tenant_id=str(tenant_id),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to deactivate tenant"
)
@router.post(route_builder.build_base_route("{tenant_id}/activate", include_tenant_prefix=False))
@track_endpoint_metrics("tenant_activate")
@owner_role_required
async def activate_tenant(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Activate a previously deactivated tenant (owner only) with enhanced validation"""
try:
success = await tenant_service.activate_tenant(
str(tenant_id),
current_user["user_id"]
)
if success:
try:
from app.core.database import get_db_session
async with get_db_session() as db:
await audit_logger.log_event(
db_session=db,
tenant_id=str(tenant_id),
user_id=current_user["user_id"],
action=AuditAction.ACTIVATE.value,
resource_type="tenant",
resource_id=str(tenant_id),
severity=AuditSeverity.HIGH.value,
description=f"Owner {current_user.get('email', current_user['user_id'])} activated tenant",
endpoint="/{tenant_id}/activate",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))
return {"success": True, "message": "Tenant activated successfully"}
else:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to activate tenant"
)
except HTTPException:
raise
except Exception as e:
logger.error("Tenant activation failed",
tenant_id=str(tenant_id),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to activate tenant"
)
# ============================================================================
# TENANT STATISTICS & ANALYTICS
# ============================================================================
@router.get(route_builder.build_base_route("statistics", include_tenant_prefix=False), dependencies=[Depends(require_admin_role_dep)])
@track_endpoint_metrics("tenant_get_statistics")
async def get_tenant_statistics(
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get comprehensive tenant statistics (admin only) with enhanced analytics"""
try:
stats = await tenant_service.get_tenant_statistics()
return stats
except Exception as e:
logger.error("Get tenant statistics failed", error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get tenant statistics"
)
# ============================================================================
# USER-TENANT RELATIONSHIP OPERATIONS
# ============================================================================
@router.get(route_builder.build_base_route("users/{user_id}/primary-tenant", include_tenant_prefix=False))
async def get_user_primary_tenant(
user_id: str,
current_user: Dict[str, Any] = Depends(get_current_user_dep)
):
"""
Get the primary tenant for a user
This endpoint is used by the auth service to validate user subscriptions
during login. It returns the user's primary tenant (the one they own or
have primary access to).
Args:
user_id: The user ID to look up
Returns:
Dictionary with user's primary tenant information, or None if no tenant found
Example Response:
{
"user_id": "user-uuid",
"tenant_id": "tenant-uuid",
"tenant_name": "Bakery Name",
"tenant_type": "standalone",
"is_owner": true
}
"""
try:
from app.core.database import database_manager
from app.repositories.tenant_repository import TenantRepository
from app.models.tenants import Tenant
async with database_manager.get_session() as session:
tenant_repo = TenantRepository(Tenant, session)
# Get user's primary tenant (the one they own)
primary_tenant = await tenant_repo.get_user_primary_tenant(user_id)
if primary_tenant:
logger.info("Found primary tenant for user",
user_id=user_id,
tenant_id=str(primary_tenant.id),
tenant_name=primary_tenant.name)
return {
'user_id': user_id,
'tenant_id': str(primary_tenant.id),
'tenant_name': primary_tenant.name,
'tenant_type': primary_tenant.tenant_type,
'is_owner': True
}
else:
# If no primary tenant found, check if user has access to any tenant
any_tenant = await tenant_repo.get_any_user_tenant(user_id)
if any_tenant:
logger.info("Found accessible tenant for user",
user_id=user_id,
tenant_id=str(any_tenant.id),
tenant_name=any_tenant.name)
return {
'user_id': user_id,
'tenant_id': str(any_tenant.id),
'tenant_name': any_tenant.name,
'tenant_type': any_tenant.tenant_type,
'is_owner': False
}
else:
logger.info("No tenant found for user", user_id=user_id)
return None
except Exception as e:
logger.error(f"Failed to get primary tenant for user {user_id}: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get primary tenant: {str(e)}")

View File

@@ -0,0 +1,186 @@
# services/tenant/app/api/tenant_settings.py
"""
Tenant Settings API Endpoints
REST API for managing tenant-specific operational settings
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from uuid import UUID
from typing import Dict, Any
from app.core.database import get_db
from shared.routing.route_builder import RouteBuilder
from ..services.tenant_settings_service import TenantSettingsService
from ..schemas.tenant_settings import (
TenantSettingsResponse,
TenantSettingsUpdate,
CategoryUpdateRequest,
CategoryResetResponse
)
router = APIRouter()
route_builder = RouteBuilder("tenants")
@router.get(
"/{tenant_id}/settings",
response_model=TenantSettingsResponse,
summary="Get all tenant settings",
description="Retrieve all operational settings for a tenant. Creates default settings if none exist."
)
async def get_tenant_settings(
tenant_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Get all settings for a tenant
- **tenant_id**: UUID of the tenant
Returns all setting categories with their current values.
If settings don't exist, default values are created and returned.
"""
service = TenantSettingsService(db)
settings = await service.get_settings(tenant_id)
return settings
@router.put(
"/{tenant_id}/settings",
response_model=TenantSettingsResponse,
summary="Update tenant settings",
description="Update one or more setting categories for a tenant. Only provided categories are updated."
)
async def update_tenant_settings(
tenant_id: UUID,
updates: TenantSettingsUpdate,
db: AsyncSession = Depends(get_db)
):
"""
Update tenant settings
- **tenant_id**: UUID of the tenant
- **updates**: Object containing setting categories to update
Only provided categories will be updated. Omitted categories remain unchanged.
All values are validated against min/max constraints.
"""
service = TenantSettingsService(db)
settings = await service.update_settings(tenant_id, updates)
return settings
@router.get(
"/{tenant_id}/settings/{category}",
response_model=Dict[str, Any],
summary="Get settings for a specific category",
description="Retrieve settings for a single category (procurement, inventory, production, supplier, pos, or order)"
)
async def get_category_settings(
tenant_id: UUID,
category: str,
db: AsyncSession = Depends(get_db)
):
"""
Get settings for a specific category
- **tenant_id**: UUID of the tenant
- **category**: Category name (procurement, inventory, production, supplier, pos, order)
Returns settings for the specified category only.
Valid categories:
- procurement: Auto-approval and procurement planning settings
- inventory: Stock thresholds and temperature monitoring
- production: Capacity, quality, and scheduling settings
- supplier: Payment terms and performance thresholds
- pos: POS integration sync settings
- order: Discount and delivery settings
"""
service = TenantSettingsService(db)
category_settings = await service.get_category(tenant_id, category)
return {
"tenant_id": str(tenant_id),
"category": category,
"settings": category_settings
}
@router.put(
"/{tenant_id}/settings/{category}",
response_model=TenantSettingsResponse,
summary="Update settings for a specific category",
description="Update all or some fields within a single category"
)
async def update_category_settings(
tenant_id: UUID,
category: str,
request: CategoryUpdateRequest,
db: AsyncSession = Depends(get_db)
):
"""
Update settings for a specific category
- **tenant_id**: UUID of the tenant
- **category**: Category name
- **request**: Object containing the settings to update
Updates only the specified category. All values are validated.
"""
service = TenantSettingsService(db)
settings = await service.update_category(tenant_id, category, request.settings)
return settings
@router.post(
"/{tenant_id}/settings/{category}/reset",
response_model=CategoryResetResponse,
summary="Reset category to default values",
description="Reset a specific category to its default values"
)
async def reset_category_settings(
tenant_id: UUID,
category: str,
db: AsyncSession = Depends(get_db)
):
"""
Reset a category to default values
- **tenant_id**: UUID of the tenant
- **category**: Category name
Resets all settings in the specified category to their default values.
This operation cannot be undone.
"""
service = TenantSettingsService(db)
reset_settings = await service.reset_category(tenant_id, category)
return CategoryResetResponse(
category=category,
settings=reset_settings,
message=f"Category '{category}' has been reset to default values"
)
@router.delete(
"/{tenant_id}/settings",
status_code=status.HTTP_204_NO_CONTENT,
summary="Delete tenant settings",
description="Delete all settings for a tenant (used when tenant is deleted)"
)
async def delete_tenant_settings(
tenant_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Delete tenant settings
- **tenant_id**: UUID of the tenant
This endpoint is typically called automatically when a tenant is deleted.
It removes all setting data for the tenant.
"""
service = TenantSettingsService(db)
await service.delete_settings(tenant_id)
return None

View File

@@ -0,0 +1,285 @@
"""
Tenant API - ATOMIC operations
Handles basic CRUD operations for tenants
"""
import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Path, Query
from typing import Dict, Any, List
from uuid import UUID
from app.schemas.tenants import TenantResponse, TenantUpdate
from app.services.tenant_service import EnhancedTenantService
from shared.auth.decorators import get_current_user_dep
from shared.auth.access_control import admin_role_required
from shared.routing.route_builder import RouteBuilder
from shared.database.base import create_database_manager
from shared.monitoring.metrics import track_endpoint_metrics
logger = structlog.get_logger()
router = APIRouter()
route_builder = RouteBuilder("tenants")
# Dependency injection for enhanced tenant service
def get_enhanced_tenant_service():
try:
from app.core.config import settings
database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
return EnhancedTenantService(database_manager)
except Exception as e:
logger.error("Failed to create enhanced tenant service", error=str(e))
raise HTTPException(status_code=500, detail="Service initialization failed")
@router.get(route_builder.build_base_route("", include_tenant_prefix=False), response_model=List[TenantResponse])
@track_endpoint_metrics("tenants_list")
async def get_active_tenants(
skip: int = Query(0, ge=0, description="Number of records to skip"),
limit: int = Query(100, ge=1, le=1000, description="Maximum number of records to return"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get all active tenants - Available to service accounts and admins"""
logger.info(
"Get active tenants request received",
skip=skip,
limit=limit,
user_id=current_user.get("user_id"),
user_type=current_user.get("type", "user"),
is_service=current_user.get("type") == "service",
role=current_user.get("role"),
service_name=current_user.get("service", "none")
)
# Allow service accounts to call this endpoint
if current_user.get("type") != "service":
# For non-service users, could add additional role checks here if needed
logger.debug(
"Non-service user requesting active tenants",
user_id=current_user.get("user_id"),
role=current_user.get("role")
)
tenants = await tenant_service.get_active_tenants(skip=skip, limit=limit)
logger.debug(
"Get active tenants successful",
count=len(tenants),
skip=skip,
limit=limit
)
return tenants
@router.get(route_builder.build_base_route("{tenant_id}", include_tenant_prefix=False), response_model=TenantResponse)
@track_endpoint_metrics("tenant_get")
async def get_tenant(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get tenant by ID - ATOMIC operation - ENHANCED with logging"""
logger.info(
"Tenant GET request received",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
user_type=current_user.get("type", "user"),
is_service=current_user.get("type") == "service",
role=current_user.get("role"),
service_name=current_user.get("service", "none")
)
tenant = await tenant_service.get_tenant_by_id(str(tenant_id))
if not tenant:
logger.warning(
"Tenant not found",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id")
)
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Tenant not found"
)
logger.debug(
"Tenant GET request successful",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id")
)
return tenant
@router.put(route_builder.build_base_route("{tenant_id}", include_tenant_prefix=False), response_model=TenantResponse)
@admin_role_required
async def update_tenant(
update_data: TenantUpdate,
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Update tenant information - ATOMIC operation (Admin+ only)"""
try:
result = await tenant_service.update_tenant(
str(tenant_id),
update_data,
current_user["user_id"]
)
return result
except HTTPException:
raise
except Exception as e:
logger.error("Tenant update failed",
tenant_id=str(tenant_id),
user_id=current_user["user_id"],
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Tenant update failed"
)
@router.get(route_builder.build_base_route("user/{user_id}/tenants", include_tenant_prefix=False), response_model=List[TenantResponse])
@track_endpoint_metrics("user_tenants_list")
async def get_user_tenants(
user_id: str = Path(..., description="User ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Get all tenants accessible by a user"""
logger.info(
"Get user tenants request received",
user_id=user_id,
requesting_user=current_user.get("user_id"),
is_demo=current_user.get("is_demo", False)
)
# Allow demo users to access tenant information for demo-user
is_demo_user = current_user.get("is_demo", False)
is_service_account = current_user.get("type") == "service"
# For demo sessions, when frontend requests with "demo-user", use the actual demo owner ID
actual_user_id = user_id
if is_demo_user and user_id == "demo-user":
actual_user_id = current_user.get("user_id")
logger.info(
"Demo session: mapping demo-user to actual owner",
requested_user_id=user_id,
actual_user_id=actual_user_id
)
if current_user.get("user_id") != actual_user_id and not is_service_account and not (is_demo_user and user_id == "demo-user"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Can only access own tenants"
)
try:
# For demo sessions, use session-specific filtering to prevent cross-session data leakage
if is_demo_user:
demo_session_id = current_user.get("demo_session_id")
demo_account_type = current_user.get("demo_account_type", "professional")
logger.info(
"Demo session detected for get_user_tenants",
user_id=user_id,
actual_user_id=actual_user_id,
demo_session_id=demo_session_id,
demo_account_type=demo_account_type,
has_session_id=bool(demo_session_id)
)
if demo_session_id:
# Get only tenants for this specific demo session
tenants = await tenant_service.get_virtual_tenants_for_session(demo_session_id, demo_account_type)
logger.info(
"Get demo session tenants successful",
user_id=user_id,
demo_session_id=demo_session_id,
demo_account_type=demo_account_type,
tenant_count=len(tenants),
tenant_ids=[str(t.id) for t in tenants] if tenants else []
)
return tenants
else:
logger.warning(
"Demo user without session ID - falling back to regular user tenants",
user_id=actual_user_id
)
# Regular users or demo fallback: get tenants by ownership and membership
tenants = await tenant_service.get_user_tenants(actual_user_id)
logger.debug(
"Get user tenants successful",
user_id=user_id,
tenant_count=len(tenants)
)
return tenants
except HTTPException:
raise
except Exception as e:
logger.error("Get user tenants failed",
user_id=user_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get user tenants"
)
@router.delete(route_builder.build_base_route("{tenant_id}", include_tenant_prefix=False))
@track_endpoint_metrics("tenant_delete")
async def delete_tenant(
tenant_id: UUID = Path(..., description="Tenant ID"),
current_user: Dict[str, Any] = Depends(get_current_user_dep),
tenant_service: EnhancedTenantService = Depends(get_enhanced_tenant_service)
):
"""Delete tenant and all associated data - ATOMIC operation (Owner/Admin or System only)"""
logger.info(
"Tenant DELETE request received",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
user_type=current_user.get("type", "user"),
is_service=current_user.get("type") == "service",
role=current_user.get("role"),
service_name=current_user.get("service", "none")
)
try:
# Allow internal service calls to bypass admin check
skip_admin_check = current_user.get("type") == "service"
result = await tenant_service.delete_tenant(
str(tenant_id),
requesting_user_id=current_user.get("user_id"),
skip_admin_check=skip_admin_check
)
logger.info(
"Tenant DELETE request successful",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
deleted_items=result.get("deleted_items")
)
return {
"message": "Tenant deleted successfully",
"summary": result
}
except HTTPException:
raise
except Exception as e:
logger.error("Tenant deletion failed",
tenant_id=str(tenant_id),
user_id=current_user.get("user_id"),
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Tenant deletion failed"
)

View File

@@ -0,0 +1,357 @@
"""
Usage Forecasting API
This endpoint predicts when a tenant will hit their subscription limits
based on historical usage growth rates.
"""
from datetime import datetime, timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
import redis.asyncio as redis
from shared.auth.decorators import get_current_user_dep
from app.core.config import settings
from app.core.database import database_manager
from app.services.subscription_limit_service import SubscriptionLimitService
router = APIRouter(prefix="/usage-forecast", tags=["usage-forecast"])
class UsageDataPoint(BaseModel):
"""Single usage data point"""
date: str
value: int
class MetricForecast(BaseModel):
"""Forecast for a single metric"""
metric: str
label: str
current: int
limit: Optional[int] # None = unlimited
unit: str
daily_growth_rate: Optional[float] # None if not enough data
predicted_breach_date: Optional[str] # ISO date string, None if unlimited or no breach
days_until_breach: Optional[int] # None if unlimited or no breach
usage_percentage: float
status: str # 'safe', 'warning', 'critical', 'unlimited'
trend_data: List[UsageDataPoint] # 30-day history
class UsageForecastResponse(BaseModel):
"""Complete usage forecast response"""
tenant_id: str
forecasted_at: str
metrics: List[MetricForecast]
async def get_redis_client() -> redis.Redis:
"""Get Redis client for usage tracking"""
return redis.from_url(
settings.REDIS_URL,
encoding="utf-8",
decode_responses=True
)
async def get_usage_history(
redis_client: redis.Redis,
tenant_id: str,
metric: str,
days: int = 30
) -> List[UsageDataPoint]:
"""
Get historical usage data for a metric from Redis
Usage data is stored with keys like:
usage:daily:{tenant_id}:{metric}:{date}
"""
history = []
today = datetime.utcnow().date()
for i in range(days):
date = today - timedelta(days=i)
date_str = date.isoformat()
key = f"usage:daily:{tenant_id}:{metric}:{date_str}"
try:
value = await redis_client.get(key)
if value is not None:
history.append(UsageDataPoint(
date=date_str,
value=int(value)
))
except Exception as e:
print(f"Error fetching usage for {key}: {e}")
continue
# Return in chronological order (oldest first)
return list(reversed(history))
def calculate_growth_rate(history: List[UsageDataPoint]) -> Optional[float]:
"""
Calculate daily growth rate using linear regression
Returns average daily increase, or None if insufficient data
"""
if len(history) < 7: # Need at least 7 days of data
return None
# Simple linear regression
n = len(history)
sum_x = sum(range(n))
sum_y = sum(point.value for point in history)
sum_xy = sum(i * point.value for i, point in enumerate(history))
sum_x_squared = sum(i * i for i in range(n))
# Calculate slope (daily growth rate)
denominator = (n * sum_x_squared) - (sum_x ** 2)
if denominator == 0:
return None
slope = ((n * sum_xy) - (sum_x * sum_y)) / denominator
return max(slope, 0) # Can't have negative growth for breach prediction
def predict_breach_date(
current: int,
limit: int,
daily_growth_rate: float
) -> Optional[tuple[str, int]]:
"""
Predict when usage will breach the limit
Returns (breach_date_iso, days_until_breach) or None if no breach predicted
"""
if daily_growth_rate <= 0:
return None
remaining_capacity = limit - current
if remaining_capacity <= 0:
# Already at or over limit
return datetime.utcnow().date().isoformat(), 0
days_until_breach = int(remaining_capacity / daily_growth_rate)
if days_until_breach > 365: # Don't predict beyond 1 year
return None
breach_date = datetime.utcnow().date() + timedelta(days=days_until_breach)
return breach_date.isoformat(), days_until_breach
def determine_status(usage_percentage: float, days_until_breach: Optional[int]) -> str:
"""Determine metric status based on usage and time to breach"""
if usage_percentage >= 100:
return 'critical'
elif usage_percentage >= 90:
return 'critical'
elif usage_percentage >= 80 or (days_until_breach is not None and days_until_breach <= 14):
return 'warning'
else:
return 'safe'
@router.get("", response_model=UsageForecastResponse)
async def get_usage_forecast(
tenant_id: str = Query(..., description="Tenant ID"),
current_user: dict = Depends(get_current_user_dep)
) -> UsageForecastResponse:
"""
Get usage forecasts for all metrics
Predicts when the tenant will hit their subscription limits based on
historical usage growth rates from the past 30 days.
Returns predictions for:
- Users
- Locations
- Products
- Recipes
- Suppliers
- Training jobs (daily)
- Forecasts (daily)
- API calls (hourly average converted to daily)
- File storage
"""
# Initialize services
redis_client = await get_redis_client()
limit_service = SubscriptionLimitService(database_manager=database_manager)
try:
# Get current usage summary (includes limits)
usage_summary = await limit_service.get_usage_summary(tenant_id)
if not usage_summary or 'error' in usage_summary:
raise HTTPException(
status_code=404,
detail=f"No active subscription found for tenant {tenant_id}"
)
# Extract usage data
usage = usage_summary.get('usage', {})
# Define metrics to forecast
metric_configs = [
{
'key': 'users',
'label': 'Users',
'current': usage.get('users', {}).get('current', 0),
'limit': usage.get('users', {}).get('limit'),
'unit': ''
},
{
'key': 'locations',
'label': 'Locations',
'current': usage.get('locations', {}).get('current', 0),
'limit': usage.get('locations', {}).get('limit'),
'unit': ''
},
{
'key': 'products',
'label': 'Products',
'current': usage.get('products', {}).get('current', 0),
'limit': usage.get('products', {}).get('limit'),
'unit': ''
},
{
'key': 'recipes',
'label': 'Recipes',
'current': usage.get('recipes', {}).get('current', 0),
'limit': usage.get('recipes', {}).get('limit'),
'unit': ''
},
{
'key': 'suppliers',
'label': 'Suppliers',
'current': usage.get('suppliers', {}).get('current', 0),
'limit': usage.get('suppliers', {}).get('limit'),
'unit': ''
},
{
'key': 'training_jobs',
'label': 'Training Jobs',
'current': usage.get('training_jobs_today', {}).get('current', 0),
'limit': usage.get('training_jobs_today', {}).get('limit'),
'unit': '/day'
},
{
'key': 'forecasts',
'label': 'Forecasts',
'current': usage.get('forecasts_today', {}).get('current', 0),
'limit': usage.get('forecasts_today', {}).get('limit'),
'unit': '/day'
},
{
'key': 'api_calls',
'label': 'API Calls',
'current': usage.get('api_calls_this_hour', {}).get('current', 0),
'limit': usage.get('api_calls_this_hour', {}).get('limit'),
'unit': '/hour'
},
{
'key': 'storage',
'label': 'File Storage',
'current': int(usage.get('file_storage_used_gb', {}).get('current', 0)),
'limit': usage.get('file_storage_used_gb', {}).get('limit'),
'unit': ' GB'
}
]
forecasts: List[MetricForecast] = []
for config in metric_configs:
metric_key = config['key']
current = config['current']
limit = config['limit']
# Get usage history
history = await get_usage_history(redis_client, tenant_id, metric_key, days=30)
# Calculate usage percentage
if limit is None or limit == -1:
usage_percentage = 0.0
status = 'unlimited'
growth_rate = None
breach_date = None
days_until = None
else:
usage_percentage = (current / limit * 100) if limit > 0 else 0
# Calculate growth rate
growth_rate = calculate_growth_rate(history) if history else None
# Predict breach
if growth_rate is not None and growth_rate > 0:
breach_result = predict_breach_date(current, limit, growth_rate)
if breach_result:
breach_date, days_until = breach_result
else:
breach_date, days_until = None, None
else:
breach_date, days_until = None, None
# Determine status
status = determine_status(usage_percentage, days_until)
forecasts.append(MetricForecast(
metric=metric_key,
label=config['label'],
current=current,
limit=limit,
unit=config['unit'],
daily_growth_rate=growth_rate,
predicted_breach_date=breach_date,
days_until_breach=days_until,
usage_percentage=round(usage_percentage, 1),
status=status,
trend_data=history[-30:] # Last 30 days
))
return UsageForecastResponse(
tenant_id=tenant_id,
forecasted_at=datetime.utcnow().isoformat(),
metrics=forecasts
)
finally:
await redis_client.close()
@router.post("/track-usage")
async def track_daily_usage(
tenant_id: str,
metric: str,
value: int,
current_user: dict = Depends(get_current_user_dep)
):
"""
Manually track daily usage for a metric
This endpoint is called by services to record daily usage snapshots.
The data is stored in Redis with a 60-day TTL.
"""
redis_client = await get_redis_client()
try:
date_str = datetime.utcnow().date().isoformat()
key = f"usage:daily:{tenant_id}:{metric}:{date_str}"
# Store usage with 60-day TTL
await redis_client.setex(key, 60 * 24 * 60 * 60, str(value))
return {
"success": True,
"tenant_id": tenant_id,
"metric": metric,
"value": value,
"date": date_str
}
finally:
await redis_client.close()

View File

@@ -0,0 +1,97 @@
"""
Webhook endpoints for handling payment provider events
These endpoints receive events from payment providers like Stripe
All event processing is handled by SubscriptionOrchestrationService
"""
import structlog
import stripe
from fastapi import APIRouter, Depends, HTTPException, status, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.subscription_orchestration_service import SubscriptionOrchestrationService
from app.core.config import settings
from app.core.database import get_db
logger = structlog.get_logger()
router = APIRouter()
def get_subscription_orchestration_service(
db: AsyncSession = Depends(get_db)
) -> SubscriptionOrchestrationService:
"""Dependency injection for SubscriptionOrchestrationService"""
try:
return SubscriptionOrchestrationService(db)
except Exception as e:
logger.error("Failed to create subscription orchestration service", error=str(e))
raise HTTPException(status_code=500, detail="Service initialization failed")
@router.post("/webhooks/stripe")
async def stripe_webhook(
request: Request,
orchestration_service: SubscriptionOrchestrationService = Depends(get_subscription_orchestration_service)
):
"""
Stripe webhook endpoint to handle payment events
This endpoint verifies webhook signatures and processes Stripe events
"""
try:
# Get the payload and signature
payload = await request.body()
sig_header = request.headers.get('stripe-signature')
if not sig_header:
logger.error("Missing stripe-signature header")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing signature header"
)
# Verify the webhook signature
try:
event = stripe.Webhook.construct_event(
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
)
except stripe.error.SignatureVerificationError as e:
logger.error("Invalid webhook signature", error=str(e))
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid signature"
)
except ValueError as e:
logger.error("Invalid payload", error=str(e))
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid payload"
)
# Get event type and data
event_type = event['type']
event_data = event['data']['object']
logger.info("Processing Stripe webhook event",
event_type=event_type,
event_id=event.get('id'))
# Use orchestration service to handle the event
result = await orchestration_service.handle_payment_webhook(event_type, event_data)
logger.info("Webhook event processed via orchestration service",
event_type=event_type,
actions_taken=result.get("actions_taken", []))
return {"success": True, "event_type": event_type, "actions_taken": result.get("actions_taken", [])}
except HTTPException:
raise
except Exception as e:
logger.error("Error processing Stripe webhook", error=str(e), exc_info=True)
# Return 200 OK even on processing errors to prevent Stripe retries
# Only return 4xx for signature verification failures
return {
"success": False,
"error": "Webhook processing error",
"details": str(e)
}

View File

@@ -0,0 +1,308 @@
# services/tenant/app/api/whatsapp_admin.py
"""
WhatsApp Admin API Endpoints
Admin-only endpoints for managing WhatsApp phone number assignments
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from uuid import UUID
from typing import List, Optional
from pydantic import BaseModel, Field
import httpx
import os
from app.core.database import get_db
from app.models.tenant_settings import TenantSettings
from app.models.tenants import Tenant
router = APIRouter()
# ================================================================
# SCHEMAS
# ================================================================
class WhatsAppPhoneNumberInfo(BaseModel):
"""Information about a WhatsApp phone number from Meta API"""
id: str = Field(..., description="Phone Number ID")
display_phone_number: str = Field(..., description="Display phone number (e.g., +34 612 345 678)")
verified_name: str = Field(..., description="Verified business name")
quality_rating: str = Field(..., description="Quality rating (GREEN, YELLOW, RED)")
class TenantWhatsAppStatus(BaseModel):
"""WhatsApp status for a tenant"""
tenant_id: UUID
tenant_name: str
whatsapp_enabled: bool
phone_number_id: Optional[str] = None
display_phone_number: Optional[str] = None
class AssignPhoneNumberRequest(BaseModel):
"""Request to assign phone number to tenant"""
phone_number_id: str = Field(..., description="Meta WhatsApp Phone Number ID")
display_phone_number: str = Field(..., description="Display format (e.g., '+34 612 345 678')")
class AssignPhoneNumberResponse(BaseModel):
"""Response after assigning phone number"""
success: bool
message: str
tenant_id: UUID
phone_number_id: str
display_phone_number: str
# ================================================================
# ENDPOINTS
# ================================================================
@router.get(
"/admin/whatsapp/phone-numbers",
response_model=List[WhatsAppPhoneNumberInfo],
summary="List available WhatsApp phone numbers",
description="Get all phone numbers available in the master WhatsApp Business Account"
)
async def list_available_phone_numbers():
"""
List all phone numbers from the master WhatsApp Business Account
Requires:
- WHATSAPP_BUSINESS_ACCOUNT_ID environment variable
- WHATSAPP_ACCESS_TOKEN environment variable
Returns list of available phone numbers with their status
"""
business_account_id = os.getenv("WHATSAPP_BUSINESS_ACCOUNT_ID")
access_token = os.getenv("WHATSAPP_ACCESS_TOKEN")
api_version = os.getenv("WHATSAPP_API_VERSION", "v18.0")
if not business_account_id or not access_token:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="WhatsApp master account not configured. Set WHATSAPP_BUSINESS_ACCOUNT_ID and WHATSAPP_ACCESS_TOKEN environment variables."
)
try:
# Fetch phone numbers from Meta Graph API
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"https://graph.facebook.com/{api_version}/{business_account_id}/phone_numbers",
headers={"Authorization": f"Bearer {access_token}"},
params={
"fields": "id,display_phone_number,verified_name,quality_rating"
}
)
if response.status_code != 200:
error_data = response.json()
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Meta API error: {error_data.get('error', {}).get('message', 'Unknown error')}"
)
data = response.json()
phone_numbers = data.get("data", [])
return [
WhatsAppPhoneNumberInfo(
id=phone.get("id"),
display_phone_number=phone.get("display_phone_number"),
verified_name=phone.get("verified_name", ""),
quality_rating=phone.get("quality_rating", "UNKNOWN")
)
for phone in phone_numbers
]
except httpx.HTTPError as e:
raise HTTPException(
status_code=status.HTTP_502_BAD_GATEWAY,
detail=f"Failed to fetch phone numbers from Meta: {str(e)}"
)
@router.get(
"/admin/whatsapp/tenants",
response_model=List[TenantWhatsAppStatus],
summary="List all tenants with WhatsApp status",
description="Get WhatsApp configuration status for all tenants"
)
async def list_tenant_whatsapp_status(
db: AsyncSession = Depends(get_db)
):
"""
List all tenants with their WhatsApp configuration status
Returns:
- tenant_id: Tenant UUID
- tenant_name: Tenant name
- whatsapp_enabled: Whether WhatsApp is enabled
- phone_number_id: Assigned phone number ID (if any)
- display_phone_number: Display format (if any)
"""
# Query all tenants with their settings
query = select(Tenant, TenantSettings).outerjoin(
TenantSettings,
Tenant.id == TenantSettings.tenant_id
)
result = await db.execute(query)
rows = result.all()
tenant_statuses = []
for tenant, settings in rows:
notification_settings = settings.notification_settings if settings else {}
tenant_statuses.append(
TenantWhatsAppStatus(
tenant_id=tenant.id,
tenant_name=tenant.name,
whatsapp_enabled=notification_settings.get("whatsapp_enabled", False),
phone_number_id=notification_settings.get("whatsapp_phone_number_id", ""),
display_phone_number=notification_settings.get("whatsapp_display_phone_number", "")
)
)
return tenant_statuses
@router.post(
"/admin/whatsapp/tenants/{tenant_id}/assign-phone",
response_model=AssignPhoneNumberResponse,
summary="Assign phone number to tenant",
description="Assign a WhatsApp phone number from the master account to a tenant"
)
async def assign_phone_number_to_tenant(
tenant_id: UUID,
request: AssignPhoneNumberRequest,
db: AsyncSession = Depends(get_db)
):
"""
Assign a WhatsApp phone number to a tenant
- **tenant_id**: UUID of the tenant
- **phone_number_id**: Meta Phone Number ID from master account
- **display_phone_number**: Human-readable format (e.g., "+34 612 345 678")
This will:
1. Validate the tenant exists
2. Check if phone number is already assigned to another tenant
3. Update tenant's notification settings
4. Enable WhatsApp for the tenant
"""
# Verify tenant exists
tenant_query = select(Tenant).where(Tenant.id == tenant_id)
tenant_result = await db.execute(tenant_query)
tenant = tenant_result.scalar_one_or_none()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tenant {tenant_id} not found"
)
# Check if phone number is already assigned to another tenant
settings_query = select(TenantSettings).where(TenantSettings.tenant_id != tenant_id)
settings_result = await db.execute(settings_query)
all_settings = settings_result.scalars().all()
for settings in all_settings:
notification_settings = settings.notification_settings or {}
if notification_settings.get("whatsapp_phone_number_id") == request.phone_number_id:
# Get the other tenant's name
other_tenant_query = select(Tenant).where(Tenant.id == settings.tenant_id)
other_tenant_result = await db.execute(other_tenant_query)
other_tenant = other_tenant_result.scalar_one_or_none()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Phone number {request.display_phone_number} is already assigned to tenant '{other_tenant.name if other_tenant else 'Unknown'}'"
)
# Get or create tenant settings
settings_query = select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
settings_result = await db.execute(settings_query)
settings = settings_result.scalar_one_or_none()
if not settings:
# Create default settings
settings = TenantSettings(
tenant_id=tenant_id,
**TenantSettings.get_default_settings()
)
db.add(settings)
# Update notification settings
notification_settings = settings.notification_settings or {}
notification_settings["whatsapp_enabled"] = True
notification_settings["whatsapp_phone_number_id"] = request.phone_number_id
notification_settings["whatsapp_display_phone_number"] = request.display_phone_number
settings.notification_settings = notification_settings
await db.commit()
await db.refresh(settings)
return AssignPhoneNumberResponse(
success=True,
message=f"Phone number {request.display_phone_number} assigned to tenant '{tenant.name}'",
tenant_id=tenant_id,
phone_number_id=request.phone_number_id,
display_phone_number=request.display_phone_number
)
@router.delete(
"/admin/whatsapp/tenants/{tenant_id}/unassign-phone",
response_model=AssignPhoneNumberResponse,
summary="Unassign phone number from tenant",
description="Remove WhatsApp phone number assignment from a tenant"
)
async def unassign_phone_number_from_tenant(
tenant_id: UUID,
db: AsyncSession = Depends(get_db)
):
"""
Unassign WhatsApp phone number from a tenant
- **tenant_id**: UUID of the tenant
This will:
1. Clear the phone number assignment
2. Disable WhatsApp for the tenant
"""
# Get tenant settings
settings_query = select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
settings_result = await db.execute(settings_query)
settings = settings_result.scalar_one_or_none()
if not settings:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Settings not found for tenant {tenant_id}"
)
# Get current values for response
notification_settings = settings.notification_settings or {}
old_phone_id = notification_settings.get("whatsapp_phone_number_id", "")
old_display_phone = notification_settings.get("whatsapp_display_phone_number", "")
# Update notification settings
notification_settings["whatsapp_enabled"] = False
notification_settings["whatsapp_phone_number_id"] = ""
notification_settings["whatsapp_display_phone_number"] = ""
settings.notification_settings = notification_settings
await db.commit()
return AssignPhoneNumberResponse(
success=True,
message=f"Phone number unassigned from tenant",
tenant_id=tenant_id,
phone_number_id=old_phone_id,
display_phone_number=old_display_phone
)

View File

View File

@@ -0,0 +1,133 @@
# ================================================================
# TENANT SERVICE CONFIGURATION
# services/tenant/app/core/config.py
# ================================================================
"""
Tenant service configuration
Multi-tenant management and subscription handling
"""
from shared.config.base import BaseServiceSettings
import os
from typing import Dict, Tuple, ClassVar
class TenantSettings(BaseServiceSettings):
"""Tenant service specific settings"""
# Service Identity
APP_NAME: str = "Tenant Service"
SERVICE_NAME: str = "tenant-service"
DESCRIPTION: str = "Multi-tenant management and subscription service"
# Database configuration (secure approach - build from components)
@property
def DATABASE_URL(self) -> str:
"""Build database URL from secure components"""
# Try complete URL first (for backward compatibility)
complete_url = os.getenv("TENANT_DATABASE_URL")
if complete_url:
return complete_url
# Build from components (secure approach)
user = os.getenv("TENANT_DB_USER", "tenant_user")
password = os.getenv("TENANT_DB_PASSWORD", "tenant_pass123")
host = os.getenv("TENANT_DB_HOST", "localhost")
port = os.getenv("TENANT_DB_PORT", "5432")
name = os.getenv("TENANT_DB_NAME", "tenant_db")
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
# Redis Database (dedicated for tenant data)
REDIS_DB: int = 4
# Service URLs for usage tracking
RECIPES_SERVICE_URL: str = os.getenv("RECIPES_SERVICE_URL", "http://recipes-service:8004")
SUPPLIERS_SERVICE_URL: str = os.getenv("SUPPLIERS_SERVICE_URL", "http://suppliers-service:8005")
# Subscription Plans
DEFAULT_PLAN: str = os.getenv("DEFAULT_PLAN", "basic")
TRIAL_PERIOD_DAYS: int = int(os.getenv("TRIAL_PERIOD_DAYS", "0"))
# Plan Limits
BASIC_PLAN_LOCATIONS: int = int(os.getenv("BASIC_PLAN_LOCATIONS", "1"))
BASIC_PLAN_PREDICTIONS_PER_DAY: int = int(os.getenv("BASIC_PLAN_PREDICTIONS_PER_DAY", "100"))
BASIC_PLAN_DATA_RETENTION_DAYS: int = int(os.getenv("BASIC_PLAN_DATA_RETENTION_DAYS", "90"))
PREMIUM_PLAN_LOCATIONS: int = int(os.getenv("PREMIUM_PLAN_LOCATIONS", "5"))
PREMIUM_PLAN_PREDICTIONS_PER_DAY: int = int(os.getenv("PREMIUM_PLAN_PREDICTIONS_PER_DAY", "1000"))
PREMIUM_PLAN_DATA_RETENTION_DAYS: int = int(os.getenv("PREMIUM_PLAN_DATA_RETENTION_DAYS", "365"))
ENTERPRISE_PLAN_LOCATIONS: int = int(os.getenv("ENTERPRISE_PLAN_LOCATIONS", "50"))
ENTERPRISE_PLAN_PREDICTIONS_PER_DAY: int = int(os.getenv("ENTERPRISE_PLAN_PREDICTIONS_PER_DAY", "10000"))
ENTERPRISE_PLAN_DATA_RETENTION_DAYS: int = int(os.getenv("ENTERPRISE_PLAN_DATA_RETENTION_DAYS", "1095"))
# Billing Configuration
BILLING_ENABLED: bool = os.getenv("BILLING_ENABLED", "false").lower() == "true"
BILLING_CURRENCY: str = os.getenv("BILLING_CURRENCY", "EUR")
BILLING_CYCLE_DAYS: int = int(os.getenv("BILLING_CYCLE_DAYS", "30"))
# Stripe Proration Configuration
DEFAULT_PRORATION_BEHAVIOR: str = os.getenv("DEFAULT_PRORATION_BEHAVIOR", "create_prorations")
UPGRADE_PRORATION_BEHAVIOR: str = os.getenv("UPGRADE_PRORATION_BEHAVIOR", "create_prorations")
DOWNGRADE_PRORATION_BEHAVIOR: str = os.getenv("DOWNGRADE_PRORATION_BEHAVIOR", "none")
BILLING_CYCLE_CHANGE_PRORATION: str = os.getenv("BILLING_CYCLE_CHANGE_PRORATION", "create_prorations")
# Stripe Subscription Update Settings
STRIPE_BILLING_CYCLE_ANCHOR: str = os.getenv("STRIPE_BILLING_CYCLE_ANCHOR", "unchanged")
STRIPE_PAYMENT_BEHAVIOR: str = os.getenv("STRIPE_PAYMENT_BEHAVIOR", "error_if_incomplete")
ALLOW_IMMEDIATE_SUBSCRIPTION_CHANGES: bool = os.getenv("ALLOW_IMMEDIATE_SUBSCRIPTION_CHANGES", "true").lower() == "true"
# Resource Limits
MAX_API_CALLS_PER_MINUTE: int = int(os.getenv("MAX_API_CALLS_PER_MINUTE", "100"))
MAX_STORAGE_MB: int = int(os.getenv("MAX_STORAGE_MB", "1024"))
MAX_CONCURRENT_REQUESTS: int = int(os.getenv("MAX_CONCURRENT_REQUESTS", "10"))
# Spanish Business Configuration
SPANISH_TAX_RATE: float = float(os.getenv("SPANISH_TAX_RATE", "0.21")) # IVA 21%
INVOICE_LANGUAGE: str = os.getenv("INVOICE_LANGUAGE", "es")
SUPPORT_EMAIL: str = os.getenv("SUPPORT_EMAIL", "soporte@bakeryforecast.es")
# Onboarding
ONBOARDING_ENABLED: bool = os.getenv("ONBOARDING_ENABLED", "true").lower() == "true"
DEMO_DATA_ENABLED: bool = os.getenv("DEMO_DATA_ENABLED", "true").lower() == "true"
# Compliance
GDPR_COMPLIANCE_ENABLED: bool = True
DATA_EXPORT_ENABLED: bool = True
DATA_DELETION_ENABLED: bool = True
# Stripe Payment Configuration
STRIPE_PUBLISHABLE_KEY: str = os.getenv("STRIPE_PUBLISHABLE_KEY", "")
STRIPE_SECRET_KEY: str = os.getenv("STRIPE_SECRET_KEY", "")
STRIPE_WEBHOOK_SECRET: str = os.getenv("STRIPE_WEBHOOK_SECRET", "")
# Stripe Price IDs for subscription plans
STARTER_MONTHLY_PRICE_ID: str = os.getenv("STARTER_MONTHLY_PRICE_ID", "price_1Sp0p3IzCdnBmAVT2Gs7z5np")
STARTER_YEARLY_PRICE_ID: str = os.getenv("STARTER_YEARLY_PRICE_ID", "price_1Sp0twIzCdnBmAVTD1lNLedx")
PROFESSIONAL_MONTHLY_PRICE_ID: str = os.getenv("PROFESSIONAL_MONTHLY_PRICE_ID", "price_1Sp0w7IzCdnBmAVTp0Jxhh1u")
PROFESSIONAL_YEARLY_PRICE_ID: str = os.getenv("PROFESSIONAL_YEARLY_PRICE_ID", "price_1Sp0yAIzCdnBmAVTLoGl4QCb")
ENTERPRISE_MONTHLY_PRICE_ID: str = os.getenv("ENTERPRISE_MONTHLY_PRICE_ID", "price_1Sp0zAIzCdnBmAVTXpApF7YO")
ENTERPRISE_YEARLY_PRICE_ID: str = os.getenv("ENTERPRISE_YEARLY_PRICE_ID", "price_1Sp15mIzCdnBmAVTuxffMpV5")
# Price ID mapping for easy lookup
STRIPE_PRICE_ID_MAPPING: ClassVar[Dict[Tuple[str, str], str]] = {
('starter', 'monthly'): STARTER_MONTHLY_PRICE_ID,
('starter', 'yearly'): STARTER_YEARLY_PRICE_ID,
('professional', 'monthly'): PROFESSIONAL_MONTHLY_PRICE_ID,
('professional', 'yearly'): PROFESSIONAL_YEARLY_PRICE_ID,
('enterprise', 'monthly'): ENTERPRISE_MONTHLY_PRICE_ID,
('enterprise', 'yearly'): ENTERPRISE_YEARLY_PRICE_ID,
}
# ============================================================
# SCHEDULER CONFIGURATION
# ============================================================
# Usage tracking scheduler
USAGE_TRACKING_ENABLED: bool = os.getenv("USAGE_TRACKING_ENABLED", "true").lower() == "true"
USAGE_TRACKING_HOUR: int = int(os.getenv("USAGE_TRACKING_HOUR", "2"))
USAGE_TRACKING_MINUTE: int = int(os.getenv("USAGE_TRACKING_MINUTE", "0"))
USAGE_TRACKING_TIMEZONE: str = os.getenv("USAGE_TRACKING_TIMEZONE", "UTC")
settings = TenantSettings()

View File

@@ -0,0 +1,12 @@
"""
Database configuration for tenant service
"""
from shared.database.base import DatabaseManager
from app.core.config import settings
# Initialize database manager
database_manager = DatabaseManager(settings.DATABASE_URL, service_name="tenant-service")
# Alias for convenience
get_db = database_manager.get_db

View File

@@ -0,0 +1,127 @@
"""
Startup seeder for tenant service.
Seeds initial data (like pilot coupons) on service startup.
All operations are idempotent - safe to run multiple times.
"""
import os
import uuid
from datetime import datetime, timedelta, timezone
from typing import Optional
import structlog
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.coupon import CouponModel
logger = structlog.get_logger()
async def ensure_pilot_coupon(session: AsyncSession) -> Optional[CouponModel]:
"""
Ensure the PILOT2025 coupon exists in the database.
This coupon provides 3 months (90 days) free trial extension
for the first 20 pilot customers.
This function is idempotent - it will not create duplicates.
Args:
session: Database session
Returns:
The coupon model (existing or newly created), or None if disabled
"""
# Check if pilot mode is enabled via environment variable
pilot_mode_enabled = os.getenv("VITE_PILOT_MODE_ENABLED", "true").lower() == "true"
if not pilot_mode_enabled:
logger.info("Pilot mode is disabled, skipping coupon seeding")
return None
coupon_code = os.getenv("VITE_PILOT_COUPON_CODE", "PILOT2025")
trial_months = int(os.getenv("VITE_PILOT_TRIAL_MONTHS", "3"))
max_redemptions = int(os.getenv("PILOT_MAX_REDEMPTIONS", "20"))
# Check if coupon already exists
result = await session.execute(
select(CouponModel).where(CouponModel.code == coupon_code)
)
existing_coupon = result.scalars().first()
if existing_coupon:
logger.info(
"Pilot coupon already exists",
code=coupon_code,
current_redemptions=existing_coupon.current_redemptions,
max_redemptions=existing_coupon.max_redemptions,
active=existing_coupon.active
)
return existing_coupon
# Create new coupon
now = datetime.now(timezone.utc)
valid_until = now + timedelta(days=180) # Valid for 6 months
trial_days = trial_months * 30 # Approximate days
coupon = CouponModel(
id=uuid.uuid4(),
code=coupon_code,
discount_type="trial_extension",
discount_value=trial_days,
max_redemptions=max_redemptions,
current_redemptions=0,
valid_from=now,
valid_until=valid_until,
active=True,
created_at=now,
extra_data={
"program": "pilot_launch_2025",
"description": f"Programa piloto - {trial_months} meses gratis para los primeros {max_redemptions} clientes",
"terms": "Válido para nuevos registros únicamente. Un cupón por cliente."
}
)
session.add(coupon)
await session.commit()
await session.refresh(coupon)
logger.info(
"Pilot coupon created successfully",
code=coupon_code,
type="Trial Extension",
value=f"{trial_days} days ({trial_months} months)",
max_redemptions=max_redemptions,
valid_until=valid_until.isoformat(),
id=str(coupon.id)
)
return coupon
async def run_startup_seeders(database_manager) -> None:
"""
Run all startup seeders.
This function is called during service startup to ensure
required seed data exists in the database.
Args:
database_manager: The database manager instance
"""
logger.info("Running startup seeders...")
try:
async with database_manager.get_session() as session:
# Seed pilot coupon
await ensure_pilot_coupon(session)
logger.info("Startup seeders completed successfully")
except Exception as e:
# Log but don't fail startup - seed data is not critical
logger.warning(
"Startup seeder encountered an error (non-fatal)",
error=str(e)
)

View File

@@ -0,0 +1,103 @@
"""
Background job to process subscription downgrades at period end
Runs periodically to check for subscriptions with:
- status = 'pending_cancellation'
- cancellation_effective_date <= now()
Converts them to 'inactive' status
"""
import structlog
import asyncio
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_async_session_factory
from app.models.tenants import Subscription
logger = structlog.get_logger()
async def process_pending_cancellations():
"""
Process all subscriptions that have reached their cancellation_effective_date
"""
async_session_factory = get_async_session_factory()
async with async_session_factory() as session:
try:
query = select(Subscription).where(
Subscription.status == 'pending_cancellation',
Subscription.cancellation_effective_date <= datetime.now(timezone.utc)
)
result = await session.execute(query)
subscriptions_to_downgrade = result.scalars().all()
downgraded_count = 0
for subscription in subscriptions_to_downgrade:
subscription.status = 'inactive'
subscription.plan = 'free'
subscription.monthly_price = 0.0
logger.info(
"subscription_downgraded_to_inactive",
tenant_id=str(subscription.tenant_id),
previous_plan=subscription.plan,
cancellation_effective_date=subscription.cancellation_effective_date.isoformat()
)
downgraded_count += 1
if downgraded_count > 0:
await session.commit()
logger.info(
"subscriptions_downgraded",
count=downgraded_count
)
else:
logger.debug("no_subscriptions_to_downgrade")
return downgraded_count
except Exception as e:
logger.error(
"subscription_downgrade_job_failed",
error=str(e)
)
await session.rollback()
raise
async def run_subscription_downgrade_job():
"""
Main entry point for the subscription downgrade job
Runs in a loop with configurable interval
"""
interval_seconds = 3600 # Run every hour
logger.info("subscription_downgrade_job_started", interval_seconds=interval_seconds)
while True:
try:
downgraded_count = await process_pending_cancellations()
logger.info(
"subscription_downgrade_job_completed",
downgraded_count=downgraded_count
)
except Exception as e:
logger.error(
"subscription_downgrade_job_error",
error=str(e)
)
await asyncio.sleep(interval_seconds)
if __name__ == "__main__":
asyncio.run(run_subscription_downgrade_job())

View File

@@ -0,0 +1,247 @@
"""
Usage Tracking Scheduler
Tracks daily usage snapshots for all active tenants
"""
import asyncio
import structlog
from datetime import datetime, timedelta, timezone
from typing import Optional
from sqlalchemy import select, func
logger = structlog.get_logger()
class UsageTrackingScheduler:
"""Scheduler for daily usage tracking"""
def __init__(self, db_manager, redis_client, config):
self.db_manager = db_manager
self.redis = redis_client
self.config = config
self._running = False
self._task: Optional[asyncio.Task] = None
def seconds_until_target_time(self) -> float:
"""Calculate seconds until next target time (default 2am UTC)"""
now = datetime.now(timezone.utc)
target = now.replace(
hour=self.config.USAGE_TRACKING_HOUR,
minute=self.config.USAGE_TRACKING_MINUTE,
second=0,
microsecond=0
)
if target <= now:
target += timedelta(days=1)
return (target - now).total_seconds()
async def _get_tenant_usage(self, session, tenant_id: str) -> dict:
"""Get current usage counts for a tenant"""
usage = {}
try:
# Import models here to avoid circular imports
from app.models.tenants import TenantMember
# Users count
result = await session.execute(
select(func.count()).select_from(TenantMember).where(TenantMember.tenant_id == tenant_id)
)
usage['users'] = result.scalar() or 0
# Get counts from other services via their databases
# For now, we'll track basic metrics. More metrics can be added by querying other service databases
# Training jobs today (from Redis quota tracking)
today_key = f"quota:training_jobs:{tenant_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
training_count = await self.redis.get(today_key)
usage['training_jobs'] = int(training_count) if training_count else 0
# Forecasts today (from Redis quota tracking)
forecast_key = f"quota:forecasts:{tenant_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d')}"
forecast_count = await self.redis.get(forecast_key)
usage['forecasts'] = int(forecast_count) if forecast_count else 0
# API calls this hour (from Redis quota tracking)
hour_key = f"quota:api_calls:{tenant_id}:{datetime.now(timezone.utc).strftime('%Y-%m-%d-%H')}"
api_count = await self.redis.get(hour_key)
usage['api_calls'] = int(api_count) if api_count else 0
# Storage (placeholder - implement based on file storage system)
usage['storage'] = 0.0
except Exception as e:
logger.error("Error getting usage for tenant", tenant_id=tenant_id, error=str(e), exc_info=True)
return {}
return usage
async def _track_metrics(self, tenant_id: str, usage: dict):
"""Track metrics to Redis"""
from app.api.usage_forecast import track_usage_snapshot
for metric_name, value in usage.items():
try:
await track_usage_snapshot(tenant_id, metric_name, value)
except Exception as e:
logger.error(
"Failed to track metric",
tenant_id=tenant_id,
metric=metric_name,
error=str(e)
)
async def _run_cycle(self):
"""Execute one tracking cycle"""
start_time = datetime.now(timezone.utc)
logger.info("Starting daily usage tracking cycle")
try:
async with self.db_manager.get_session() as session:
# Import models here to avoid circular imports
from app.models.tenants import Tenant, Subscription
from sqlalchemy import select
# Get all active tenants
result = await session.execute(
select(Tenant, Subscription)
.join(Subscription, Tenant.id == Subscription.tenant_id)
.where(Tenant.is_active == True)
.where(Subscription.status.in_(['active', 'trialing', 'cancelled']))
)
tenants_data = result.all()
total_tenants = len(tenants_data)
success_count = 0
error_count = 0
logger.info(f"Found {total_tenants} active tenants to track")
# Process each tenant
for tenant, subscription in tenants_data:
try:
usage = await self._get_tenant_usage(session, tenant.id)
if usage:
await self._track_metrics(tenant.id, usage)
success_count += 1
else:
logger.warning(
"No usage data available for tenant",
tenant_id=tenant.id
)
error_count += 1
except Exception as e:
logger.error(
"Error tracking tenant usage",
tenant_id=tenant.id,
error=str(e),
exc_info=True
)
error_count += 1
end_time = datetime.now(timezone.utc)
duration = (end_time - start_time).total_seconds()
logger.info(
"Daily usage tracking completed",
total_tenants=total_tenants,
success=success_count,
errors=error_count,
duration_seconds=duration
)
except Exception as e:
logger.error("Usage tracking cycle failed", error=str(e), exc_info=True)
async def _run_scheduler(self):
"""Main scheduler loop"""
logger.info(
"Usage tracking scheduler loop started",
target_hour=self.config.USAGE_TRACKING_HOUR,
target_minute=self.config.USAGE_TRACKING_MINUTE
)
# Initial delay to target time
delay = self.seconds_until_target_time()
logger.info(f"Waiting {delay/3600:.2f} hours until next run at {self.config.USAGE_TRACKING_HOUR:02d}:{self.config.USAGE_TRACKING_MINUTE:02d} UTC")
try:
await asyncio.sleep(delay)
except asyncio.CancelledError:
logger.info("Scheduler cancelled during initial delay")
return
while self._running:
try:
await self._run_cycle()
except Exception as e:
logger.error("Scheduler cycle error", error=str(e), exc_info=True)
# Wait 24 hours until next run
try:
await asyncio.sleep(86400)
except asyncio.CancelledError:
logger.info("Scheduler cancelled during sleep")
break
def start(self):
"""Start the scheduler"""
if not self.config.USAGE_TRACKING_ENABLED:
logger.info("Usage tracking scheduler disabled by configuration")
return
if self._running:
logger.warning("Usage tracking scheduler already running")
return
self._running = True
self._task = asyncio.create_task(self._run_scheduler())
logger.info("Usage tracking scheduler started successfully")
async def stop(self):
"""Stop the scheduler gracefully"""
if not self._running:
logger.debug("Scheduler not running, nothing to stop")
return
logger.info("Stopping usage tracking scheduler")
self._running = False
if self._task:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
logger.info("Scheduler task cancelled successfully")
logger.info("Usage tracking scheduler stopped")
# Global instance
_scheduler: Optional[UsageTrackingScheduler] = None
async def start_scheduler(db_manager, redis_client, config):
"""Start the usage tracking scheduler"""
global _scheduler
try:
_scheduler = UsageTrackingScheduler(db_manager, redis_client, config)
_scheduler.start()
logger.info("Usage tracking scheduler module initialized")
except Exception as e:
logger.error("Failed to start usage tracking scheduler", error=str(e), exc_info=True)
raise
async def stop_scheduler():
"""Stop the usage tracking scheduler"""
global _scheduler
if _scheduler:
await _scheduler.stop()
_scheduler = None
logger.info("Usage tracking scheduler module stopped")

175
services/tenant/app/main.py Normal file
View File

@@ -0,0 +1,175 @@
# services/tenant/app/main.py
"""
Tenant Service FastAPI application
"""
from fastapi import FastAPI
from sqlalchemy import text
from app.core.config import settings
from app.core.database import database_manager
from app.api import tenants, tenant_members, tenant_operations, webhooks, plans, subscription, tenant_settings, whatsapp_admin, usage_forecast, enterprise_upgrade, tenant_locations, tenant_hierarchy, internal_demo, network_alerts, onboarding
from shared.service_base import StandardFastAPIService
from shared.monitoring.system_metrics import SystemMetricsCollector
class TenantService(StandardFastAPIService):
"""Tenant Service with standardized setup"""
expected_migration_version = "00001"
async def verify_migrations(self):
"""Verify database schema matches the latest migrations."""
try:
async with self.database_manager.get_session() as session:
# Check if alembic_version table exists
result = await session.execute(text("""
SELECT EXISTS (
SELECT FROM information_schema.tables
WHERE table_schema = 'public'
AND table_name = 'alembic_version'
)
"""))
table_exists = result.scalar()
if table_exists:
# If table exists, check the version
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
# For now, just log the version instead of strict checking to avoid startup failures
self.logger.info(f"Migration verification successful: {version}")
else:
# If table doesn't exist, migrations might not have run yet
# This is OK - the migration job should create it
self.logger.warning("alembic_version table does not exist yet - migrations may not have run")
except Exception as e:
self.logger.warning(f"Migration verification failed (this may be expected during initial setup): {e}")
def __init__(self):
# Define expected database tables for health checks
tenant_expected_tables = ['tenants', 'tenant_members', 'subscriptions']
# Note: api_prefix is empty because RouteBuilder already includes /api/v1
super().__init__(
service_name="tenant-service",
app_name="Tenant Management Service",
description="Multi-tenant bakery management service",
version="1.0.0",
log_level=settings.LOG_LEVEL,
api_prefix="",
database_manager=database_manager,
expected_tables=tenant_expected_tables
)
async def on_startup(self, app: FastAPI):
"""Custom startup logic for tenant service"""
# Verify migrations first
await self.verify_migrations()
# Import models to ensure they're registered with SQLAlchemy
from app.models.tenants import Tenant, TenantMember, Subscription
from app.models.tenant_settings import TenantSettings
self.logger.info("Tenant models imported successfully")
# Initialize Redis
from shared.redis_utils import initialize_redis, get_redis_client
await initialize_redis(settings.REDIS_URL, db=settings.REDIS_DB, max_connections=20)
redis_client = await get_redis_client()
self.logger.info("Redis initialized successfully")
# Initialize system metrics collection
system_metrics = SystemMetricsCollector("tenant")
self.logger.info("System metrics collection started")
# Start usage tracking scheduler
from app.jobs.usage_tracking_scheduler import start_scheduler
await start_scheduler(self.database_manager, redis_client, settings)
self.logger.info("Usage tracking scheduler started")
# Run startup seeders (pilot coupon, etc.)
from app.jobs.startup_seeder import run_startup_seeders
await run_startup_seeders(self.database_manager)
self.logger.info("Startup seeders completed")
async def on_shutdown(self, app: FastAPI):
"""Custom shutdown logic for tenant service"""
# Stop usage tracking scheduler
from app.jobs.usage_tracking_scheduler import stop_scheduler
await stop_scheduler()
self.logger.info("Usage tracking scheduler stopped")
# Close Redis connection
from shared.redis_utils import close_redis
await close_redis()
self.logger.info("Redis connection closed")
# Database cleanup is handled by the base class
def get_service_features(self):
"""Return tenant-specific features"""
return [
"multi_tenant_management",
"subscription_management",
"tenant_isolation",
"webhook_notifications",
"member_management"
]
def setup_custom_endpoints(self):
"""Setup custom endpoints for tenant service"""
# Note: Metrics are exported via OpenTelemetry OTLP to SigNoz
# The /metrics endpoint is not needed as metrics are pushed automatically
# @self.app.get("/metrics")
# async def metrics():
# """Prometheus metrics endpoint"""
# if self.metrics_collector:
# return self.metrics_collector.get_metrics()
# return {"metrics": "not_available"}
# Create service instance
service = TenantService()
# Create FastAPI app with standardized setup
app = service.create_app(
docs_url="/docs",
redoc_url="/redoc"
)
# Setup standard endpoints
service.setup_standard_endpoints()
# Setup custom endpoints
service.setup_custom_endpoints()
# Include routers
service.add_router(plans.router, tags=["subscription-plans"]) # Public endpoint
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(subscription.router, tags=["subscription"])
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(usage_forecast.router, prefix="/api/v1", tags=["usage-forecast"]) # Usage forecasting & predictive analytics
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
# Register settings router BEFORE tenants router to ensure proper route matching
service.add_router(tenant_settings.router, prefix="/api/v1/tenants", tags=["tenant-settings"])
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(whatsapp_admin.router, prefix="/api/v1", tags=["whatsapp-admin"]) # Admin WhatsApp management
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(tenants.router, tags=["tenants"])
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(tenant_members.router, tags=["tenant-members"])
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(tenant_operations.router, tags=["tenant-operations"])
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(webhooks.router, tags=["webhooks"])
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(enterprise_upgrade.router, tags=["enterprise"]) # Enterprise tier upgrade endpoints
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(tenant_locations.router, tags=["tenant-locations"]) # Tenant locations endpoints
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(tenant_hierarchy.router, tags=["tenant-hierarchy"]) # Tenant hierarchy endpoints
service.add_router(internal_demo.router, tags=["internal-demo"]) # Internal demo data cloning
service.add_router(network_alerts.router, tags=["network-alerts"]) # Network alerts aggregation endpoints
service.add_router(onboarding.router, tags=["onboarding"]) # Onboarding status endpoints
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)

View File

@@ -0,0 +1,31 @@
"""
Tenant Service Models Package
Import all models to ensure they are registered with SQLAlchemy Base.
"""
# Import AuditLog model for this service
from shared.security import create_audit_log_model
from shared.database.base import Base
# Create audit log model for this service
AuditLog = create_audit_log_model(Base)
# Import all models to register them with the Base metadata
from .tenants import Tenant, TenantMember, Subscription
from .tenant_location import TenantLocation
from .coupon import CouponModel, CouponRedemptionModel
from .events import Event, EventTemplate
# List all models for easier access
__all__ = [
"Tenant",
"TenantMember",
"Subscription",
"TenantLocation",
"AuditLog",
"CouponModel",
"CouponRedemptionModel",
"Event",
"EventTemplate",
]

View File

@@ -0,0 +1,64 @@
"""
SQLAlchemy models for coupon system
"""
from datetime import datetime
from sqlalchemy import Column, String, Integer, Boolean, DateTime, ForeignKey, JSON, Index
from sqlalchemy.orm import relationship
from sqlalchemy.dialects.postgresql import UUID
import uuid
from shared.database import Base
class CouponModel(Base):
"""Coupon configuration table"""
__tablename__ = "coupons"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
code = Column(String(50), unique=True, nullable=False, index=True)
discount_type = Column(String(20), nullable=False) # trial_extension, percentage, fixed_amount
discount_value = Column(Integer, nullable=False) # Days/percentage/cents depending on type
max_redemptions = Column(Integer, nullable=True) # None = unlimited
current_redemptions = Column(Integer, nullable=False, default=0)
valid_from = Column(DateTime(timezone=True), nullable=False)
valid_until = Column(DateTime(timezone=True), nullable=True) # None = no expiry
active = Column(Boolean, nullable=False, default=True)
created_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow)
extra_data = Column(JSON, nullable=True) # Renamed from metadata to avoid SQLAlchemy reserved name
# Relationships
redemptions = relationship("CouponRedemptionModel", back_populates="coupon")
# Indexes for performance
__table_args__ = (
Index('idx_coupon_code_active', 'code', 'active'),
Index('idx_coupon_valid_dates', 'valid_from', 'valid_until'),
)
def __repr__(self):
return f"<Coupon(code='{self.code}', type='{self.discount_type}', value={self.discount_value})>"
class CouponRedemptionModel(Base):
"""Coupon redemption history table"""
__tablename__ = "coupon_redemptions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(String(255), nullable=False, index=True)
coupon_code = Column(String(50), ForeignKey('coupons.code'), nullable=False)
redeemed_at = Column(DateTime(timezone=True), nullable=False, default=datetime.utcnow)
discount_applied = Column(JSON, nullable=False) # Details of discount applied
extra_data = Column(JSON, nullable=True) # Renamed from metadata to avoid SQLAlchemy reserved name
# Relationships
coupon = relationship("CouponModel", back_populates="redemptions")
# Constraints
__table_args__ = (
Index('idx_redemption_tenant', 'tenant_id'),
Index('idx_redemption_coupon', 'coupon_code'),
Index('idx_redemption_tenant_coupon', 'tenant_id', 'coupon_code'), # Prevent duplicate redemptions
)
def __repr__(self):
return f"<CouponRedemption(tenant_id='{self.tenant_id}', code='{self.coupon_code}')>"

View File

@@ -0,0 +1,136 @@
"""
Event Calendar Models
Database models for tracking local events, promotions, and special occasions
"""
from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, Float, Date
from sqlalchemy.dialects.postgresql import UUID
from shared.database.base import Base
from datetime import datetime, timezone
import uuid
class Event(Base):
"""
Table to track events that affect bakery demand.
Events include:
- Local events (festivals, markets, concerts)
- Promotions and sales
- Weather events (heat waves, storms)
- School holidays and breaks
- Special occasions
"""
__tablename__ = "events"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Event information
event_name = Column(String(500), nullable=False)
event_type = Column(String(100), nullable=False, index=True) # promotion, festival, holiday, weather, school_break, sport_event, etc.
description = Column(Text, nullable=True)
# Date and time
event_date = Column(Date, nullable=False, index=True)
start_time = Column(DateTime(timezone=True), nullable=True)
end_time = Column(DateTime(timezone=True), nullable=True)
is_all_day = Column(Boolean, default=True)
# Impact estimation
expected_impact = Column(String(50), nullable=True) # low, medium, high, very_high
impact_multiplier = Column(Float, nullable=True) # Expected demand multiplier (e.g., 1.5 = 50% increase)
affected_product_categories = Column(String(500), nullable=True) # Comma-separated categories
# Location
location = Column(String(500), nullable=True)
is_local = Column(Boolean, default=True) # True if event is near bakery
# Status
is_confirmed = Column(Boolean, default=False)
is_recurring = Column(Boolean, default=False)
recurrence_pattern = Column(String(200), nullable=True) # e.g., "weekly:monday", "monthly:first_saturday"
# Actual impact (filled after event)
actual_impact_multiplier = Column(Float, nullable=True)
actual_sales_increase_percent = Column(Float, nullable=True)
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
created_by = Column(String(255), nullable=True)
notes = Column(Text, nullable=True)
def to_dict(self):
return {
"id": str(self.id),
"tenant_id": str(self.tenant_id),
"event_name": self.event_name,
"event_type": self.event_type,
"description": self.description,
"event_date": self.event_date.isoformat() if self.event_date else None,
"start_time": self.start_time.isoformat() if self.start_time else None,
"end_time": self.end_time.isoformat() if self.end_time else None,
"is_all_day": self.is_all_day,
"expected_impact": self.expected_impact,
"impact_multiplier": self.impact_multiplier,
"affected_product_categories": self.affected_product_categories,
"location": self.location,
"is_local": self.is_local,
"is_confirmed": self.is_confirmed,
"is_recurring": self.is_recurring,
"recurrence_pattern": self.recurrence_pattern,
"actual_impact_multiplier": self.actual_impact_multiplier,
"actual_sales_increase_percent": self.actual_sales_increase_percent,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None,
"created_by": self.created_by,
"notes": self.notes
}
class EventTemplate(Base):
"""
Template for recurring events.
Allows easy creation of events based on patterns.
"""
__tablename__ = "event_templates"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Template information
template_name = Column(String(500), nullable=False)
event_type = Column(String(100), nullable=False)
description = Column(Text, nullable=True)
# Default values
default_impact = Column(String(50), nullable=True)
default_impact_multiplier = Column(Float, nullable=True)
default_affected_categories = Column(String(500), nullable=True)
# Recurrence
recurrence_pattern = Column(String(200), nullable=False) # e.g., "weekly:saturday", "monthly:last_sunday"
is_active = Column(Boolean, default=True)
# Metadata
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
def to_dict(self):
return {
"id": str(self.id),
"tenant_id": str(self.tenant_id),
"template_name": self.template_name,
"event_type": self.event_type,
"description": self.description,
"default_impact": self.default_impact,
"default_impact_multiplier": self.default_impact_multiplier,
"default_affected_categories": self.default_affected_categories,
"recurrence_pattern": self.recurrence_pattern,
"is_active": self.is_active,
"created_at": self.created_at.isoformat() if self.created_at else None,
"updated_at": self.updated_at.isoformat() if self.updated_at else None
}

View File

@@ -0,0 +1,59 @@
"""
Tenant Location Model
Represents physical locations for enterprise tenants (central production, retail outlets)
"""
from sqlalchemy import Column, String, Boolean, DateTime, Float, ForeignKey, Text, Integer, JSON
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class TenantLocation(Base):
"""TenantLocation model - represents physical locations for enterprise tenants"""
__tablename__ = "tenant_locations"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False, index=True)
# Location information
name = Column(String(200), nullable=False)
location_type = Column(String(50), nullable=False) # central_production, retail_outlet
address = Column(Text, nullable=False)
city = Column(String(100), default="Madrid")
postal_code = Column(String(10), nullable=False)
latitude = Column(Float, nullable=True)
longitude = Column(Float, nullable=True)
# Location-specific configuration
delivery_windows = Column(JSON, nullable=True) # { "monday": "08:00-12:00,14:00-18:00", ... }
capacity = Column(Integer, nullable=True) # For production capacity in kg/day or storage capacity
max_delivery_radius_km = Column(Float, nullable=True, default=50.0)
# Operational hours
operational_hours = Column(JSON, nullable=True) # { "monday": "06:00-20:00", ... }
is_active = Column(Boolean, default=True)
# Contact information
contact_person = Column(String(200), nullable=True)
contact_phone = Column(String(20), nullable=True)
contact_email = Column(String(255), nullable=True)
# Custom delivery scheduling configuration per location
delivery_schedule_config = Column(JSON, nullable=True) # { "delivery_days": "Mon,Wed,Fri", "time_window": "07:00-10:00" }
# Metadata
metadata_ = Column(JSON, nullable=True)
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# Relationships
tenant = relationship("Tenant", back_populates="locations")
def __repr__(self):
return f"<TenantLocation(id={self.id}, tenant_id={self.tenant_id}, name={self.name}, type={self.location_type})>"

View File

@@ -0,0 +1,370 @@
# services/tenant/app/models/tenant_settings.py
"""
Tenant Settings Model
Centralized configuration storage for all tenant-specific operational settings
"""
from sqlalchemy import Column, String, Boolean, DateTime, ForeignKey, JSON
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class TenantSettings(Base):
"""
Centralized tenant settings model
Stores all operational configurations for a tenant across all services
"""
__tablename__ = "tenant_settings"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False, unique=True, index=True)
# Procurement & Auto-Approval Settings (Orders Service)
procurement_settings = Column(JSON, nullable=False, default=lambda: {
"auto_approve_enabled": True,
"auto_approve_threshold_eur": 500.0,
"auto_approve_min_supplier_score": 0.80,
"require_approval_new_suppliers": True,
"require_approval_critical_items": True,
"procurement_lead_time_days": 3,
"demand_forecast_days": 14,
"safety_stock_percentage": 20.0,
"po_approval_reminder_hours": 24,
"po_critical_escalation_hours": 12,
"use_reorder_rules": True,
"economic_rounding": True,
"respect_storage_limits": True,
"use_supplier_minimums": True,
"optimize_price_tiers": True
})
# Inventory Management Settings (Inventory Service)
inventory_settings = Column(JSON, nullable=False, default=lambda: {
"low_stock_threshold": 10,
"reorder_point": 20,
"reorder_quantity": 50,
"expiring_soon_days": 7,
"expiration_warning_days": 3,
"quality_score_threshold": 8.0,
"temperature_monitoring_enabled": True,
"refrigeration_temp_min": 1.0,
"refrigeration_temp_max": 4.0,
"freezer_temp_min": -20.0,
"freezer_temp_max": -15.0,
"room_temp_min": 18.0,
"room_temp_max": 25.0,
"temp_deviation_alert_minutes": 15,
"critical_temp_deviation_minutes": 5
})
# Production Settings (Production Service)
production_settings = Column(JSON, nullable=False, default=lambda: {
"planning_horizon_days": 7,
"minimum_batch_size": 1.0,
"maximum_batch_size": 100.0,
"production_buffer_percentage": 10.0,
"working_hours_per_day": 12,
"max_overtime_hours": 4,
"capacity_utilization_target": 0.85,
"capacity_warning_threshold": 0.95,
"quality_check_enabled": True,
"minimum_yield_percentage": 85.0,
"quality_score_threshold": 8.0,
"schedule_optimization_enabled": True,
"prep_time_buffer_minutes": 30,
"cleanup_time_buffer_minutes": 15,
"labor_cost_per_hour_eur": 15.0,
"overhead_cost_percentage": 20.0
})
# Supplier Settings (Suppliers Service)
supplier_settings = Column(JSON, nullable=False, default=lambda: {
"default_payment_terms_days": 30,
"default_delivery_days": 3,
"excellent_delivery_rate": 95.0,
"good_delivery_rate": 90.0,
"excellent_quality_rate": 98.0,
"good_quality_rate": 95.0,
"critical_delivery_delay_hours": 24,
"critical_quality_rejection_rate": 10.0,
"high_cost_variance_percentage": 15.0,
# Supplier Approval Workflow Settings
"require_supplier_approval": True,
"auto_approve_for_admin_owner": True,
"approval_required_roles": ["member", "viewer"]
})
# POS Integration Settings (POS Service)
pos_settings = Column(JSON, nullable=False, default=lambda: {
"sync_interval_minutes": 5,
"auto_sync_products": True,
"auto_sync_transactions": True
})
# Order & Business Rules Settings (Orders Service)
order_settings = Column(JSON, nullable=False, default=lambda: {
"max_discount_percentage": 50.0,
"default_delivery_window_hours": 48,
"dynamic_pricing_enabled": False,
"discount_enabled": True,
"delivery_tracking_enabled": True
})
# Replenishment Planning Settings (Orchestrator Service)
replenishment_settings = Column(JSON, nullable=False, default=lambda: {
"projection_horizon_days": 7,
"service_level": 0.95,
"buffer_days": 1,
"enable_auto_replenishment": True,
"min_order_quantity": 1.0,
"max_order_quantity": 1000.0,
"demand_forecast_days": 14
})
# Safety Stock Settings (Orchestrator Service)
safety_stock_settings = Column(JSON, nullable=False, default=lambda: {
"service_level": 0.95,
"method": "statistical",
"min_safety_stock": 0.0,
"max_safety_stock": 100.0,
"reorder_point_calculation": "safety_stock_plus_lead_time_demand"
})
# MOQ Aggregation Settings (Orchestrator Service)
moq_settings = Column(JSON, nullable=False, default=lambda: {
"consolidation_window_days": 7,
"allow_early_ordering": True,
"enable_batch_optimization": True,
"min_batch_size": 1.0,
"max_batch_size": 1000.0
})
# Supplier Selection Settings (Orchestrator Service)
supplier_selection_settings = Column(JSON, nullable=False, default=lambda: {
"price_weight": 0.40,
"lead_time_weight": 0.20,
"quality_weight": 0.20,
"reliability_weight": 0.20,
"diversification_threshold": 1000,
"max_single_percentage": 0.70,
"enable_supplier_score_optimization": True
})
# ML Insights Settings (AI Insights Service)
ml_insights_settings = Column(JSON, nullable=False, default=lambda: {
# Inventory ML (Safety Stock Optimization)
"inventory_lookback_days": 90,
"inventory_min_history_days": 30,
# Production ML (Yield Prediction)
"production_lookback_days": 90,
"production_min_history_runs": 30,
# Procurement ML (Supplier Analysis & Price Forecasting)
"supplier_analysis_lookback_days": 180,
"supplier_analysis_min_orders": 10,
"price_forecast_lookback_days": 180,
"price_forecast_horizon_days": 30,
# Forecasting ML (Dynamic Rules)
"rules_generation_lookback_days": 90,
"rules_generation_min_samples": 10,
# Global ML Settings
"enable_ml_insights": True,
"ml_insights_auto_trigger": False,
"ml_confidence_threshold": 0.80
})
# Notification Settings (Notification Service)
notification_settings = Column(JSON, nullable=False, default=lambda: {
# WhatsApp Configuration (Shared Account Model)
"whatsapp_enabled": False,
"whatsapp_phone_number_id": "", # Meta WhatsApp Phone Number ID (from shared master account)
"whatsapp_display_phone_number": "", # Display format for UI (e.g., "+34 612 345 678")
"whatsapp_default_language": "es",
# Email Configuration
"email_enabled": True,
"email_from_address": "",
"email_from_name": "",
"email_reply_to": "",
# Notification Preferences
"enable_po_notifications": True,
"enable_inventory_alerts": True,
"enable_production_alerts": True,
"enable_forecast_alerts": True,
# Notification Channels
"po_notification_channels": ["email"], # ["email", "whatsapp"]
"inventory_alert_channels": ["email"],
"production_alert_channels": ["email"],
"forecast_alert_channels": ["email"]
})
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), nullable=False)
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc), nullable=False)
# Relationships
tenant = relationship("Tenant", backref="settings")
def __repr__(self):
return f"<TenantSettings(tenant_id={self.tenant_id})>"
@staticmethod
def get_default_settings() -> dict:
"""
Get default settings for all categories
Returns a dictionary with default values for all setting categories
"""
return {
"procurement_settings": {
"auto_approve_enabled": True,
"auto_approve_threshold_eur": 500.0,
"auto_approve_min_supplier_score": 0.80,
"require_approval_new_suppliers": True,
"require_approval_critical_items": True,
"procurement_lead_time_days": 3,
"demand_forecast_days": 14,
"safety_stock_percentage": 20.0,
"po_approval_reminder_hours": 24,
"po_critical_escalation_hours": 12,
"use_reorder_rules": True,
"economic_rounding": True,
"respect_storage_limits": True,
"use_supplier_minimums": True,
"optimize_price_tiers": True
},
"inventory_settings": {
"low_stock_threshold": 10,
"reorder_point": 20,
"reorder_quantity": 50,
"expiring_soon_days": 7,
"expiration_warning_days": 3,
"quality_score_threshold": 8.0,
"temperature_monitoring_enabled": True,
"refrigeration_temp_min": 1.0,
"refrigeration_temp_max": 4.0,
"freezer_temp_min": -20.0,
"freezer_temp_max": -15.0,
"room_temp_min": 18.0,
"room_temp_max": 25.0,
"temp_deviation_alert_minutes": 15,
"critical_temp_deviation_minutes": 5
},
"production_settings": {
"planning_horizon_days": 7,
"minimum_batch_size": 1.0,
"maximum_batch_size": 100.0,
"production_buffer_percentage": 10.0,
"working_hours_per_day": 12,
"max_overtime_hours": 4,
"capacity_utilization_target": 0.85,
"capacity_warning_threshold": 0.95,
"quality_check_enabled": True,
"minimum_yield_percentage": 85.0,
"quality_score_threshold": 8.0,
"schedule_optimization_enabled": True,
"prep_time_buffer_minutes": 30,
"cleanup_time_buffer_minutes": 15,
"labor_cost_per_hour_eur": 15.0,
"overhead_cost_percentage": 20.0
},
"supplier_settings": {
"default_payment_terms_days": 30,
"default_delivery_days": 3,
"excellent_delivery_rate": 95.0,
"good_delivery_rate": 90.0,
"excellent_quality_rate": 98.0,
"good_quality_rate": 95.0,
"critical_delivery_delay_hours": 24,
"critical_quality_rejection_rate": 10.0,
"high_cost_variance_percentage": 15.0,
"require_supplier_approval": True,
"auto_approve_for_admin_owner": True,
"approval_required_roles": ["member", "viewer"]
},
"pos_settings": {
"sync_interval_minutes": 5,
"auto_sync_products": True,
"auto_sync_transactions": True
},
"order_settings": {
"max_discount_percentage": 50.0,
"default_delivery_window_hours": 48,
"dynamic_pricing_enabled": False,
"discount_enabled": True,
"delivery_tracking_enabled": True
},
"replenishment_settings": {
"projection_horizon_days": 7,
"service_level": 0.95,
"buffer_days": 1,
"enable_auto_replenishment": True,
"min_order_quantity": 1.0,
"max_order_quantity": 1000.0,
"demand_forecast_days": 14
},
"safety_stock_settings": {
"service_level": 0.95,
"method": "statistical",
"min_safety_stock": 0.0,
"max_safety_stock": 100.0,
"reorder_point_calculation": "safety_stock_plus_lead_time_demand"
},
"moq_settings": {
"consolidation_window_days": 7,
"allow_early_ordering": True,
"enable_batch_optimization": True,
"min_batch_size": 1.0,
"max_batch_size": 1000.0
},
"supplier_selection_settings": {
"price_weight": 0.40,
"lead_time_weight": 0.20,
"quality_weight": 0.20,
"reliability_weight": 0.20,
"diversification_threshold": 1000,
"max_single_percentage": 0.70,
"enable_supplier_score_optimization": True
},
"ml_insights_settings": {
"inventory_lookback_days": 90,
"inventory_min_history_days": 30,
"production_lookback_days": 90,
"production_min_history_runs": 30,
"supplier_analysis_lookback_days": 180,
"supplier_analysis_min_orders": 10,
"price_forecast_lookback_days": 180,
"price_forecast_horizon_days": 30,
"rules_generation_lookback_days": 90,
"rules_generation_min_samples": 10,
"enable_ml_insights": True,
"ml_insights_auto_trigger": False,
"ml_confidence_threshold": 0.80
},
"notification_settings": {
"whatsapp_enabled": False,
"whatsapp_phone_number_id": "",
"whatsapp_display_phone_number": "",
"whatsapp_default_language": "es",
"email_enabled": True,
"email_from_address": "",
"email_from_name": "",
"email_reply_to": "",
"enable_po_notifications": True,
"enable_inventory_alerts": True,
"enable_production_alerts": True,
"enable_forecast_alerts": True,
"po_notification_channels": ["email"],
"inventory_alert_channels": ["email"],
"production_alert_channels": ["email"],
"forecast_alert_channels": ["email"]
}
}

View File

@@ -0,0 +1,221 @@
# services/tenant/app/models/tenants.py - FIXED VERSION
"""
Tenant models for bakery management - FIXED
Removed cross-service User relationship to eliminate circular dependencies
"""
from sqlalchemy import Column, String, Boolean, DateTime, Float, ForeignKey, Text, Integer, JSON
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import relationship
from datetime import datetime, timezone
import uuid
from shared.database.base import Base
class Tenant(Base):
"""Tenant/Bakery model"""
__tablename__ = "tenants"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
name = Column(String(200), nullable=False)
subdomain = Column(String(100), unique=True)
business_type = Column(String(100), default="bakery")
business_model = Column(String(100), default="individual_bakery") # individual_bakery, central_baker_satellite, retail_bakery, hybrid_bakery
# Location info
address = Column(Text, nullable=False)
city = Column(String(100), default="Madrid")
postal_code = Column(String(10), nullable=False)
latitude = Column(Float)
longitude = Column(Float)
# Regional/Localization configuration
timezone = Column(String(50), default="Europe/Madrid", nullable=False)
currency = Column(String(3), default="EUR", nullable=False) # Currency code: EUR, USD, GBP
language = Column(String(5), default="es", nullable=False) # Language code: es, en, eu
# Contact info
phone = Column(String(20))
email = Column(String(255))
# Status
is_active = Column(Boolean, default=True)
# Demo account flags
is_demo = Column(Boolean, default=False, index=True)
is_demo_template = Column(Boolean, default=False, index=True)
base_demo_tenant_id = Column(UUID(as_uuid=True), nullable=True, index=True)
demo_session_id = Column(String(100), nullable=True, index=True)
demo_expires_at = Column(DateTime(timezone=True), nullable=True)
# ML status
ml_model_trained = Column(Boolean, default=False)
last_training_date = Column(DateTime(timezone=True))
# Additional metadata (JSON field for flexible data storage)
metadata_ = Column(JSON, nullable=True)
# Ownership (user_id without FK - cross-service reference)
owner_id = Column(UUID(as_uuid=True), nullable=False, index=True)
# Enterprise tier hierarchy fields
parent_tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="RESTRICT"), nullable=True, index=True)
tenant_type = Column(String(50), default="standalone", nullable=False) # standalone, parent, child
hierarchy_path = Column(String(500), nullable=True) # Materialized path for queries
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# 3D Secure (3DS) tracking
threeds_authentication_required = Column(Boolean, default=False)
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
threeds_authentication_completed = Column(Boolean, default=False)
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
last_threeds_setup_intent_id = Column(String(255), nullable=True)
threeds_action_type = Column(String(100), nullable=True)
# Relationships - only within tenant service
members = relationship("TenantMember", back_populates="tenant", cascade="all, delete-orphan")
subscriptions = relationship("Subscription", back_populates="tenant", cascade="all, delete-orphan")
locations = relationship("TenantLocation", back_populates="tenant", cascade="all, delete-orphan")
child_tenants = relationship("Tenant", back_populates="parent_tenant", remote_side=[id])
parent_tenant = relationship("Tenant", back_populates="child_tenants", remote_side=[parent_tenant_id])
# REMOVED: users relationship - no cross-service SQLAlchemy relationships
@property
def subscription_tier(self):
"""
Get current subscription tier from active subscription
Note: This is a computed property that requires subscription relationship to be loaded.
For performance-critical operations, use the subscription cache service directly.
"""
# Find active subscription
for subscription in self.subscriptions:
if subscription.status == 'active':
return subscription.plan
return "starter" # Default fallback
def __repr__(self):
return f"<Tenant(id={self.id}, name={self.name})>"
class TenantMember(Base):
"""
Tenant membership model for team access.
This model represents TENANT-SPECIFIC roles, which are distinct from global user roles.
TENANT ROLES (stored here):
- owner: Full control of the tenant, can transfer ownership, manage all aspects
- admin: Tenant administrator, can manage team members and most operations
- member: Standard team member, regular operational access
- viewer: Read-only observer, view-only access to tenant data
ROLE MAPPING TO GLOBAL ROLES:
When users are created through tenant management (pilot phase), their tenant role
is mapped to a global user role in the Auth service:
- tenant 'admin' → global 'admin' (system-wide admin access)
- tenant 'member' → global 'manager' (management-level access)
- tenant 'viewer' → global 'user' (basic user access)
- tenant 'owner' → No automatic global role (owner is tenant-specific)
This mapping is implemented in app/api/tenant_members.py lines 68-76.
Note: user_id is a cross-service reference (no FK) to avoid circular dependencies.
User enrichment is handled at the service layer via Auth service calls.
"""
__tablename__ = "tenant_members"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=False)
user_id = Column(UUID(as_uuid=True), nullable=False, index=True) # No FK - cross-service reference
# Role and permissions specific to this tenant
# Valid values: 'owner', 'admin', 'member', 'viewer', 'network_admin'
role = Column(String(50), default="member")
permissions = Column(Text) # JSON string of permissions
# Status
is_active = Column(Boolean, default=True)
invited_by = Column(UUID(as_uuid=True)) # No FK - cross-service reference
invited_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
joined_at = Column(DateTime(timezone=True))
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
# Relationships - only within tenant service
tenant = relationship("Tenant", back_populates="members")
# REMOVED: user relationship - no cross-service SQLAlchemy relationships
def __repr__(self):
return f"<TenantMember(tenant_id={self.tenant_id}, user_id={self.user_id}, role={self.role})>"
# Additional models for subscriptions, plans, etc.
class Subscription(Base):
"""Subscription model for tenant billing with tenant linking support"""
__tablename__ = "subscriptions"
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
tenant_id = Column(UUID(as_uuid=True), ForeignKey("tenants.id", ondelete="CASCADE"), nullable=True)
# User reference for tenant-independent subscriptions
user_id = Column(UUID(as_uuid=True), nullable=True, index=True)
# Tenant linking status
is_tenant_linked = Column(Boolean, default=False, nullable=False)
tenant_linking_status = Column(String(50), nullable=True) # pending, completed, failed
linked_at = Column(DateTime(timezone=True), nullable=True)
plan = Column(String(50), default="starter") # starter, professional, enterprise
status = Column(String(50), default="active") # active, pending_cancellation, inactive, suspended, pending_tenant_linking
# Billing
monthly_price = Column(Float, default=0.0)
billing_cycle = Column(String(20), default="monthly") # monthly, yearly
next_billing_date = Column(DateTime(timezone=True))
trial_ends_at = Column(DateTime(timezone=True))
cancelled_at = Column(DateTime(timezone=True), nullable=True)
cancellation_effective_date = Column(DateTime(timezone=True), nullable=True)
# Payment provider references (generic names for provider-agnostic design)
subscription_id = Column(String(255), nullable=True) # Payment provider subscription ID
customer_id = Column(String(255), nullable=True) # Payment provider customer ID
# Limits
max_users = Column(Integer, default=5)
max_locations = Column(Integer, default=1)
max_products = Column(Integer, default=50)
# Features - Store plan features as JSON
features = Column(JSON)
# Timestamps
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
# 3D Secure (3DS) tracking
threeds_authentication_required = Column(Boolean, default=False)
threeds_authentication_required_at = Column(DateTime(timezone=True), nullable=True)
threeds_authentication_completed = Column(Boolean, default=False)
threeds_authentication_completed_at = Column(DateTime(timezone=True), nullable=True)
last_threeds_setup_intent_id = Column(String(255), nullable=True)
threeds_action_type = Column(String(100), nullable=True)
# Relationships
tenant = relationship("Tenant")
def __repr__(self):
return f"<Subscription(id={self.id}, tenant_id={self.tenant_id}, user_id={self.user_id}, plan={self.plan}, status={self.status})>"
def is_pending_tenant_linking(self) -> bool:
"""Check if subscription is waiting to be linked to a tenant"""
return self.tenant_linking_status == "pending" and not self.is_tenant_linked
def can_be_linked_to_tenant(self, user_id: str) -> bool:
"""Check if subscription can be linked to a tenant by the given user"""
return (self.is_pending_tenant_linking() and
str(self.user_id) == user_id and
self.tenant_id is None)

View File

@@ -0,0 +1,16 @@
"""
Tenant Service Repositories
Repository implementations for tenant service
"""
from .base import TenantBaseRepository
from .tenant_repository import TenantRepository
from .tenant_member_repository import TenantMemberRepository
from .subscription_repository import SubscriptionRepository
__all__ = [
"TenantBaseRepository",
"TenantRepository",
"TenantMemberRepository",
"SubscriptionRepository"
]

View File

@@ -0,0 +1,234 @@
"""
Base Repository for Tenant Service
Service-specific repository base class with tenant management utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
import json
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TenantBaseRepository(BaseRepository):
"""Base repository for tenant service with common tenant operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Tenant data is relatively stable, medium cache time (10 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_user_id(self, user_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by user ID (for cross-service references)"""
if hasattr(self.model, 'user_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"user_id": user_id},
order_by="created_at",
order_desc=True
)
elif hasattr(self.model, 'owner_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"owner_id": user_id},
order_by="created_at",
order_desc=True
)
return []
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_old_records(self, days_old: int = 365) -> int:
"""Clean up old tenant records (very conservative - 1 year)"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
table_name = self.model.__tablename__
# Only delete inactive records that are very old
conditions = [
"created_at < :cutoff_date"
]
if hasattr(self.model, 'is_active'):
conditions.append("is_active = false")
query_text = f"""
DELETE FROM {table_name}
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_statistics_by_tenant(self, tenant_id: str) -> Dict[str, Any]:
"""Get statistics for a tenant"""
try:
table_name = self.model.__tablename__
# Get basic counts
total_records = await self.count(filters={"tenant_id": tenant_id})
# Get active records if applicable
active_records = total_records
if hasattr(self.model, 'is_active'):
active_records = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
# Get recent activity (records in last 7 days)
seven_days_ago = datetime.utcnow() - timedelta(days=7)
recent_query = text(f"""
SELECT COUNT(*) as count
FROM {table_name}
WHERE tenant_id = :tenant_id
AND created_at >= :seven_days_ago
""")
result = await self.session.execute(recent_query, {
"tenant_id": tenant_id,
"seven_days_ago": seven_days_ago
})
recent_records = result.scalar() or 0
return {
"total_records": total_records,
"active_records": active_records,
"inactive_records": total_records - active_records,
"recent_records_7d": recent_records
}
except Exception as e:
logger.error("Failed to get tenant statistics",
model=self.model.__name__,
tenant_id=tenant_id,
error=str(e))
return {
"total_records": 0,
"active_records": 0,
"inactive_records": 0,
"recent_records_7d": 0
}
def _validate_tenant_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate tenant-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate user_id format if present
if "user_id" in data and data["user_id"]:
user_id = data["user_id"]
if not isinstance(user_id, str) or len(user_id) < 1:
errors.append("Invalid user_id format")
# Validate owner_id format if present
if "owner_id" in data and data["owner_id"]:
owner_id = data["owner_id"]
if not isinstance(owner_id, str) or len(owner_id) < 1:
errors.append("Invalid owner_id format")
# Validate email format if present
if "email" in data and data["email"]:
email = data["email"]
if "@" not in email or "." not in email.split("@")[-1]:
errors.append("Invalid email format")
# Validate phone format if present (basic validation)
if "phone" in data and data["phone"]:
phone = data["phone"]
if not isinstance(phone, str) or len(phone) < 9:
errors.append("Invalid phone format")
# Validate coordinates if present
if "latitude" in data and data["latitude"] is not None:
try:
lat = float(data["latitude"])
if lat < -90 or lat > 90:
errors.append("Invalid latitude - must be between -90 and 90")
except (ValueError, TypeError):
errors.append("Invalid latitude format")
if "longitude" in data and data["longitude"] is not None:
try:
lng = float(data["longitude"])
if lng < -180 or lng > 180:
errors.append("Invalid longitude - must be between -180 and 180")
except (ValueError, TypeError):
errors.append("Invalid longitude format")
# Validate JSON fields
json_fields = ["permissions"]
for field in json_fields:
if field in data and data[field]:
if isinstance(data[field], str):
try:
json.loads(data[field])
except json.JSONDecodeError:
errors.append(f"Invalid JSON format in {field}")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,326 @@
"""
Repository for coupon data access and validation
"""
from datetime import datetime, timezone
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import and_, select
from sqlalchemy.orm import selectinload
from app.models.coupon import CouponModel, CouponRedemptionModel
from shared.subscription.coupons import (
Coupon,
CouponRedemption,
CouponValidationResult,
DiscountType,
calculate_trial_end_date,
format_discount_description
)
class CouponRepository:
"""Data access layer for coupon operations"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_coupon_by_code(self, code: str) -> Optional[Coupon]:
"""
Retrieve coupon by code.
Returns None if not found.
"""
result = await self.db.execute(
select(CouponModel).where(CouponModel.code == code.upper())
)
coupon_model = result.scalar_one_or_none()
if not coupon_model:
return None
return self._model_to_dataclass(coupon_model)
async def validate_coupon(
self,
code: str,
tenant_id: str
) -> CouponValidationResult:
"""
Validate a coupon code for a specific tenant.
Checks: existence, validity, redemption limits, and if tenant already used it.
"""
# Get coupon
coupon = await self.get_coupon_by_code(code)
if not coupon:
return CouponValidationResult(
valid=False,
coupon=None,
error_message="Código de cupón inválido",
discount_preview=None
)
# Check if coupon can be redeemed
can_redeem, reason = coupon.can_be_redeemed()
if not can_redeem:
error_messages = {
"Coupon is inactive": "Este cupón no está activo",
"Coupon is not yet valid": "Este cupón aún no es válido",
"Coupon has expired": "Este cupón ha expirado",
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
}
return CouponValidationResult(
valid=False,
coupon=coupon,
error_message=error_messages.get(reason, reason),
discount_preview=None
)
# Check if tenant already redeemed this coupon
result = await self.db.execute(
select(CouponRedemptionModel).where(
and_(
CouponRedemptionModel.tenant_id == tenant_id,
CouponRedemptionModel.coupon_code == code.upper()
)
)
)
existing_redemption = result.scalar_one_or_none()
if existing_redemption:
return CouponValidationResult(
valid=False,
coupon=coupon,
error_message="Ya has utilizado este cupón",
discount_preview=None
)
# Generate discount preview
discount_preview = self._generate_discount_preview(coupon)
return CouponValidationResult(
valid=True,
coupon=coupon,
error_message=None,
discount_preview=discount_preview
)
async def redeem_coupon(
self,
code: str,
tenant_id: Optional[str],
base_trial_days: int = 0
) -> tuple[bool, Optional[CouponRedemption], Optional[str]]:
"""
Redeem a coupon for a tenant.
For tenant-independent registrations, tenant_id can be None initially.
Returns (success, redemption, error_message)
"""
# For tenant-independent registrations, skip tenant validation
if tenant_id:
# Validate first
validation = await self.validate_coupon(code, tenant_id)
if not validation.valid:
return False, None, validation.error_message
coupon = validation.coupon
else:
# Just get the coupon and validate its general availability
coupon = await self.get_coupon_by_code(code)
if not coupon:
return False, None, "Código de cupón inválido"
# Check if coupon can be redeemed
can_redeem, reason = coupon.can_be_redeemed()
if not can_redeem:
error_messages = {
"Coupon is inactive": "Este cupón no está activo",
"Coupon is not yet valid": "Este cupón aún no es válido",
"Coupon has expired": "Este cupón ha expirado",
"Coupon has reached maximum redemptions": "Este cupón ha alcanzado su límite de usos"
}
return False, None, error_messages.get(reason, reason)
# Calculate discount applied
discount_applied = self._calculate_discount_applied(
coupon,
base_trial_days
)
# Only create redemption record if tenant_id is provided
# For tenant-independent subscriptions, skip redemption record creation
if tenant_id:
# Create redemption record
redemption_model = CouponRedemptionModel(
tenant_id=tenant_id,
coupon_code=code.upper(),
redeemed_at=datetime.now(timezone.utc),
discount_applied=discount_applied,
extra_data={
"coupon_type": coupon.discount_type.value,
"coupon_value": coupon.discount_value
}
)
self.db.add(redemption_model)
# Increment coupon redemption count
result = await self.db.execute(
select(CouponModel).where(CouponModel.code == code.upper())
)
coupon_model = result.scalar_one_or_none()
if coupon_model:
coupon_model.current_redemptions += 1
try:
await self.db.commit()
await self.db.refresh(redemption_model)
redemption = CouponRedemption(
id=str(redemption_model.id),
tenant_id=redemption_model.tenant_id,
coupon_code=redemption_model.coupon_code,
redeemed_at=redemption_model.redeemed_at,
discount_applied=redemption_model.discount_applied,
extra_data=redemption_model.extra_data
)
return True, redemption, None
except Exception as e:
await self.db.rollback()
return False, None, f"Error al aplicar el cupón: {str(e)}"
else:
# For tenant-independent subscriptions, return discount without creating redemption
# The redemption will be created when the tenant is linked
redemption = CouponRedemption(
id="pending", # Temporary ID
tenant_id="pending", # Will be set during tenant linking
coupon_code=code.upper(),
redeemed_at=datetime.now(timezone.utc),
discount_applied=discount_applied,
extra_data={
"coupon_type": coupon.discount_type.value,
"coupon_value": coupon.discount_value
}
)
return True, redemption, None
async def get_redemption_by_tenant_and_code(
self,
tenant_id: str,
code: str
) -> Optional[CouponRedemption]:
"""Get existing redemption for tenant and coupon code"""
result = await self.db.execute(
select(CouponRedemptionModel).where(
and_(
CouponRedemptionModel.tenant_id == tenant_id,
CouponRedemptionModel.coupon_code == code.upper()
)
)
)
redemption_model = result.scalar_one_or_none()
if not redemption_model:
return None
return CouponRedemption(
id=str(redemption_model.id),
tenant_id=redemption_model.tenant_id,
coupon_code=redemption_model.coupon_code,
redeemed_at=redemption_model.redeemed_at,
discount_applied=redemption_model.discount_applied,
extra_data=redemption_model.extra_data
)
async def get_coupon_usage_stats(self, code: str) -> Optional[dict]:
"""Get usage statistics for a coupon"""
result = await self.db.execute(
select(CouponModel).where(CouponModel.code == code.upper())
)
coupon_model = result.scalar_one_or_none()
if not coupon_model:
return None
count_result = await self.db.execute(
select(CouponRedemptionModel).where(
CouponRedemptionModel.coupon_code == code.upper()
)
)
redemptions_count = len(count_result.scalars().all())
return {
"code": coupon_model.code,
"current_redemptions": coupon_model.current_redemptions,
"max_redemptions": coupon_model.max_redemptions,
"redemptions_remaining": (
coupon_model.max_redemptions - coupon_model.current_redemptions
if coupon_model.max_redemptions
else None
),
"active": coupon_model.active,
"valid_from": coupon_model.valid_from.isoformat(),
"valid_until": coupon_model.valid_until.isoformat() if coupon_model.valid_until else None
}
def _model_to_dataclass(self, model: CouponModel) -> Coupon:
"""Convert SQLAlchemy model to dataclass"""
return Coupon(
id=str(model.id),
code=model.code,
discount_type=DiscountType(model.discount_type),
discount_value=model.discount_value,
max_redemptions=model.max_redemptions,
current_redemptions=model.current_redemptions,
valid_from=model.valid_from,
valid_until=model.valid_until,
active=model.active,
created_at=model.created_at,
extra_data=model.extra_data
)
def _generate_discount_preview(self, coupon: Coupon) -> dict:
"""Generate a preview of the discount to be applied"""
description = format_discount_description(coupon)
preview = {
"description": description,
"discount_type": coupon.discount_type.value,
"discount_value": coupon.discount_value
}
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
trial_end = calculate_trial_end_date(0, coupon.discount_value)
preview["trial_end_date"] = trial_end.isoformat()
preview["total_trial_days"] = 0 + coupon.discount_value
return preview
def _calculate_discount_applied(
self,
coupon: Coupon,
base_trial_days: int
) -> dict:
"""Calculate the actual discount that will be applied"""
discount = {
"type": coupon.discount_type.value,
"value": coupon.discount_value,
"description": format_discount_description(coupon)
}
if coupon.discount_type == DiscountType.TRIAL_EXTENSION:
total_trial_days = base_trial_days + coupon.discount_value
trial_end = calculate_trial_end_date(base_trial_days, coupon.discount_value)
discount["base_trial_days"] = base_trial_days
discount["extension_days"] = coupon.discount_value
discount["total_trial_days"] = total_trial_days
discount["trial_end_date"] = trial_end.isoformat()
elif coupon.discount_type == DiscountType.PERCENTAGE:
discount["percentage_off"] = coupon.discount_value
elif coupon.discount_type == DiscountType.FIXED_AMOUNT:
discount["amount_off_cents"] = coupon.discount_value
discount["amount_off_euros"] = coupon.discount_value / 100
return discount

View File

@@ -0,0 +1,283 @@
"""
Event Repository
Data access layer for events
"""
from typing import List, Optional, Dict, Any
from datetime import date, datetime
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func
from uuid import UUID
import structlog
from app.models.events import Event, EventTemplate
from shared.database.repository import BaseRepository
logger = structlog.get_logger()
class EventRepository(BaseRepository[Event]):
"""Repository for event management"""
def __init__(self, session: AsyncSession):
super().__init__(Event, session)
async def get_events_by_date_range(
self,
tenant_id: UUID,
start_date: date,
end_date: date,
event_types: List[str] = None,
confirmed_only: bool = False
) -> List[Event]:
"""
Get events within a date range.
Args:
tenant_id: Tenant UUID
start_date: Start date (inclusive)
end_date: End date (inclusive)
event_types: Optional filter by event types
confirmed_only: Only return confirmed events
Returns:
List of Event objects
"""
try:
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_date >= start_date,
Event.event_date <= end_date
)
)
if event_types:
query = query.where(Event.event_type.in_(event_types))
if confirmed_only:
query = query.where(Event.is_confirmed == True)
query = query.order_by(Event.event_date)
result = await self.session.execute(query)
events = result.scalars().all()
logger.debug("Retrieved events by date range",
tenant_id=str(tenant_id),
start_date=start_date.isoformat(),
end_date=end_date.isoformat(),
count=len(events))
return list(events)
except Exception as e:
logger.error("Failed to get events by date range",
tenant_id=str(tenant_id),
error=str(e))
return []
async def get_events_for_date(
self,
tenant_id: UUID,
event_date: date
) -> List[Event]:
"""
Get all events for a specific date.
Args:
tenant_id: Tenant UUID
event_date: Date to get events for
Returns:
List of Event objects
"""
try:
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_date == event_date
)
).order_by(Event.start_time)
result = await self.session.execute(query)
events = result.scalars().all()
return list(events)
except Exception as e:
logger.error("Failed to get events for date",
tenant_id=str(tenant_id),
error=str(e))
return []
async def get_upcoming_events(
self,
tenant_id: UUID,
days_ahead: int = 30,
limit: int = 100
) -> List[Event]:
"""
Get upcoming events.
Args:
tenant_id: Tenant UUID
days_ahead: Number of days to look ahead
limit: Maximum number of events to return
Returns:
List of upcoming Event objects
"""
try:
from datetime import date, timedelta
today = date.today()
future_date = today + timedelta(days=days_ahead)
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_date >= today,
Event.event_date <= future_date
)
).order_by(Event.event_date).limit(limit)
result = await self.session.execute(query)
events = result.scalars().all()
return list(events)
except Exception as e:
logger.error("Failed to get upcoming events",
tenant_id=str(tenant_id),
error=str(e))
return []
async def create_event(self, event_data: Dict[str, Any]) -> Event:
"""Create a new event"""
try:
event = Event(**event_data)
self.session.add(event)
await self.session.flush()
logger.info("Created event",
event_id=str(event.id),
event_name=event.event_name,
event_date=event.event_date.isoformat())
return event
except Exception as e:
logger.error("Failed to create event", error=str(e))
raise
async def update_event_actual_impact(
self,
event_id: UUID,
actual_impact_multiplier: float,
actual_sales_increase_percent: float
) -> Optional[Event]:
"""
Update event with actual impact after it occurs.
Args:
event_id: Event UUID
actual_impact_multiplier: Actual demand multiplier observed
actual_sales_increase_percent: Actual sales increase percentage
Returns:
Updated Event or None
"""
try:
event = await self.get(event_id)
if not event:
return None
event.actual_impact_multiplier = actual_impact_multiplier
event.actual_sales_increase_percent = actual_sales_increase_percent
await self.session.flush()
logger.info("Updated event actual impact",
event_id=str(event_id),
actual_multiplier=actual_impact_multiplier)
return event
except Exception as e:
logger.error("Failed to update event actual impact",
event_id=str(event_id),
error=str(e))
return None
async def get_events_by_type(
self,
tenant_id: UUID,
event_type: str,
limit: int = 100
) -> List[Event]:
"""Get events by type"""
try:
query = select(Event).where(
and_(
Event.tenant_id == tenant_id,
Event.event_type == event_type
)
).order_by(Event.event_date.desc()).limit(limit)
result = await self.session.execute(query)
events = result.scalars().all()
return list(events)
except Exception as e:
logger.error("Failed to get events by type",
tenant_id=str(tenant_id),
event_type=event_type,
error=str(e))
return []
class EventTemplateRepository(BaseRepository[EventTemplate]):
"""Repository for event template management"""
def __init__(self, session: AsyncSession):
super().__init__(EventTemplate, session)
async def get_active_templates(self, tenant_id: UUID) -> List[EventTemplate]:
"""Get all active event templates for a tenant"""
try:
query = select(EventTemplate).where(
and_(
EventTemplate.tenant_id == tenant_id,
EventTemplate.is_active == True
)
).order_by(EventTemplate.template_name)
result = await self.session.execute(query)
templates = result.scalars().all()
return list(templates)
except Exception as e:
logger.error("Failed to get active templates",
tenant_id=str(tenant_id),
error=str(e))
return []
async def create_template(self, template_data: Dict[str, Any]) -> EventTemplate:
"""Create a new event template"""
try:
template = EventTemplate(**template_data)
self.session.add(template)
await self.session.flush()
logger.info("Created event template",
template_id=str(template.id),
template_name=template.template_name)
return template
except Exception as e:
logger.error("Failed to create event template", error=str(e))
raise

View File

@@ -0,0 +1,812 @@
"""
Subscription Repository
Repository for subscription operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import json
from .base import TenantBaseRepository
from app.models.tenants import Subscription
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
from shared.subscription.plans import SubscriptionPlanMetadata, QuotaLimits, PlanPricing
logger = structlog.get_logger()
class SubscriptionRepository(TenantBaseRepository):
"""Repository for subscription operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Subscriptions are relatively stable, medium cache time (10 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_subscription(self, subscription_data: Dict[str, Any]) -> Subscription:
"""Create a new subscription with validation"""
try:
# Validate subscription data
validation_result = self._validate_tenant_data(
subscription_data,
["tenant_id", "plan"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
# Check for existing active subscription
existing_subscription = await self.get_active_subscription(
subscription_data["tenant_id"]
)
if existing_subscription:
raise DuplicateRecordError(f"Tenant already has an active subscription")
# Set default values based on plan from centralized configuration
plan = subscription_data["plan"]
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
# Set defaults from centralized plan configuration
if "monthly_price" not in subscription_data:
billing_cycle = subscription_data.get("billing_cycle", "monthly")
subscription_data["monthly_price"] = float(
PlanPricing.get_price(plan, billing_cycle)
)
if "max_users" not in subscription_data:
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
if "max_locations" not in subscription_data:
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
if "max_products" not in subscription_data:
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
if "features" not in subscription_data:
subscription_data["features"] = {
feature: True for feature in plan_info.get("features", [])
}
# Set default subscription values
if "status" not in subscription_data:
subscription_data["status"] = "active"
if "billing_cycle" not in subscription_data:
subscription_data["billing_cycle"] = "monthly"
if "next_billing_date" not in subscription_data:
# Set next billing date based on cycle
if subscription_data["billing_cycle"] == "yearly":
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Check if subscription with this subscription_id already exists to prevent duplicates
if subscription_data.get('subscription_id'):
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
logger.info("Existing subscription updated",
subscription_id=subscription_data['subscription_id'],
tenant_id=subscription_data.get('tenant_id'),
plan=subscription_data.get('plan'))
return updated_subscription
# Create subscription
subscription = await self.create(subscription_data)
logger.info("Subscription created successfully",
subscription_id=subscription.id,
tenant_id=subscription.tenant_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create subscription",
tenant_id=subscription_data.get("tenant_id"),
plan=subscription_data.get("plan"),
error=str(e))
raise DatabaseError(f"Failed to create subscription: {str(e)}")
async def get_by_tenant_id(self, tenant_id: str) -> Optional[Subscription]:
"""Get subscription by tenant ID"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by tenant ID",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_by_provider_id(self, subscription_id: str) -> Optional[Subscription]:
"""Get subscription by payment provider subscription ID"""
try:
subscriptions = await self.get_multi(
filters={
"subscription_id": subscription_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get subscription by provider ID",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_active_subscription(self, tenant_id: str) -> Optional[Subscription]:
"""Get active subscription for tenant"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_id": tenant_id,
"status": "active"
},
limit=1,
order_by="created_at",
order_desc=True
)
return subscriptions[0] if subscriptions else None
except Exception as e:
logger.error("Failed to get active subscription",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscription: {str(e)}")
async def get_tenant_subscriptions(
self,
tenant_id: str,
include_inactive: bool = False
) -> List[Subscription]:
"""Get all subscriptions for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if not include_inactive:
filters["status"] = "active"
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenant subscriptions",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get subscriptions: {str(e)}")
async def update_subscription_plan(
self,
subscription_id: str,
new_plan: str,
billing_cycle: str = "monthly"
) -> Optional[Subscription]:
"""Update subscription plan and pricing using centralized configuration"""
try:
valid_plans = ["starter", "professional", "enterprise"]
if new_plan not in valid_plans:
raise ValidationError(f"Invalid plan. Must be one of: {valid_plans}")
# Get current subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Get new plan configuration from centralized source
plan_info = SubscriptionPlanMetadata.get_plan_info(new_plan)
# Update subscription with new plan details
update_data = {
"plan": new_plan,
"monthly_price": float(PlanPricing.get_price(new_plan, billing_cycle)),
"billing_cycle": billing_cycle,
"max_users": QuotaLimits.get_limit('MAX_USERS', new_plan) or -1,
"max_locations": QuotaLimits.get_limit('MAX_LOCATIONS', new_plan) or -1,
"max_products": QuotaLimits.get_limit('MAX_PRODUCTS', new_plan) or -1,
"features": {feature: True for feature in plan_info.get("features", [])},
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription plan updated",
subscription_id=subscription_id,
new_plan=new_plan,
new_price=update_data["monthly_price"])
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription plan",
subscription_id=subscription_id,
new_plan=new_plan,
error=str(e))
raise DatabaseError(f"Failed to update plan: {str(e)}")
async def update_subscription_status(
self,
subscription_id: str,
status: str,
additional_data: Dict[str, Any] = None
) -> Optional[Subscription]:
"""Update subscription status with optional additional data"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": status,
"updated_at": datetime.utcnow()
}
# Merge additional data if provided
if additional_data:
update_data.update(additional_data)
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
if subscription.tenant_id:
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription status updated",
subscription_id=subscription_id,
new_status=status,
additional_data=additional_data)
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription status",
subscription_id=subscription_id,
status=status,
error=str(e))
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
async def cancel_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Cancel a subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": "cancelled",
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription cancelled",
subscription_id=subscription_id,
reason=reason)
return updated_subscription
except Exception as e:
logger.error("Failed to cancel subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
async def suspend_subscription(
self,
subscription_id: str,
reason: str = None
) -> Optional[Subscription]:
"""Suspend a subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
update_data = {
"status": "suspended",
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription suspended",
subscription_id=subscription_id,
reason=reason)
return updated_subscription
except Exception as e:
logger.error("Failed to suspend subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to suspend subscription: {str(e)}")
async def reactivate_subscription(
self,
subscription_id: str
) -> Optional[Subscription]:
"""Reactivate a cancelled or suspended subscription"""
try:
# Get subscription to find tenant_id for cache invalidation
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Reset billing date when reactivating
next_billing_date = datetime.utcnow() + timedelta(days=30)
update_data = {
"status": "active",
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache
await self._invalidate_cache(str(subscription.tenant_id))
logger.info("Subscription reactivated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
return updated_subscription
except Exception as e:
logger.error("Failed to reactivate subscription",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
async def get_subscriptions_due_for_billing(
self,
days_ahead: int = 3
) -> List[Subscription]:
"""Get subscriptions that need billing in the next N days"""
try:
cutoff_date = datetime.utcnow() + timedelta(days=days_ahead)
query_text = """
SELECT * FROM subscriptions
WHERE status = 'active'
AND next_billing_date <= :cutoff_date
ORDER BY next_billing_date ASC
"""
result = await self.session.execute(text(query_text), {
"cutoff_date": cutoff_date
})
subscriptions = []
for row in result.fetchall():
record_dict = dict(row._mapping)
subscription = self.model(**record_dict)
subscriptions.append(subscription)
return subscriptions
except Exception as e:
logger.error("Failed to get subscriptions due for billing",
days_ahead=days_ahead,
error=str(e))
return []
async def update_billing_date(
self,
subscription_id: str,
next_billing_date: datetime
) -> Optional[Subscription]:
"""Update next billing date for subscription"""
try:
updated_subscription = await self.update(subscription_id, {
"next_billing_date": next_billing_date,
"updated_at": datetime.utcnow()
})
logger.info("Subscription billing date updated",
subscription_id=subscription_id,
next_billing_date=next_billing_date)
return updated_subscription
except Exception as e:
logger.error("Failed to update billing date",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to update billing date: {str(e)}")
async def get_subscription_statistics(self) -> Dict[str, Any]:
"""Get subscription statistics"""
try:
# Get counts by plan
plan_query = text("""
SELECT plan, COUNT(*) as count
FROM subscriptions
WHERE status = 'active'
GROUP BY plan
ORDER BY count DESC
""")
result = await self.session.execute(plan_query)
subscriptions_by_plan = {row.plan: row.count for row in result.fetchall()}
# Get counts by status
status_query = text("""
SELECT status, COUNT(*) as count
FROM subscriptions
GROUP BY status
ORDER BY count DESC
""")
result = await self.session.execute(status_query)
subscriptions_by_status = {row.status: row.count for row in result.fetchall()}
# Get revenue statistics
revenue_query = text("""
SELECT
SUM(monthly_price) as total_monthly_revenue,
AVG(monthly_price) as avg_monthly_price,
COUNT(*) as total_active_subscriptions
FROM subscriptions
WHERE status = 'active'
""")
revenue_result = await self.session.execute(revenue_query)
revenue_row = revenue_result.fetchone()
# Get upcoming billing count
thirty_days_ahead = datetime.utcnow() + timedelta(days=30)
upcoming_billing = len(await self.get_subscriptions_due_for_billing(30))
return {
"subscriptions_by_plan": subscriptions_by_plan,
"subscriptions_by_status": subscriptions_by_status,
"total_monthly_revenue": float(revenue_row.total_monthly_revenue or 0),
"avg_monthly_price": float(revenue_row.avg_monthly_price or 0),
"total_active_subscriptions": int(revenue_row.total_active_subscriptions or 0),
"upcoming_billing_30d": upcoming_billing
}
except Exception as e:
logger.error("Failed to get subscription statistics", error=str(e))
return {
"subscriptions_by_plan": {},
"subscriptions_by_status": {},
"total_monthly_revenue": 0.0,
"avg_monthly_price": 0.0,
"total_active_subscriptions": 0,
"upcoming_billing_30d": 0
}
async def cleanup_old_subscriptions(self, days_old: int = 730) -> int:
"""Clean up very old cancelled subscriptions (2 years)"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM subscriptions
WHERE status IN ('cancelled', 'suspended')
AND updated_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up old subscriptions",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old subscriptions",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def _invalidate_cache(self, tenant_id: str) -> None:
"""
Invalidate subscription cache for a tenant
Args:
tenant_id: Tenant ID
"""
try:
from app.services.subscription_cache import get_subscription_cache_service
cache_service = get_subscription_cache_service()
await cache_service.invalidate_subscription_cache(tenant_id)
logger.debug("Invalidated subscription cache from repository",
tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to invalidate cache (non-critical)",
tenant_id=tenant_id, error=str(e))
# ========================================================================
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
# ========================================================================
async def create_tenant_independent_subscription(
self,
subscription_data: Dict[str, Any]
) -> Subscription:
"""Create a subscription not linked to any tenant (for registration flow)"""
try:
# Validate required data for tenant-independent subscription
# user_id may not exist during registration, so validate other required fields
required_fields = ["plan", "subscription_id", "customer_id"]
validation_result = self._validate_tenant_data(subscription_data, required_fields)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid subscription data: {validation_result['errors']}")
# Ensure tenant_id is not provided (this is tenant-independent)
if "tenant_id" in subscription_data and subscription_data["tenant_id"]:
raise ValidationError("tenant_id should not be provided for tenant-independent subscriptions")
# Set tenant-independent specific fields
subscription_data["tenant_id"] = None
subscription_data["is_tenant_linked"] = False
subscription_data["tenant_linking_status"] = "pending"
subscription_data["linked_at"] = None
# Set default values based on plan from centralized configuration
plan = subscription_data["plan"]
plan_info = SubscriptionPlanMetadata.get_plan_info(plan)
# Set defaults from centralized plan configuration
if "monthly_price" not in subscription_data:
billing_cycle = subscription_data.get("billing_cycle", "monthly")
subscription_data["monthly_price"] = float(
PlanPricing.get_price(plan, billing_cycle)
)
if "max_users" not in subscription_data:
subscription_data["max_users"] = QuotaLimits.get_limit('MAX_USERS', plan) or -1
if "max_locations" not in subscription_data:
subscription_data["max_locations"] = QuotaLimits.get_limit('MAX_LOCATIONS', plan) or -1
if "max_products" not in subscription_data:
subscription_data["max_products"] = QuotaLimits.get_limit('MAX_PRODUCTS', plan) or -1
if "features" not in subscription_data:
subscription_data["features"] = {
feature: True for feature in plan_info.get("features", [])
}
# Set default subscription values
if "status" not in subscription_data:
subscription_data["status"] = "pending_tenant_linking"
if "billing_cycle" not in subscription_data:
subscription_data["billing_cycle"] = "monthly"
if "next_billing_date" not in subscription_data:
# Set next billing date based on cycle
if subscription_data["billing_cycle"] == "yearly":
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=365)
else:
subscription_data["next_billing_date"] = datetime.utcnow() + timedelta(days=30)
# Check if subscription with this subscription_id already exists
existing_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.update(str(existing_subscription.id), subscription_data)
logger.info("Existing tenant-independent subscription updated",
subscription_id=subscription_data['subscription_id'],
user_id=subscription_data.get('user_id'),
plan=subscription_data.get('plan'))
return updated_subscription
else:
# Create new subscription, but handle potential duplicate errors
try:
subscription = await self.create(subscription_data)
logger.info("Tenant-independent subscription created successfully",
subscription_id=subscription.id,
user_id=subscription.user_id,
plan=subscription.plan,
monthly_price=subscription.monthly_price)
return subscription
except DuplicateRecordError:
# Another process may have created the subscription between our check and create
# Try to get the existing subscription and return it
final_subscription = await self.get_by_provider_id(subscription_data['subscription_id'])
if final_subscription:
logger.info("Race condition detected: subscription already created by another process",
subscription_id=subscription_data['subscription_id'])
return final_subscription
else:
# This shouldn't happen, but re-raise the error if we can't find it
raise
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create tenant-independent subscription",
user_id=subscription_data.get("user_id"),
plan=subscription_data.get("plan"),
error=str(e))
raise DatabaseError(f"Failed to create tenant-independent subscription: {str(e)}")
async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]:
"""Get all subscriptions waiting to be linked to tenants"""
try:
subscriptions = await self.get_multi(
filters={
"tenant_linking_status": "pending",
"is_tenant_linked": False
},
order_by="created_at",
order_desc=True
)
return subscriptions
except Exception as e:
logger.error("Failed to get pending tenant linking subscriptions",
error=str(e))
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
"""Get pending tenant linking subscriptions for a specific user"""
try:
subscriptions = await self.get_multi(
filters={
"user_id": user_id,
"tenant_linking_status": "pending",
"is_tenant_linked": False
},
order_by="created_at",
order_desc=True
)
return subscriptions
except Exception as e:
logger.error("Failed to get pending subscriptions by user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def link_subscription_to_tenant(
self,
subscription_id: str,
tenant_id: str,
user_id: str
) -> Subscription:
"""Link a pending subscription to a tenant"""
try:
# Get the subscription first
subscription = await self.get_by_id(subscription_id)
if not subscription:
raise ValidationError(f"Subscription {subscription_id} not found")
# Validate subscription can be linked
if not subscription.can_be_linked_to_tenant(user_id):
raise ValidationError(
f"Subscription {subscription_id} cannot be linked to tenant by user {user_id}. "
f"Current status: {subscription.tenant_linking_status}, "
f"User: {subscription.user_id}, "
f"Already linked: {subscription.is_tenant_linked}"
)
# Update subscription with tenant information
update_data = {
"tenant_id": tenant_id,
"is_tenant_linked": True,
"tenant_linking_status": "completed",
"linked_at": datetime.utcnow(),
"status": "active", # Activate subscription when linked to tenant
"updated_at": datetime.utcnow()
}
updated_subscription = await self.update(subscription_id, update_data)
# Invalidate cache for the tenant
await self._invalidate_cache(tenant_id)
logger.info("Subscription linked to tenant successfully",
subscription_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id)
return updated_subscription
except Exception as e:
logger.error("Failed to link subscription to tenant",
subscription_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}")
async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int:
"""Clean up subscriptions that were never linked to tenants"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM subscriptions
WHERE tenant_linking_status = 'pending'
AND is_tenant_linked = FALSE
AND created_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up orphaned subscriptions",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup orphaned subscriptions",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get subscriptions by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
query = select(Subscription).where(Subscription.customer_id == customer_id)
result = await self.session.execute(query)
subscriptions = result.scalars().all()
logger.debug("Found subscriptions by customer_id",
customer_id=customer_id,
count=len(subscriptions))
return subscriptions
except Exception as e:
logger.error("Error getting subscriptions by customer_id",
customer_id=customer_id,
error=str(e))
raise DatabaseError(f"Failed to get subscriptions by customer_id: {str(e)}")

View File

@@ -0,0 +1,218 @@
"""
Tenant Location Repository
Handles database operations for tenant location data
"""
from typing import List, Optional, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, delete
from sqlalchemy.orm import selectinload
import structlog
from app.models.tenant_location import TenantLocation
from app.models.tenants import Tenant
from shared.database.exceptions import DatabaseError
from .base import BaseRepository
logger = structlog.get_logger()
class TenantLocationRepository(BaseRepository):
"""Repository for tenant location operations"""
def __init__(self, session: AsyncSession):
super().__init__(TenantLocation, session)
async def create_location(self, location_data: Dict[str, Any]) -> TenantLocation:
"""
Create a new tenant location
Args:
location_data: Dictionary containing location information
Returns:
Created TenantLocation object
"""
try:
# Create new location instance
location = TenantLocation(**location_data)
self.session.add(location)
await self.session.commit()
await self.session.refresh(location)
logger.info(f"Created new tenant location: {location.id} for tenant {location.tenant_id}")
return location
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to create tenant location: {str(e)}")
raise DatabaseError(f"Failed to create tenant location: {str(e)}")
async def get_location_by_id(self, location_id: str) -> Optional[TenantLocation]:
"""
Get a location by its ID
Args:
location_id: UUID of the location
Returns:
TenantLocation object if found, None otherwise
"""
try:
stmt = select(TenantLocation).where(TenantLocation.id == location_id)
result = await self.session.execute(stmt)
location = result.scalar_one_or_none()
return location
except Exception as e:
logger.error(f"Failed to get location by ID: {str(e)}")
raise DatabaseError(f"Failed to get location by ID: {str(e)}")
async def get_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
"""
Get all locations for a specific tenant
Args:
tenant_id: UUID of the tenant
Returns:
List of TenantLocation objects
"""
try:
stmt = select(TenantLocation).where(TenantLocation.tenant_id == tenant_id)
result = await self.session.execute(stmt)
locations = result.scalars().all()
return locations
except Exception as e:
logger.error(f"Failed to get locations by tenant: {str(e)}")
raise DatabaseError(f"Failed to get locations by tenant: {str(e)}")
async def get_location_by_type(self, tenant_id: str, location_type: str) -> Optional[TenantLocation]:
"""
Get a location by tenant and type
Args:
tenant_id: UUID of the tenant
location_type: Type of location (e.g., 'central_production', 'retail_outlet')
Returns:
TenantLocation object if found, None otherwise
"""
try:
stmt = select(TenantLocation).where(
TenantLocation.tenant_id == tenant_id,
TenantLocation.location_type == location_type
)
result = await self.session.execute(stmt)
location = result.scalar_one_or_none()
return location
except Exception as e:
logger.error(f"Failed to get location by type: {str(e)}")
raise DatabaseError(f"Failed to get location by type: {str(e)}")
async def update_location(self, location_id: str, location_data: Dict[str, Any]) -> Optional[TenantLocation]:
"""
Update a tenant location
Args:
location_id: UUID of the location to update
location_data: Dictionary containing updated location information
Returns:
Updated TenantLocation object if successful, None if location not found
"""
try:
stmt = (
update(TenantLocation)
.where(TenantLocation.id == location_id)
.values(**location_data)
)
await self.session.execute(stmt)
# Now fetch the updated location
location_stmt = select(TenantLocation).where(TenantLocation.id == location_id)
result = await self.session.execute(location_stmt)
location = result.scalar_one_or_none()
if location:
await self.session.commit()
logger.info(f"Updated tenant location: {location_id}")
return location
else:
await self.session.rollback()
logger.warning(f"Location not found for update: {location_id}")
return None
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to update location: {str(e)}")
raise DatabaseError(f"Failed to update location: {str(e)}")
async def delete_location(self, location_id: str) -> bool:
"""
Delete a tenant location
Args:
location_id: UUID of the location to delete
Returns:
True if deleted successfully, False if location not found
"""
try:
stmt = delete(TenantLocation).where(TenantLocation.id == location_id)
result = await self.session.execute(stmt)
if result.rowcount > 0:
await self.session.commit()
logger.info(f"Deleted tenant location: {location_id}")
return True
else:
await self.session.rollback()
logger.warning(f"Location not found for deletion: {location_id}")
return False
except Exception as e:
await self.session.rollback()
logger.error(f"Failed to delete location: {str(e)}")
raise DatabaseError(f"Failed to delete location: {str(e)}")
async def get_active_locations_by_tenant(self, tenant_id: str) -> List[TenantLocation]:
"""
Get all active locations for a specific tenant
Args:
tenant_id: UUID of the tenant
Returns:
List of active TenantLocation objects
"""
try:
stmt = select(TenantLocation).where(
TenantLocation.tenant_id == tenant_id,
TenantLocation.is_active == True
)
result = await self.session.execute(stmt)
locations = result.scalars().all()
return locations
except Exception as e:
logger.error(f"Failed to get active locations by tenant: {str(e)}")
raise DatabaseError(f"Failed to get active locations by tenant: {str(e)}")
async def get_locations_by_tenant_with_type(self, tenant_id: str, location_types: List[str]) -> List[TenantLocation]:
"""
Get locations for a specific tenant filtered by location types
Args:
tenant_id: UUID of the tenant
location_types: List of location types to filter by
Returns:
List of TenantLocation objects matching the criteria
"""
try:
stmt = select(TenantLocation).where(
TenantLocation.tenant_id == tenant_id,
TenantLocation.location_type.in_(location_types)
)
result = await self.session.execute(stmt)
locations = result.scalars().all()
return locations
except Exception as e:
logger.error(f"Failed to get locations by tenant and type: {str(e)}")
raise DatabaseError(f"Failed to get locations by tenant and type: {str(e)}")

View File

@@ -0,0 +1,588 @@
"""
Tenant Member Repository
Repository for tenant membership operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import json
from .base import TenantBaseRepository
from app.models.tenants import TenantMember
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
from shared.config.base import is_internal_service
logger = structlog.get_logger()
class TenantMemberRepository(TenantBaseRepository):
"""Repository for tenant member operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Member data changes more frequently, shorter cache time (5 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_membership(self, membership_data: Dict[str, Any]) -> TenantMember:
"""Create a new tenant membership with validation"""
try:
# Validate membership data
validation_result = self._validate_tenant_data(
membership_data,
["tenant_id", "user_id", "role"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid membership data: {validation_result['errors']}")
# Check for existing membership
existing_membership = await self.get_membership(
membership_data["tenant_id"],
membership_data["user_id"]
)
if existing_membership and existing_membership.is_active:
raise DuplicateRecordError(f"User is already an active member of this tenant")
# Set default values
if "is_active" not in membership_data:
membership_data["is_active"] = True
if "joined_at" not in membership_data:
membership_data["joined_at"] = datetime.utcnow()
# Set permissions based on role
if "permissions" not in membership_data:
membership_data["permissions"] = self._get_default_permissions(
membership_data["role"]
)
# If reactivating existing membership
if existing_membership and not existing_membership.is_active:
# Update existing membership
update_data = {
key: value for key, value in membership_data.items()
if key not in ["tenant_id", "user_id"]
}
membership = await self.update(existing_membership.id, update_data)
else:
# Create new membership
membership = await self.create(membership_data)
logger.info("Tenant membership created",
membership_id=membership.id,
tenant_id=membership.tenant_id,
user_id=membership.user_id,
role=membership.role)
return membership
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create membership",
tenant_id=membership_data.get("tenant_id"),
user_id=membership_data.get("user_id"),
error=str(e))
raise DatabaseError(f"Failed to create membership: {str(e)}")
async def get_membership(self, tenant_id: str, user_id: str) -> Optional[TenantMember]:
"""Get specific membership by tenant and user"""
try:
# Validate that user_id is a proper UUID format for actual users
# Service names like 'inventory-service' should be handled differently
import uuid
try:
uuid.UUID(user_id)
is_valid_uuid = True
except ValueError:
is_valid_uuid = False
# For internal service access, return None to indicate no user membership
# Service access should be handled at the API layer
if not is_valid_uuid:
if is_internal_service(user_id):
# This is a known internal service request, return None
# Service access is granted at the API endpoint level
logger.debug("Internal service detected in membership lookup",
service=user_id,
tenant_id=tenant_id)
return None
elif user_id == "unknown-service":
# Special handling for 'unknown-service' which commonly occurs in demo sessions
# This happens when service identification fails during demo operations
logger.warning("Demo session service identification issue",
service=user_id,
tenant_id=tenant_id,
message="Service not properly identified - likely demo session context")
return None
else:
# This is an unknown service
# Return None to prevent database errors, but log a warning
logger.warning("Unknown service detected in membership lookup",
service=user_id,
tenant_id=tenant_id,
message="Service not in internal services registry")
return None
memberships = await self.get_multi(
filters={
"tenant_id": tenant_id,
"user_id": user_id
},
limit=1,
order_by="created_at",
order_desc=True
)
return memberships[0] if memberships else None
except Exception as e:
logger.error("Failed to get membership",
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get membership: {str(e)}")
async def get_tenant_members(
self,
tenant_id: str,
active_only: bool = True,
role: str = None,
include_user_info: bool = False
) -> List[TenantMember]:
"""Get all members of a tenant with optional user info enrichment"""
try:
filters = {"tenant_id": tenant_id}
if active_only:
filters["is_active"] = True
if role:
filters["role"] = role
members = await self.get_multi(
filters=filters,
order_by="joined_at",
order_desc=False
)
# If include_user_info is True, enrich with user data from auth service
if include_user_info and members:
members = await self._enrich_members_with_user_info(members)
return members
except Exception as e:
logger.error("Failed to get tenant members",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get members: {str(e)}")
async def _enrich_members_with_user_info(self, members: List[TenantMember]) -> List[TenantMember]:
"""Enrich member objects with user information from auth service using batch endpoint"""
try:
import httpx
import os
if not members:
return members
# Get unique user IDs
user_ids = list(set([str(member.user_id) for member in members]))
if not user_ids:
return members
# Fetch user data from auth service using batch endpoint
# Using internal service communication
auth_service_url = os.getenv('AUTH_SERVICE_URL', 'http://auth-service:8000')
user_data_map = {}
async with httpx.AsyncClient() as client:
try:
# Use batch endpoint for efficiency
response = await client.post(
f"{auth_service_url}/api/v1/auth/users/batch",
json={"user_ids": user_ids},
timeout=10.0,
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
batch_result = response.json()
user_data_map = batch_result.get("users", {})
logger.info(
"Batch user fetch successful",
requested_count=len(user_ids),
found_count=batch_result.get("found_count", 0)
)
else:
logger.warning(
"Batch user fetch failed, falling back to individual calls",
status_code=response.status_code
)
# Fallback to individual calls if batch fails
for user_id in user_ids:
try:
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()
user_data_map[user_id] = user_data
except Exception as e:
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
continue
except Exception as e:
logger.warning("Batch user fetch failed, falling back to individual calls", error=str(e))
# Fallback to individual calls
for user_id in user_ids:
try:
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()
user_data_map[user_id] = user_data
except Exception as e:
logger.warning(f"Failed to fetch user data for {user_id}", error=str(e))
continue
# Enrich members with user data
for member in members:
user_id_str = str(member.user_id)
if user_id_str in user_data_map and user_data_map[user_id_str] is not None:
user_data = user_data_map[user_id_str]
# Add user fields as attributes to the member object
member.user_email = user_data.get("email")
member.user_full_name = user_data.get("full_name")
member.user = user_data # Store full user object for compatibility
else:
# Set defaults for missing users
member.user_email = None
member.user_full_name = "Unknown User"
member.user = None
return members
except Exception as e:
logger.warning("Failed to enrich members with user info", error=str(e))
# Return members without enrichment if it fails
return members
async def get_user_memberships(
self,
user_id: str,
active_only: bool = True
) -> List[TenantMember]:
"""Get all tenants a user is a member of"""
try:
filters = {"user_id": user_id}
if active_only:
filters["is_active"] = True
return await self.get_multi(
filters=filters,
order_by="joined_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get user memberships",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get memberships: {str(e)}")
async def verify_user_access(
self,
user_id: str,
tenant_id: str
) -> Dict[str, Any]:
"""Verify if user has access to tenant and return access details"""
try:
membership = await self.get_membership(tenant_id, user_id)
if not membership or not membership.is_active:
return {
"has_access": False,
"role": "none",
"permissions": []
}
# Parse permissions
permissions = []
if membership.permissions:
try:
permissions = json.loads(membership.permissions)
except json.JSONDecodeError:
logger.warning("Invalid permissions JSON for membership",
membership_id=membership.id)
permissions = self._get_default_permissions(membership.role)
return {
"has_access": True,
"role": membership.role,
"permissions": permissions,
"membership_id": str(membership.id),
"joined_at": membership.joined_at.isoformat() if membership.joined_at else None
}
except Exception as e:
logger.error("Failed to verify user access",
user_id=user_id,
tenant_id=tenant_id,
error=str(e))
return {
"has_access": False,
"role": "none",
"permissions": []
}
async def update_member_role(
self,
tenant_id: str,
user_id: str,
new_role: str,
updated_by: str = None
) -> Optional[TenantMember]:
"""Update member role and permissions"""
try:
valid_roles = ["owner", "admin", "member", "viewer"]
if new_role not in valid_roles:
raise ValidationError(f"Invalid role. Must be one of: {valid_roles}")
membership = await self.get_membership(tenant_id, user_id)
if not membership:
raise ValidationError("Membership not found")
# Get new permissions based on role
new_permissions = self._get_default_permissions(new_role)
updated_membership = await self.update(membership.id, {
"role": new_role,
"permissions": json.dumps(new_permissions)
})
logger.info("Member role updated",
membership_id=membership.id,
tenant_id=tenant_id,
user_id=user_id,
old_role=membership.role,
new_role=new_role,
updated_by=updated_by)
return updated_membership
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update member role",
tenant_id=tenant_id,
user_id=user_id,
new_role=new_role,
error=str(e))
raise DatabaseError(f"Failed to update role: {str(e)}")
async def deactivate_membership(
self,
tenant_id: str,
user_id: str,
deactivated_by: str = None
) -> Optional[TenantMember]:
"""Deactivate a membership (remove user from tenant)"""
try:
membership = await self.get_membership(tenant_id, user_id)
if not membership:
raise ValidationError("Membership not found")
# Don't allow deactivating the owner
if membership.role == "owner":
raise ValidationError("Cannot deactivate the owner membership")
updated_membership = await self.update(membership.id, {
"is_active": False
})
logger.info("Membership deactivated",
membership_id=membership.id,
tenant_id=tenant_id,
user_id=user_id,
deactivated_by=deactivated_by)
return updated_membership
except ValidationError:
raise
except Exception as e:
logger.error("Failed to deactivate membership",
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to deactivate membership: {str(e)}")
async def reactivate_membership(
self,
tenant_id: str,
user_id: str,
reactivated_by: str = None
) -> Optional[TenantMember]:
"""Reactivate a deactivated membership"""
try:
membership = await self.get_membership(tenant_id, user_id)
if not membership:
raise ValidationError("Membership not found")
updated_membership = await self.update(membership.id, {
"is_active": True,
"joined_at": datetime.utcnow() # Update join date
})
logger.info("Membership reactivated",
membership_id=membership.id,
tenant_id=tenant_id,
user_id=user_id,
reactivated_by=reactivated_by)
return updated_membership
except ValidationError:
raise
except Exception as e:
logger.error("Failed to reactivate membership",
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to reactivate membership: {str(e)}")
async def get_membership_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get membership statistics for a tenant"""
try:
# Get counts by role
role_query = text("""
SELECT role, COUNT(*) as count
FROM tenant_members
WHERE tenant_id = :tenant_id AND is_active = true
GROUP BY role
ORDER BY count DESC
""")
result = await self.session.execute(role_query, {"tenant_id": tenant_id})
members_by_role = {row.role: row.count for row in result.fetchall()}
# Get basic counts
total_members = await self.count(filters={"tenant_id": tenant_id})
active_members = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
# Get recent activity (members joined in last 30 days)
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_joins = len(await self.get_multi(
filters={
"tenant_id": tenant_id,
"is_active": True
},
limit=1000 # High limit to get accurate count
))
# Filter for recent joins (manual filtering since we can't use date range in filters easily)
recent_members = 0
all_active_members = await self.get_tenant_members(tenant_id, active_only=True)
for member in all_active_members:
if member.joined_at and member.joined_at >= thirty_days_ago:
recent_members += 1
return {
"total_members": total_members,
"active_members": active_members,
"inactive_members": total_members - active_members,
"members_by_role": members_by_role,
"recent_joins_30d": recent_members
}
except Exception as e:
logger.error("Failed to get membership statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_members": 0,
"active_members": 0,
"inactive_members": 0,
"members_by_role": {},
"recent_joins_30d": 0
}
def _get_default_permissions(self, role: str) -> str:
"""Get default permissions JSON string for a role"""
permission_map = {
"owner": ["read", "write", "admin", "delete"],
"admin": ["read", "write", "admin"],
"member": ["read", "write"],
"viewer": ["read"]
}
permissions = permission_map.get(role, ["read"])
return json.dumps(permissions)
async def bulk_update_permissions(
self,
tenant_id: str,
role_permissions: Dict[str, List[str]]
) -> int:
"""Bulk update permissions for all members of specific roles"""
try:
updated_count = 0
for role, permissions in role_permissions.items():
members = await self.get_tenant_members(
tenant_id, active_only=True, role=role
)
for member in members:
await self.update(member.id, {
"permissions": json.dumps(permissions)
})
updated_count += 1
logger.info("Bulk updated member permissions",
tenant_id=tenant_id,
updated_count=updated_count,
roles=list(role_permissions.keys()))
return updated_count
except Exception as e:
logger.error("Failed to bulk update permissions",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Bulk permission update failed: {str(e)}")
async def cleanup_inactive_memberships(self, days_old: int = 180) -> int:
"""Clean up old inactive memberships"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query_text = """
DELETE FROM tenant_members
WHERE is_active = false
AND created_at < :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up inactive memberships",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup inactive memberships",
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")

View File

@@ -0,0 +1,680 @@
"""
Tenant Repository
Repository for tenant operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, text, and_
from datetime import datetime, timedelta
import structlog
import uuid
from .base import TenantBaseRepository
from app.models.tenants import Tenant, Subscription
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class TenantRepository(TenantBaseRepository):
"""Repository for tenant operations"""
def __init__(self, model_class, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Tenants are relatively stable, longer cache time (10 minutes)
super().__init__(model_class, session, cache_ttl)
async def create_tenant(self, tenant_data: Dict[str, Any]) -> Tenant:
"""Create a new tenant with validation"""
try:
# Validate tenant data
validation_result = self._validate_tenant_data(
tenant_data,
["name", "address", "postal_code", "owner_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid tenant data: {validation_result['errors']}")
# Generate subdomain if not provided
if "subdomain" not in tenant_data or not tenant_data["subdomain"]:
subdomain = await self._generate_unique_subdomain(tenant_data["name"])
tenant_data["subdomain"] = subdomain
else:
# Check if provided subdomain is unique
existing_tenant = await self.get_by_subdomain(tenant_data["subdomain"])
if existing_tenant:
raise DuplicateRecordError(f"Subdomain {tenant_data['subdomain']} already exists")
# Set default values
if "business_type" not in tenant_data:
tenant_data["business_type"] = "bakery"
if "city" not in tenant_data:
tenant_data["city"] = "Madrid"
if "is_active" not in tenant_data:
tenant_data["is_active"] = True
if "ml_model_trained" not in tenant_data:
tenant_data["ml_model_trained"] = False
# Create tenant
tenant = await self.create(tenant_data)
logger.info("Tenant created successfully",
tenant_id=tenant.id,
name=tenant.name,
subdomain=tenant.subdomain,
owner_id=tenant.owner_id)
return tenant
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create tenant",
name=tenant_data.get("name"),
error=str(e))
raise DatabaseError(f"Failed to create tenant: {str(e)}")
async def get_by_subdomain(self, subdomain: str) -> Optional[Tenant]:
"""Get tenant by subdomain"""
try:
return await self.get_by_field("subdomain", subdomain)
except Exception as e:
logger.error("Failed to get tenant by subdomain",
subdomain=subdomain,
error=str(e))
raise DatabaseError(f"Failed to get tenant: {str(e)}")
async def get_tenants_by_owner(self, owner_id: str) -> List[Tenant]:
"""Get all tenants owned by a user"""
try:
return await self.get_multi(
filters={"owner_id": owner_id, "is_active": True},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenants by owner",
owner_id=owner_id,
error=str(e))
raise DatabaseError(f"Failed to get tenants: {str(e)}")
async def get_active_tenants(self, skip: int = 0, limit: int = 100) -> List[Tenant]:
"""Get all active tenants"""
return await self.get_active_records(skip=skip, limit=limit)
async def search_tenants(
self,
search_term: str,
business_type: str = None,
city: str = None,
skip: int = 0,
limit: int = 50
) -> List[Tenant]:
"""Search tenants by name, address, or other criteria"""
try:
# Build search conditions
conditions = ["is_active = true"]
params = {"skip": skip, "limit": limit}
# Add text search
conditions.append("(LOWER(name) LIKE LOWER(:search_term) OR LOWER(address) LIKE LOWER(:search_term))")
params["search_term"] = f"%{search_term}%"
# Add business type filter
if business_type:
conditions.append("business_type = :business_type")
params["business_type"] = business_type
# Add city filter
if city:
conditions.append("LOWER(city) = LOWER(:city)")
params["city"] = city
query_text = f"""
SELECT * FROM tenants
WHERE {' AND '.join(conditions)}
ORDER BY name ASC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), params)
tenants = []
for row in result.fetchall():
record_dict = dict(row._mapping)
tenant = self.model(**record_dict)
tenants.append(tenant)
return tenants
except Exception as e:
logger.error("Failed to search tenants",
search_term=search_term,
error=str(e))
return []
async def update_tenant_model_status(
self,
tenant_id: str,
ml_model_trained: bool,
last_training_date: datetime = None
) -> Optional[Tenant]:
"""Update tenant model training status"""
try:
update_data = {
"ml_model_trained": ml_model_trained,
"updated_at": datetime.utcnow()
}
if last_training_date:
update_data["last_training_date"] = last_training_date
elif ml_model_trained:
update_data["last_training_date"] = datetime.utcnow()
updated_tenant = await self.update(tenant_id, update_data)
logger.info("Tenant model status updated",
tenant_id=tenant_id,
ml_model_trained=ml_model_trained,
last_training_date=last_training_date)
return updated_tenant
except Exception as e:
logger.error("Failed to update tenant model status",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to update model status: {str(e)}")
async def get_tenants_by_location(
self,
latitude: float,
longitude: float,
radius_km: float = 10.0,
limit: int = 50
) -> List[Tenant]:
"""Get tenants within a geographic radius"""
try:
# Using Haversine formula for distance calculation
query_text = """
SELECT *,
(6371 * acos(
cos(radians(:latitude)) *
cos(radians(latitude)) *
cos(radians(longitude) - radians(:longitude)) +
sin(radians(:latitude)) *
sin(radians(latitude))
)) AS distance_km
FROM tenants
WHERE is_active = true
AND latitude IS NOT NULL
AND longitude IS NOT NULL
HAVING distance_km <= :radius_km
ORDER BY distance_km ASC
LIMIT :limit
"""
result = await self.session.execute(text(query_text), {
"latitude": latitude,
"longitude": longitude,
"radius_km": radius_km,
"limit": limit
})
tenants = []
for row in result.fetchall():
# Create tenant object (excluding the calculated distance_km field)
record_dict = dict(row._mapping)
record_dict.pop("distance_km", None) # Remove calculated field
tenant = self.model(**record_dict)
tenants.append(tenant)
return tenants
except Exception as e:
logger.error("Failed to get tenants by location",
latitude=latitude,
longitude=longitude,
radius_km=radius_km,
error=str(e))
return []
async def get_tenant_statistics(self) -> Dict[str, Any]:
"""Get global tenant statistics"""
try:
# Get basic counts
total_tenants = await self.count()
active_tenants = await self.count(filters={"is_active": True})
# Get tenants by business type
business_type_query = text("""
SELECT business_type, COUNT(*) as count
FROM tenants
WHERE is_active = true
GROUP BY business_type
ORDER BY count DESC
""")
result = await self.session.execute(business_type_query)
business_type_stats = {row.business_type: row.count for row in result.fetchall()}
# Get tenants by subscription tier - now from subscriptions table
tier_query = text("""
SELECT s.plan as subscription_tier, COUNT(*) as count
FROM tenants t
LEFT JOIN subscriptions s ON t.id = s.tenant_id AND s.status = 'active'
WHERE t.is_active = true
GROUP BY s.plan
ORDER BY count DESC
""")
tier_result = await self.session.execute(tier_query)
tier_stats = {}
for row in tier_result.fetchall():
tier = row.subscription_tier if row.subscription_tier else "no_subscription"
tier_stats[tier] = row.count
# Get model training statistics
model_query = text("""
SELECT
COUNT(CASE WHEN ml_model_trained = true THEN 1 END) as trained_count,
COUNT(CASE WHEN ml_model_trained = false THEN 1 END) as untrained_count,
AVG(EXTRACT(EPOCH FROM (NOW() - last_training_date))/86400) as avg_days_since_training
FROM tenants
WHERE is_active = true
""")
model_result = await self.session.execute(model_query)
model_row = model_result.fetchone()
# Get recent registrations (last 30 days)
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_registrations = await self.count(filters={
"created_at": f">= '{thirty_days_ago.isoformat()}'"
})
return {
"total_tenants": total_tenants,
"active_tenants": active_tenants,
"inactive_tenants": total_tenants - active_tenants,
"tenants_by_business_type": business_type_stats,
"tenants_by_subscription": tier_stats,
"model_training": {
"trained_tenants": int(model_row.trained_count or 0),
"untrained_tenants": int(model_row.untrained_count or 0),
"avg_days_since_training": float(model_row.avg_days_since_training or 0)
} if model_row else {
"trained_tenants": 0,
"untrained_tenants": 0,
"avg_days_since_training": 0.0
},
"recent_registrations_30d": recent_registrations
}
except Exception as e:
logger.error("Failed to get tenant statistics", error=str(e))
return {
"total_tenants": 0,
"active_tenants": 0,
"inactive_tenants": 0,
"tenants_by_business_type": {},
"tenants_by_subscription": {},
"model_training": {
"trained_tenants": 0,
"untrained_tenants": 0,
"avg_days_since_training": 0.0
},
"recent_registrations_30d": 0
}
async def _generate_unique_subdomain(self, name: str) -> str:
"""Generate a unique subdomain from tenant name"""
try:
# Clean the name to create a subdomain
subdomain = name.lower().replace(' ', '-')
# Remove accents
subdomain = subdomain.replace('á', 'a').replace('é', 'e').replace('í', 'i').replace('ó', 'o').replace('ú', 'u')
subdomain = subdomain.replace('ñ', 'n')
# Keep only alphanumeric and hyphens
subdomain = ''.join(c for c in subdomain if c.isalnum() or c == '-')
# Remove multiple consecutive hyphens
while '--' in subdomain:
subdomain = subdomain.replace('--', '-')
# Remove leading/trailing hyphens
subdomain = subdomain.strip('-')
# Ensure minimum length
if len(subdomain) < 3:
subdomain = f"tenant-{subdomain}"
# Check if subdomain exists
existing_tenant = await self.get_by_subdomain(subdomain)
if not existing_tenant:
return subdomain
# If it exists, add a unique suffix
counter = 1
while True:
candidate = f"{subdomain}-{counter}"
existing_tenant = await self.get_by_subdomain(candidate)
if not existing_tenant:
return candidate
counter += 1
# Prevent infinite loop
if counter > 9999:
return f"{subdomain}-{uuid.uuid4().hex[:6]}"
except Exception as e:
logger.error("Failed to generate unique subdomain",
name=name,
error=str(e))
# Fallback to UUID-based subdomain
return f"tenant-{uuid.uuid4().hex[:8]}"
async def deactivate_tenant(self, tenant_id: str) -> Optional[Tenant]:
"""Deactivate a tenant"""
return await self.deactivate_record(tenant_id)
async def activate_tenant(self, tenant_id: str) -> Optional[Tenant]:
"""Activate a tenant"""
return await self.activate_record(tenant_id)
async def get_child_tenants(self, parent_tenant_id: str) -> List[Tenant]:
"""Get all child tenants for a parent tenant"""
try:
return await self.get_multi(
filters={"parent_tenant_id": parent_tenant_id, "is_active": True},
order_by="created_at",
order_desc=False
)
except Exception as e:
logger.error("Failed to get child tenants",
parent_tenant_id=parent_tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get child tenants: {str(e)}")
async def get(self, record_id: Any) -> Optional[Tenant]:
"""Get tenant by ID - alias for get_by_id for compatibility"""
return await self.get_by_id(record_id)
async def get_child_tenant_count(self, parent_tenant_id: str) -> int:
"""Get count of child tenants for a parent tenant"""
try:
child_tenants = await self.get_child_tenants(parent_tenant_id)
return len(child_tenants)
except Exception as e:
logger.error("Failed to get child tenant count",
parent_tenant_id=parent_tenant_id,
error=str(e))
return 0
async def get_user_tenants_with_hierarchy(self, user_id: str) -> List[Dict[str, Any]]:
"""
Get all tenants a user has access to, organized in hierarchy.
Returns parent tenants with their children nested.
"""
try:
# Get all tenants where user is owner or member
query_text = """
SELECT DISTINCT t.*
FROM tenants t
LEFT JOIN tenant_members tm ON t.id = tm.tenant_id
WHERE (t.owner_id = :user_id OR tm.user_id = :user_id)
AND t.is_active = true
ORDER BY t.tenant_type DESC, t.created_at ASC
"""
result = await self.session.execute(text(query_text), {"user_id": user_id})
tenants = []
for row in result.fetchall():
record_dict = dict(row._mapping)
tenant = self.model(**record_dict)
tenants.append(tenant)
# Organize into hierarchy
tenant_hierarchy = []
parent_map = {}
# First pass: collect all parent/standalone tenants
for tenant in tenants:
if tenant.tenant_type in ['parent', 'standalone']:
tenant_dict = {
'id': str(tenant.id),
'name': tenant.name,
'subdomain': tenant.subdomain,
'tenant_type': tenant.tenant_type,
'business_type': tenant.business_type,
'business_model': tenant.business_model,
'city': tenant.city,
'is_active': tenant.is_active,
'children': [] if tenant.tenant_type == 'parent' else None
}
tenant_hierarchy.append(tenant_dict)
parent_map[str(tenant.id)] = tenant_dict
# Second pass: attach children to their parents
for tenant in tenants:
if tenant.tenant_type == 'child' and tenant.parent_tenant_id:
parent_id = str(tenant.parent_tenant_id)
if parent_id in parent_map:
child_dict = {
'id': str(tenant.id),
'name': tenant.name,
'subdomain': tenant.subdomain,
'tenant_type': 'child',
'parent_tenant_id': parent_id,
'city': tenant.city,
'is_active': tenant.is_active
}
parent_map[parent_id]['children'].append(child_dict)
return tenant_hierarchy
except Exception as e:
logger.error("Failed to get user tenants with hierarchy",
user_id=user_id,
error=str(e))
return []
async def get_tenants_by_session_id(self, session_id: str) -> List[Tenant]:
"""
Get tenants associated with a specific demo session using the demo_session_id field.
"""
try:
return await self.get_multi(
filters={
"demo_session_id": session_id,
"is_active": True
},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get tenants by session ID",
session_id=session_id,
error=str(e))
raise DatabaseError(f"Failed to get tenants by session ID: {str(e)}")
async def get_professional_demo_tenants(self, session_id: str) -> List[Tenant]:
"""
Get professional demo tenants filtered by session.
Args:
session_id: Required demo session ID to filter tenants
Returns:
List of professional demo tenants for this specific session
"""
try:
filters = {
"business_model": "professional_bakery",
"is_demo": True,
"is_active": True,
"demo_session_id": session_id # Always filter by session
}
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get professional demo tenants",
session_id=session_id,
error=str(e))
raise DatabaseError(f"Failed to get professional demo tenants: {str(e)}")
async def get_enterprise_demo_tenants(self, session_id: str) -> List[Tenant]:
"""
Get enterprise demo tenants (parent and children) filtered by session.
Args:
session_id: Required demo session ID to filter tenants
Returns:
List of enterprise demo tenants (1 parent + 3 children) for this specific session
"""
try:
# Get enterprise demo parent tenants for this session
parent_tenants = await self.get_multi(
filters={
"tenant_type": "parent",
"is_demo": True,
"is_active": True,
"demo_session_id": session_id # Always filter by session
},
order_by="created_at",
order_desc=True
)
# Get child tenants for the enterprise demo session
child_tenants = await self.get_multi(
filters={
"tenant_type": "child",
"is_demo": True,
"is_active": True,
"demo_session_id": session_id # Always filter by session
},
order_by="created_at",
order_desc=True
)
# Combine parent and child tenants
return parent_tenants + child_tenants
except Exception as e:
logger.error("Failed to get enterprise demo tenants",
session_id=session_id,
error=str(e))
raise DatabaseError(f"Failed to get enterprise demo tenants: {str(e)}")
async def get_by_customer_id(self, customer_id: str) -> Optional[Tenant]:
"""
Get tenant by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
Tenant object if found, None otherwise
"""
try:
# Find tenant by joining with subscriptions table
# Tenant doesn't have customer_id directly, so we need to find via subscription
query = select(Tenant).join(
Subscription, Subscription.tenant_id == Tenant.id
).where(Subscription.customer_id == customer_id)
result = await self.session.execute(query)
tenant = result.scalar_one_or_none()
if tenant:
logger.debug("Found tenant by customer_id",
customer_id=customer_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.debug("No tenant found for customer_id",
customer_id=customer_id)
return None
except Exception as e:
logger.error("Error getting tenant by customer_id",
customer_id=customer_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant by customer_id: {str(e)}")
async def get_user_primary_tenant(self, user_id: str) -> Optional[Tenant]:
"""
Get the primary tenant for a user (the tenant they own)
Args:
user_id: User ID to find primary tenant for
Returns:
Tenant object if found, None otherwise
"""
try:
logger.debug("Getting primary tenant for user", user_id=user_id)
# Query for tenant where user is the owner
query = select(Tenant).where(Tenant.owner_id == user_id)
result = await self.session.execute(query)
tenant = result.scalar_one_or_none()
if tenant:
logger.debug("Found primary tenant for user",
user_id=user_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.debug("No primary tenant found for user", user_id=user_id)
return None
except Exception as e:
logger.error("Error getting primary tenant for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get primary tenant for user: {str(e)}")
async def get_any_user_tenant(self, user_id: str) -> Optional[Tenant]:
"""
Get any tenant that the user has access to (via tenant_members)
Args:
user_id: User ID to find accessible tenants for
Returns:
Tenant object if found, None otherwise
"""
try:
logger.debug("Getting any accessible tenant for user", user_id=user_id)
# Query for tenant members where user has access
from app.models.tenants import TenantMember
query = select(Tenant).join(
TenantMember, Tenant.id == TenantMember.tenant_id
).where(TenantMember.user_id == user_id)
result = await self.session.execute(query)
tenant = result.scalar_one_or_none()
if tenant:
logger.debug("Found accessible tenant for user",
user_id=user_id,
tenant_id=str(tenant.id))
return tenant
else:
logger.debug("No accessible tenants found for user", user_id=user_id)
return None
except Exception as e:
logger.error("Error getting accessible tenant for user",
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to get accessible tenant for user: {str(e)}")

View File

@@ -0,0 +1,82 @@
# services/tenant/app/repositories/tenant_settings_repository.py
"""
Tenant Settings Repository
Data access layer for tenant settings
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from typing import Optional
from uuid import UUID
import structlog
from ..models.tenant_settings import TenantSettings
logger = structlog.get_logger()
class TenantSettingsRepository:
"""Repository for TenantSettings data access"""
def __init__(self, db: AsyncSession):
self.db = db
async def get_by_tenant_id(self, tenant_id: UUID) -> Optional[TenantSettings]:
"""
Get tenant settings by tenant ID
Args:
tenant_id: UUID of the tenant
Returns:
TenantSettings or None if not found
"""
result = await self.db.execute(
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
)
return result.scalar_one_or_none()
async def create(self, settings: TenantSettings) -> TenantSettings:
"""
Create new tenant settings
Args:
settings: TenantSettings instance to create
Returns:
Created TenantSettings instance
"""
self.db.add(settings)
await self.db.commit()
await self.db.refresh(settings)
return settings
async def update(self, settings: TenantSettings) -> TenantSettings:
"""
Update tenant settings
Args:
settings: TenantSettings instance with updates
Returns:
Updated TenantSettings instance
"""
await self.db.commit()
await self.db.refresh(settings)
return settings
async def delete(self, tenant_id: UUID) -> None:
"""
Delete tenant settings
Args:
tenant_id: UUID of the tenant
"""
result = await self.db.execute(
select(TenantSettings).where(TenantSettings.tenant_id == tenant_id)
)
settings = result.scalar_one_or_none()
if settings:
await self.db.delete(settings)
await self.db.commit()

View File

View File

@@ -0,0 +1,89 @@
"""
Tenant Location Schemas
"""
from pydantic import BaseModel, Field, field_validator
from typing import Optional, List, Dict, Any
from datetime import datetime
from uuid import UUID
class TenantLocationBase(BaseModel):
"""Base schema for tenant location"""
name: str = Field(..., min_length=1, max_length=200)
location_type: str = Field(..., pattern=r'^(central_production|retail_outlet|warehouse|store|branch)$')
address: str = Field(..., min_length=10, max_length=500)
city: str = Field(default="Madrid", max_length=100)
postal_code: str = Field(..., min_length=3, max_length=10)
latitude: Optional[float] = Field(None, ge=-90, le=90)
longitude: Optional[float] = Field(None, ge=-180, le=180)
contact_person: Optional[str] = Field(None, max_length=200)
contact_phone: Optional[str] = Field(None, max_length=20)
contact_email: Optional[str] = Field(None, max_length=255)
is_active: bool = True
delivery_windows: Optional[Dict[str, Any]] = None
operational_hours: Optional[Dict[str, Any]] = None
capacity: Optional[int] = Field(None, ge=0)
max_delivery_radius_km: Optional[float] = Field(None, ge=0)
delivery_schedule_config: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = Field(None)
class TenantLocationCreate(TenantLocationBase):
"""Schema for creating a tenant location"""
tenant_id: str # This will be validated as UUID in the API layer
class TenantLocationUpdate(BaseModel):
"""Schema for updating a tenant location"""
name: Optional[str] = Field(None, min_length=1, max_length=200)
location_type: Optional[str] = Field(None, pattern=r'^(central_production|retail_outlet|warehouse|store|branch)$')
address: Optional[str] = Field(None, min_length=10, max_length=500)
city: Optional[str] = Field(None, max_length=100)
postal_code: Optional[str] = Field(None, min_length=3, max_length=10)
latitude: Optional[float] = Field(None, ge=-90, le=90)
longitude: Optional[float] = Field(None, ge=-180, le=180)
contact_person: Optional[str] = Field(None, max_length=200)
contact_phone: Optional[str] = Field(None, max_length=20)
contact_email: Optional[str] = Field(None, max_length=255)
is_active: Optional[bool] = None
delivery_windows: Optional[Dict[str, Any]] = None
operational_hours: Optional[Dict[str, Any]] = None
capacity: Optional[int] = Field(None, ge=0)
max_delivery_radius_km: Optional[float] = Field(None, ge=0)
delivery_schedule_config: Optional[Dict[str, Any]] = None
metadata: Optional[Dict[str, Any]] = Field(None)
class TenantLocationResponse(TenantLocationBase):
"""Schema for tenant location response"""
id: str
tenant_id: str
created_at: datetime
updated_at: Optional[datetime]
@field_validator('id', 'tenant_id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
populate_by_name = True
class TenantLocationsResponse(BaseModel):
"""Schema for multiple tenant locations response"""
locations: List[TenantLocationResponse]
total: int
class TenantLocationTypeFilter(BaseModel):
"""Schema for filtering locations by type"""
location_types: List[str] = Field(
default=["central_production", "retail_outlet", "warehouse", "store", "branch"],
description="List of location types to include"
)

View File

@@ -0,0 +1,316 @@
# services/tenant/app/schemas/tenant_settings.py
"""
Tenant Settings Schemas
Pydantic models for API request/response validation
"""
from pydantic import BaseModel, Field, validator
from typing import Optional
from datetime import datetime
from uuid import UUID
# ================================================================
# SETTING CATEGORY SCHEMAS
# ================================================================
class ProcurementSettings(BaseModel):
"""Procurement and auto-approval settings"""
auto_approve_enabled: bool = True
auto_approve_threshold_eur: float = Field(500.0, ge=0, le=10000)
auto_approve_min_supplier_score: float = Field(0.80, ge=0.0, le=1.0)
require_approval_new_suppliers: bool = True
require_approval_critical_items: bool = True
procurement_lead_time_days: int = Field(3, ge=1, le=30)
demand_forecast_days: int = Field(14, ge=1, le=90)
safety_stock_percentage: float = Field(20.0, ge=0.0, le=100.0)
po_approval_reminder_hours: int = Field(24, ge=1, le=168)
po_critical_escalation_hours: int = Field(12, ge=1, le=72)
use_reorder_rules: bool = Field(True, description="Use ingredient reorder point and reorder quantity in procurement calculations")
economic_rounding: bool = Field(True, description="Round order quantities to economic multiples (reorder_quantity or supplier minimum_order_quantity)")
respect_storage_limits: bool = Field(True, description="Enforce max_stock_level constraints on orders")
use_supplier_minimums: bool = Field(True, description="Respect supplier minimum_order_quantity and minimum_order_amount")
optimize_price_tiers: bool = Field(True, description="Optimize order quantities to capture volume discount price tiers")
class InventorySettings(BaseModel):
"""Inventory management settings"""
low_stock_threshold: int = Field(10, ge=1, le=1000)
reorder_point: int = Field(20, ge=1, le=1000)
reorder_quantity: int = Field(50, ge=1, le=1000)
expiring_soon_days: int = Field(7, ge=1, le=30)
expiration_warning_days: int = Field(3, ge=1, le=14)
quality_score_threshold: float = Field(8.0, ge=0.0, le=10.0)
temperature_monitoring_enabled: bool = True
refrigeration_temp_min: float = Field(1.0, ge=-5.0, le=10.0)
refrigeration_temp_max: float = Field(4.0, ge=-5.0, le=10.0)
freezer_temp_min: float = Field(-20.0, ge=-30.0, le=0.0)
freezer_temp_max: float = Field(-15.0, ge=-30.0, le=0.0)
room_temp_min: float = Field(18.0, ge=10.0, le=35.0)
room_temp_max: float = Field(25.0, ge=10.0, le=35.0)
temp_deviation_alert_minutes: int = Field(15, ge=1, le=60)
critical_temp_deviation_minutes: int = Field(5, ge=1, le=30)
@validator('refrigeration_temp_max')
def validate_refrigeration_range(cls, v, values):
if 'refrigeration_temp_min' in values and v <= values['refrigeration_temp_min']:
raise ValueError('refrigeration_temp_max must be greater than refrigeration_temp_min')
return v
@validator('freezer_temp_max')
def validate_freezer_range(cls, v, values):
if 'freezer_temp_min' in values and v <= values['freezer_temp_min']:
raise ValueError('freezer_temp_max must be greater than freezer_temp_min')
return v
@validator('room_temp_max')
def validate_room_range(cls, v, values):
if 'room_temp_min' in values and v <= values['room_temp_min']:
raise ValueError('room_temp_max must be greater than room_temp_min')
return v
class ProductionSettings(BaseModel):
"""Production settings"""
planning_horizon_days: int = Field(7, ge=1, le=30)
minimum_batch_size: float = Field(1.0, ge=0.1, le=100.0)
maximum_batch_size: float = Field(100.0, ge=1.0, le=1000.0)
production_buffer_percentage: float = Field(10.0, ge=0.0, le=50.0)
working_hours_per_day: int = Field(12, ge=1, le=24)
max_overtime_hours: int = Field(4, ge=0, le=12)
capacity_utilization_target: float = Field(0.85, ge=0.5, le=1.0)
capacity_warning_threshold: float = Field(0.95, ge=0.7, le=1.0)
quality_check_enabled: bool = True
minimum_yield_percentage: float = Field(85.0, ge=50.0, le=100.0)
quality_score_threshold: float = Field(8.0, ge=0.0, le=10.0)
schedule_optimization_enabled: bool = True
prep_time_buffer_minutes: int = Field(30, ge=0, le=120)
cleanup_time_buffer_minutes: int = Field(15, ge=0, le=120)
labor_cost_per_hour_eur: float = Field(15.0, ge=5.0, le=100.0)
overhead_cost_percentage: float = Field(20.0, ge=0.0, le=50.0)
@validator('maximum_batch_size')
def validate_batch_size_range(cls, v, values):
if 'minimum_batch_size' in values and v <= values['minimum_batch_size']:
raise ValueError('maximum_batch_size must be greater than minimum_batch_size')
return v
@validator('capacity_warning_threshold')
def validate_capacity_threshold(cls, v, values):
if 'capacity_utilization_target' in values and v <= values['capacity_utilization_target']:
raise ValueError('capacity_warning_threshold must be greater than capacity_utilization_target')
return v
class SupplierSettings(BaseModel):
"""Supplier management settings"""
default_payment_terms_days: int = Field(30, ge=1, le=90)
default_delivery_days: int = Field(3, ge=1, le=30)
excellent_delivery_rate: float = Field(95.0, ge=90.0, le=100.0)
good_delivery_rate: float = Field(90.0, ge=80.0, le=99.0)
excellent_quality_rate: float = Field(98.0, ge=90.0, le=100.0)
good_quality_rate: float = Field(95.0, ge=80.0, le=99.0)
critical_delivery_delay_hours: int = Field(24, ge=1, le=168)
critical_quality_rejection_rate: float = Field(10.0, ge=0.0, le=50.0)
high_cost_variance_percentage: float = Field(15.0, ge=0.0, le=100.0)
@validator('good_delivery_rate')
def validate_delivery_rates(cls, v, values):
if 'excellent_delivery_rate' in values and v >= values['excellent_delivery_rate']:
raise ValueError('good_delivery_rate must be less than excellent_delivery_rate')
return v
@validator('good_quality_rate')
def validate_quality_rates(cls, v, values):
if 'excellent_quality_rate' in values and v >= values['excellent_quality_rate']:
raise ValueError('good_quality_rate must be less than excellent_quality_rate')
return v
class POSSettings(BaseModel):
"""POS integration settings"""
sync_interval_minutes: int = Field(5, ge=1, le=60)
auto_sync_products: bool = True
auto_sync_transactions: bool = True
class OrderSettings(BaseModel):
"""Order and business rules settings"""
max_discount_percentage: float = Field(50.0, ge=0.0, le=100.0)
default_delivery_window_hours: int = Field(48, ge=1, le=168)
dynamic_pricing_enabled: bool = False
discount_enabled: bool = True
delivery_tracking_enabled: bool = True
class ReplenishmentSettings(BaseModel):
"""Replenishment planning settings"""
projection_horizon_days: int = Field(7, ge=1, le=30)
service_level: float = Field(0.95, ge=0.0, le=1.0)
buffer_days: int = Field(1, ge=0, le=14)
enable_auto_replenishment: bool = True
min_order_quantity: float = Field(1.0, ge=0.1, le=1000.0)
max_order_quantity: float = Field(1000.0, ge=1.0, le=10000.0)
demand_forecast_days: int = Field(14, ge=1, le=90)
class SafetyStockSettings(BaseModel):
"""Safety stock settings"""
service_level: float = Field(0.95, ge=0.0, le=1.0)
method: str = Field("statistical", description="Method for safety stock calculation")
min_safety_stock: float = Field(0.0, ge=0.0, le=1000.0)
max_safety_stock: float = Field(100.0, ge=0.0, le=1000.0)
reorder_point_calculation: str = Field("safety_stock_plus_lead_time_demand", description="Method for reorder point calculation")
class MOQSettings(BaseModel):
"""MOQ aggregation settings"""
consolidation_window_days: int = Field(7, ge=1, le=30)
allow_early_ordering: bool = True
enable_batch_optimization: bool = True
min_batch_size: float = Field(1.0, ge=0.1, le=1000.0)
max_batch_size: float = Field(1000.0, ge=1.0, le=10000.0)
class SupplierSelectionSettings(BaseModel):
"""Supplier selection settings"""
price_weight: float = Field(0.40, ge=0.0, le=1.0)
lead_time_weight: float = Field(0.20, ge=0.0, le=1.0)
quality_weight: float = Field(0.20, ge=0.0, le=1.0)
reliability_weight: float = Field(0.20, ge=0.0, le=1.0)
diversification_threshold: int = Field(1000, ge=0, le=1000)
max_single_percentage: float = Field(0.70, ge=0.0, le=1.0)
enable_supplier_score_optimization: bool = True
@validator('price_weight', 'lead_time_weight', 'quality_weight', 'reliability_weight')
def validate_weights_sum(cls, v, values):
weights = [values.get('price_weight', 0.40), values.get('lead_time_weight', 0.20),
values.get('quality_weight', 0.20), values.get('reliability_weight', 0.20)]
total = sum(weights)
if total > 1.0:
raise ValueError('Weights must sum to 1.0 or less')
return v
class MLInsightsSettings(BaseModel):
"""ML Insights configuration settings"""
# Inventory ML (Safety Stock Optimization)
inventory_lookback_days: int = Field(90, ge=30, le=365, description="Days of demand history for safety stock analysis")
inventory_min_history_days: int = Field(30, ge=7, le=180, description="Minimum days of history required")
# Production ML (Yield Prediction)
production_lookback_days: int = Field(90, ge=30, le=365, description="Days of production history for yield analysis")
production_min_history_runs: int = Field(30, ge=10, le=100, description="Minimum production runs required")
# Procurement ML (Supplier Analysis & Price Forecasting)
supplier_analysis_lookback_days: int = Field(180, ge=30, le=730, description="Days of order history for supplier analysis")
supplier_analysis_min_orders: int = Field(10, ge=5, le=100, description="Minimum orders required for analysis")
price_forecast_lookback_days: int = Field(180, ge=90, le=730, description="Days of price history for forecasting")
price_forecast_horizon_days: int = Field(30, ge=7, le=90, description="Days to forecast ahead")
# Forecasting ML (Dynamic Rules)
rules_generation_lookback_days: int = Field(90, ge=30, le=365, description="Days of sales history for rule learning")
rules_generation_min_samples: int = Field(10, ge=5, le=100, description="Minimum samples required for rule generation")
# Global ML Settings
enable_ml_insights: bool = Field(True, description="Enable/disable ML insights generation")
ml_insights_auto_trigger: bool = Field(False, description="Automatically trigger ML insights in daily workflow")
ml_confidence_threshold: float = Field(0.80, ge=0.0, le=1.0, description="Minimum confidence threshold for ML recommendations")
class NotificationSettings(BaseModel):
"""Notification and communication settings"""
# WhatsApp Configuration (Shared Account Model)
whatsapp_enabled: bool = Field(False, description="Enable WhatsApp notifications for this tenant")
whatsapp_phone_number_id: str = Field("", description="Meta WhatsApp Phone Number ID (from shared master account)")
whatsapp_display_phone_number: str = Field("", description="Display format for UI (e.g., '+34 612 345 678')")
whatsapp_default_language: str = Field("es", description="Default language for WhatsApp templates")
# Email Configuration
email_enabled: bool = Field(True, description="Enable email notifications for this tenant")
email_from_address: str = Field("", description="Custom from email address (optional)")
email_from_name: str = Field("", description="Custom from name (optional)")
email_reply_to: str = Field("", description="Reply-to email address (optional)")
# Notification Preferences
enable_po_notifications: bool = Field(True, description="Enable purchase order notifications")
enable_inventory_alerts: bool = Field(True, description="Enable inventory alerts")
enable_production_alerts: bool = Field(True, description="Enable production alerts")
enable_forecast_alerts: bool = Field(True, description="Enable forecast alerts")
# Notification Channels
po_notification_channels: list[str] = Field(["email"], description="Channels for PO notifications (email, whatsapp)")
inventory_alert_channels: list[str] = Field(["email"], description="Channels for inventory alerts")
production_alert_channels: list[str] = Field(["email"], description="Channels for production alerts")
forecast_alert_channels: list[str] = Field(["email"], description="Channels for forecast alerts")
@validator('po_notification_channels', 'inventory_alert_channels', 'production_alert_channels', 'forecast_alert_channels')
def validate_channels(cls, v):
"""Validate that channels are valid"""
valid_channels = ["email", "whatsapp", "sms", "push"]
for channel in v:
if channel not in valid_channels:
raise ValueError(f"Invalid channel: {channel}. Must be one of {valid_channels}")
return v
@validator('whatsapp_phone_number_id')
def validate_phone_number_id(cls, v, values):
"""Validate phone number ID is provided if WhatsApp is enabled"""
if values.get('whatsapp_enabled') and not v:
raise ValueError("whatsapp_phone_number_id is required when WhatsApp is enabled")
return v
# ================================================================
# REQUEST/RESPONSE SCHEMAS
# ================================================================
class TenantSettingsResponse(BaseModel):
"""Response schema for tenant settings"""
id: UUID
tenant_id: UUID
procurement_settings: ProcurementSettings
inventory_settings: InventorySettings
production_settings: ProductionSettings
supplier_settings: SupplierSettings
pos_settings: POSSettings
order_settings: OrderSettings
replenishment_settings: ReplenishmentSettings
safety_stock_settings: SafetyStockSettings
moq_settings: MOQSettings
supplier_selection_settings: SupplierSelectionSettings
ml_insights_settings: MLInsightsSettings
notification_settings: NotificationSettings
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class TenantSettingsUpdate(BaseModel):
"""Schema for updating tenant settings"""
procurement_settings: Optional[ProcurementSettings] = None
inventory_settings: Optional[InventorySettings] = None
production_settings: Optional[ProductionSettings] = None
supplier_settings: Optional[SupplierSettings] = None
pos_settings: Optional[POSSettings] = None
order_settings: Optional[OrderSettings] = None
replenishment_settings: Optional[ReplenishmentSettings] = None
safety_stock_settings: Optional[SafetyStockSettings] = None
moq_settings: Optional[MOQSettings] = None
supplier_selection_settings: Optional[SupplierSelectionSettings] = None
ml_insights_settings: Optional[MLInsightsSettings] = None
notification_settings: Optional[NotificationSettings] = None
class CategoryUpdateRequest(BaseModel):
"""Schema for updating a single category"""
settings: dict
class CategoryResetResponse(BaseModel):
"""Response schema for category reset"""
category: str
settings: dict
message: str

View File

@@ -0,0 +1,386 @@
# services/tenant/app/schemas/tenants.py
"""
Tenant schemas - FIXED VERSION
"""
from pydantic import BaseModel, Field, field_validator, ValidationInfo
from typing import Optional, List, Dict, Any
from datetime import datetime
from uuid import UUID
import re
class BakeryRegistration(BaseModel):
"""Bakery registration schema"""
name: str = Field(..., min_length=2, max_length=200)
address: str = Field(..., min_length=10, max_length=500)
city: str = Field(default="Madrid", max_length=100)
postal_code: str = Field(..., pattern=r"^\d{5}$")
phone: str = Field(..., min_length=9, max_length=20)
business_type: str = Field(default="bakery")
business_model: Optional[str] = Field(default="individual_bakery")
coupon_code: Optional[str] = Field(None, max_length=50, description="Promotional coupon code")
# Subscription linking fields (for new multi-phase registration architecture)
subscription_id: Optional[str] = Field(None, description="Existing subscription ID to link to this tenant")
link_existing_subscription: Optional[bool] = Field(False, description="Flag to link an existing subscription during tenant creation")
@field_validator('phone')
@classmethod
def validate_spanish_phone(cls, v):
"""Validate Spanish phone number"""
# Remove spaces and common separators
phone = re.sub(r'[\s\-\(\)]', '', v)
# Spanish mobile: +34 6/7/8/9 + 8 digits
# Spanish landline: +34 9 + 8 digits
patterns = [
r'^(\+34|0034|34)?[6789]\d{8}$', # Mobile
r'^(\+34|0034|34)?9\d{8}$', # Landline
]
if not any(re.match(pattern, phone) for pattern in patterns):
raise ValueError('Invalid Spanish phone number')
return v
@field_validator('business_type')
@classmethod
def validate_business_type(cls, v):
valid_types = ['bakery', 'coffee_shop', 'pastry_shop', 'restaurant']
if v not in valid_types:
raise ValueError(f'Business type must be one of: {valid_types}')
return v
@field_validator('business_model')
@classmethod
def validate_business_model(cls, v):
if v is None:
return v
valid_models = ['individual_bakery', 'central_baker_satellite', 'retail_bakery', 'hybrid_bakery']
if v not in valid_models:
raise ValueError(f'Business model must be one of: {valid_models}')
return v
class TenantResponse(BaseModel):
"""Tenant response schema - Updated to use subscription relationship"""
id: str # ✅ Keep as str for Pydantic validation
name: str
subdomain: Optional[str]
business_type: str
business_model: Optional[str]
tenant_type: Optional[str] = "standalone" # standalone, parent, or child
parent_tenant_id: Optional[str] = None # For child tenants
address: str
city: str
postal_code: str
# Regional/Localization settings
timezone: Optional[str] = "Europe/Madrid"
currency: Optional[str] = "EUR" # Currency code: EUR, USD, GBP
language: Optional[str] = "es" # Language code: es, en, eu
phone: Optional[str]
is_active: bool
subscription_plan: Optional[str] = None # Populated from subscription relationship or service
ml_model_trained: bool
last_training_date: Optional[datetime]
owner_id: str # ✅ Keep as str for Pydantic validation
created_at: datetime
# ✅ FIX: Add custom validator to convert UUID to string
@field_validator('id', 'owner_id', 'parent_tenant_id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TenantAccessResponse(BaseModel):
"""Tenant access verification response"""
has_access: bool
role: str
permissions: List[str]
class TenantMemberResponse(BaseModel):
"""Tenant member response - FIXED VERSION with enriched user data"""
id: str
user_id: str
role: str
is_active: bool
joined_at: Optional[datetime]
# Enriched user fields (populated via service layer)
user_email: Optional[str] = None
user_full_name: Optional[str] = None
user: Optional[Dict[str, Any]] = None # Full user object for compatibility
# ✅ FIX: Add custom validator to convert UUID to string
@field_validator('id', 'user_id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TenantUpdate(BaseModel):
"""Tenant update schema"""
name: Optional[str] = Field(None, min_length=2, max_length=200)
address: Optional[str] = Field(None, min_length=10, max_length=500)
phone: Optional[str] = None
business_type: Optional[str] = None
business_model: Optional[str] = None
# Regional/Localization settings
timezone: Optional[str] = None
currency: Optional[str] = Field(None, pattern=r'^(EUR|USD|GBP)$') # Currency code
language: Optional[str] = Field(None, pattern=r'^(es|en|eu)$') # Language code
class TenantListResponse(BaseModel):
"""Response schema for listing tenants"""
tenants: List[TenantResponse]
total: int
page: int
per_page: int
has_next: bool
has_prev: bool
class TenantMemberInvitation(BaseModel):
"""Schema for inviting a member to a tenant"""
email: str = Field(..., pattern=r'^[^@]+@[^@]+\.[^@]+$')
role: str = Field(..., pattern=r'^(admin|member|viewer)$')
message: Optional[str] = Field(None, max_length=500)
class TenantMemberUpdate(BaseModel):
"""Schema for updating tenant member"""
role: Optional[str] = Field(None, pattern=r'^(owner|admin|member|viewer)$')
is_active: Optional[bool] = None
class AddMemberWithUserCreate(BaseModel):
"""Schema for adding member with optional user creation (pilot phase)"""
# For existing users
user_id: Optional[str] = Field(None, description="ID of existing user to add")
# For new user creation
create_user: bool = Field(False, description="Whether to create a new user")
email: Optional[str] = Field(None, description="Email for new user (if create_user=True)")
full_name: Optional[str] = Field(None, min_length=2, max_length=100, description="Full name for new user")
password: Optional[str] = Field(None, min_length=8, max_length=128, description="Password for new user")
phone: Optional[str] = Field(None, description="Phone number for new user")
language: Optional[str] = Field("es", pattern="^(es|en|eu)$", description="Preferred language")
timezone: Optional[str] = Field("Europe/Madrid", description="User timezone")
# Common fields
role: str = Field(..., pattern=r'^(admin|member|viewer)$', description="Role in the tenant")
@field_validator('email', 'full_name', 'password')
@classmethod
def validate_user_creation_fields(cls, v, info: ValidationInfo):
"""Validate that required fields are present when creating a user"""
if info.data.get('create_user') and info.field_name in ['email', 'full_name', 'password']:
if not v:
raise ValueError(f"{info.field_name} is required when create_user is True")
return v
@field_validator('user_id')
@classmethod
def validate_user_id_or_create(cls, v, info: ValidationInfo):
"""Ensure either user_id or create_user is provided"""
if not v and not info.data.get('create_user'):
raise ValueError("Either user_id or create_user must be provided")
if v and info.data.get('create_user'):
raise ValueError("Cannot specify both user_id and create_user")
return v
class TenantSubscriptionUpdate(BaseModel):
"""Schema for updating tenant subscription"""
plan: str = Field(..., pattern=r'^(basic|professional|enterprise)$')
billing_cycle: str = Field(default="monthly", pattern=r'^(monthly|yearly)$')
class TenantStatsResponse(BaseModel):
"""Tenant statistics response"""
tenant_id: str
total_members: int
active_members: int
total_predictions: int
models_trained: int
last_training_date: Optional[datetime]
subscription_plan: str
subscription_status: str
# ============================================================================
# ENTERPRISE CHILD TENANT SCHEMAS
# ============================================================================
class ChildTenantCreate(BaseModel):
"""Schema for creating a child tenant in enterprise hierarchy - Updated to match tenant model"""
name: str = Field(..., min_length=2, max_length=200, description="Child tenant name (e.g., 'Madrid - Salamanca')")
city: str = Field(..., min_length=2, max_length=100, description="City where the outlet is located")
zone: Optional[str] = Field(None, max_length=100, description="Zone or neighborhood")
address: str = Field(..., min_length=10, max_length=500, description="Full address of the outlet")
postal_code: str = Field(..., pattern=r"^\d{5}$", description="5-digit postal code")
location_code: str = Field(..., min_length=1, max_length=10, description="Short location code (e.g., MAD, BCN)")
# Coordinates (can be geocoded from address if not provided)
latitude: Optional[float] = Field(None, ge=-90, le=90, description="Latitude coordinate")
longitude: Optional[float] = Field(None, ge=-180, le=180, description="Longitude coordinate")
# Contact info (inherits from parent if not provided)
phone: Optional[str] = Field(None, min_length=9, max_length=20, description="Contact phone")
email: Optional[str] = Field(None, description="Contact email")
# Business info
business_type: Optional[str] = Field(None, max_length=100, description="Type of business")
business_model: Optional[str] = Field(None, max_length=100, description="Business model")
# Timezone configuration
timezone: Optional[str] = Field(None, max_length=50, description="Timezone for scheduling")
# Additional metadata
metadata: Optional[Dict[str, Any]] = Field(None, description="Additional metadata for the child tenant")
@field_validator('location_code')
@classmethod
def validate_location_code(cls, v):
"""Ensure location code is uppercase and alphanumeric"""
if not v.replace('-', '').replace('_', '').isalnum():
raise ValueError('Location code must be alphanumeric (with optional hyphens/underscores)')
return v.upper()
@field_validator('phone')
@classmethod
def validate_phone(cls, v):
"""Validate Spanish phone number if provided"""
if v is None:
return v
phone = re.sub(r'[\s\-\(\)]', '', v)
patterns = [
r'^(\+34|0034|34)?[6789]\d{8}$', # Mobile
r'^(\+34|0034|34)?9\d{8}$', # Landline
]
if not any(re.match(pattern, phone) for pattern in patterns):
raise ValueError('Invalid Spanish phone number')
return v
@field_validator('business_type')
@classmethod
def validate_business_type(cls, v):
"""Validate business type if provided"""
if v is None:
return v
valid_types = ['bakery', 'coffee_shop', 'pastry_shop', 'restaurant']
if v not in valid_types:
raise ValueError(f'Business type must be one of: {valid_types}')
return v
@field_validator('business_model')
@classmethod
def validate_business_model(cls, v):
"""Validate business model if provided"""
if v is None:
return v
valid_models = ['individual_bakery', 'central_baker_satellite', 'retail_bakery', 'hybrid_bakery']
if v not in valid_models:
raise ValueError(f'Business model must be one of: {valid_models}')
return v
@field_validator('timezone')
@classmethod
def validate_timezone(cls, v):
"""Validate timezone if provided"""
if v is None:
return v
# Basic timezone validation - should match common timezone formats
if not re.match(r'^[A-Za-z_+/]+$', v):
raise ValueError('Invalid timezone format')
return v
class BulkChildTenantsCreate(BaseModel):
"""Schema for bulk creating child tenants during onboarding"""
child_tenants: List[ChildTenantCreate] = Field(
...,
min_length=1,
max_length=50,
description="List of child tenants to create (1-50)"
)
# Optional: Auto-configure distribution routes
auto_configure_distribution: bool = Field(
True,
description="Whether to automatically set up distribution routes between parent and children"
)
@field_validator('child_tenants')
@classmethod
def validate_unique_location_codes(cls, v):
"""Ensure all location codes are unique within the batch"""
location_codes = [ct.location_code for ct in v]
if len(location_codes) != len(set(location_codes)):
raise ValueError('Location codes must be unique within the batch')
return v
class ChildTenantResponse(TenantResponse):
"""Response schema for child tenant - extends TenantResponse"""
location_code: Optional[str] = None
zone: Optional[str] = None
hierarchy_path: Optional[str] = None
class Config:
from_attributes = True
class BulkChildTenantsResponse(BaseModel):
"""Response schema for bulk child tenant creation"""
parent_tenant_id: str
created_count: int
failed_count: int
created_tenants: List[ChildTenantResponse]
failed_tenants: List[Dict[str, Any]] = Field(
default_factory=list,
description="List of failed tenants with error details"
)
distribution_configured: bool = False
@field_validator('parent_tenant_id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if isinstance(v, UUID):
return str(v)
return v
class TenantHierarchyResponse(BaseModel):
"""Response schema for tenant hierarchy information"""
tenant_id: str
tenant_type: str = Field(..., description="Type: standalone, parent, or child")
parent_tenant_id: Optional[str] = Field(None, description="Parent tenant ID if this is a child")
hierarchy_path: Optional[str] = Field(None, description="Materialized path for hierarchy queries")
child_count: int = Field(0, description="Number of child tenants (for parent tenants)")
hierarchy_level: int = Field(0, description="Level in hierarchy: 0=parent, 1=child, 2=grandchild, etc.")
@field_validator('tenant_id', 'parent_tenant_id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
"""Convert UUID objects to strings for JSON serialization"""
if v is None:
return v
if isinstance(v, UUID):
return str(v)
return v
class Config:
from_attributes = True
class TenantSearchRequest(BaseModel):
"""Tenant search request schema"""
query: Optional[str] = None
business_type: Optional[str] = None
city: Optional[str] = None
status: Optional[str] = None
limit: int = Field(default=50, ge=1, le=100)
offset: int = Field(default=0, ge=0)

View File

@@ -0,0 +1,19 @@
"""
Tenant Service Layer
Business logic services for tenant operations
"""
from .tenant_service import TenantService, EnhancedTenantService
from .subscription_service import SubscriptionService
from .payment_service import PaymentService
from .coupon_service import CouponService
from .subscription_orchestration_service import SubscriptionOrchestrationService
__all__ = [
"TenantService",
"EnhancedTenantService",
"SubscriptionService",
"PaymentService",
"CouponService",
"SubscriptionOrchestrationService"
]

View File

@@ -0,0 +1,108 @@
"""
Coupon Service - Coupon Operations
This service handles ONLY coupon validation and redemption
NO payment provider interactions, NO subscription logic
"""
import structlog
from typing import Dict, Any, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from app.repositories.coupon_repository import CouponRepository
logger = structlog.get_logger()
class CouponService:
"""Service for handling coupon validation and redemption"""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
self.coupon_repo = CouponRepository(db_session)
async def validate_coupon_code(
self,
coupon_code: str,
tenant_id: str
) -> Dict[str, Any]:
"""
Validate a coupon code for a tenant
Args:
coupon_code: Coupon code to validate
tenant_id: Tenant ID
Returns:
Dictionary with validation results
"""
try:
validation = await self.coupon_repo.validate_coupon(coupon_code, tenant_id)
return {
"valid": validation.valid,
"error_message": validation.error_message,
"discount_preview": validation.discount_preview,
"coupon": {
"code": validation.coupon.code,
"discount_type": validation.coupon.discount_type.value,
"discount_value": validation.coupon.discount_value
} if validation.coupon else None
}
except Exception as e:
logger.error("Failed to validate coupon", error=str(e), coupon_code=coupon_code)
return {
"valid": False,
"error_message": "Error al validar el cupón",
"discount_preview": None,
"coupon": None
}
async def redeem_coupon(
self,
coupon_code: str,
tenant_id: str,
base_trial_days: int = 0
) -> Tuple[bool, Optional[Dict[str, Any]], Optional[str]]:
"""
Redeem a coupon for a tenant
Args:
coupon_code: Coupon code to redeem
tenant_id: Tenant ID
base_trial_days: Base trial days without coupon
Returns:
Tuple of (success, discount_applied, error_message)
"""
try:
success, redemption, error = await self.coupon_repo.redeem_coupon(
coupon_code,
tenant_id,
base_trial_days
)
if success and redemption:
return True, redemption.discount_applied, None
else:
return False, None, error
except Exception as e:
logger.error("Failed to redeem coupon", error=str(e), coupon_code=coupon_code)
return False, None, f"Error al aplicar el cupón: {str(e)}"
async def get_coupon_by_code(self, coupon_code: str) -> Optional[Any]:
"""
Get coupon details by code
Args:
coupon_code: Coupon code to retrieve
Returns:
Coupon object or None
"""
try:
return await self.coupon_repo.get_coupon_by_code(coupon_code)
except Exception as e:
logger.error("Failed to get coupon by code", error=str(e), coupon_code=coupon_code)
return None

View File

@@ -0,0 +1,365 @@
# services/tenant/app/services/network_alerts_service.py
"""
Network Alerts Service
Business logic for aggregating and managing alerts across enterprise networks
"""
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
import uuid
import structlog
logger = structlog.get_logger()
class NetworkAlertsService:
"""
Service for aggregating and managing alerts across enterprise networks
"""
def __init__(self, tenant_client, alerts_client):
self.tenant_client = tenant_client
self.alerts_client = alerts_client
async def get_child_tenants(self, parent_id: str) -> List[Dict[str, Any]]:
"""
Get all child tenants for a parent tenant
"""
try:
# Get child tenants from tenant service
children = await self.tenant_client.get_child_tenants(parent_id)
# Enrich with tenant details
enriched_children = []
for child in children:
child_details = await self.tenant_client.get_tenant(child['id'])
enriched_children.append({
'id': child['id'],
'name': child_details.get('name', f"Outlet {child['id']}"),
'subdomain': child_details.get('subdomain'),
'city': child_details.get('city')
})
return enriched_children
except Exception as e:
logger.error(f"Failed to get child tenants, parent_id={parent_id}, error={str(e)}")
raise Exception(f"Failed to get child tenants: {str(e)}")
async def get_alerts_for_tenant(self, tenant_id: str) -> List[Dict[str, Any]]:
"""
Get alerts for a specific tenant
"""
try:
# In a real implementation, this would call the alert service
# For demo purposes, we'll simulate some alert data
# Simulate different types of alerts based on tenant type
simulated_alerts = []
# Generate some sample alerts
alert_types = ['inventory', 'production', 'delivery', 'equipment', 'quality']
severities = ['critical', 'high', 'medium', 'low']
for i in range(3): # Generate 3 sample alerts per tenant
alert = {
'alert_id': str(uuid.uuid4()),
'tenant_id': tenant_id,
'alert_type': alert_types[i % len(alert_types)],
'severity': severities[i % len(severities)],
'title': f"{alert_types[i % len(alert_types)].title()} Alert Detected",
'message': f"Sample {alert_types[i % len(alert_types)]} alert for tenant {tenant_id}",
'timestamp': (datetime.now() - timedelta(hours=i)).isoformat(),
'status': 'active' if i < 2 else 'resolved',
'source_system': f"{alert_types[i % len(alert_types)]}-service",
'related_entity_id': f"entity-{i+1}",
'related_entity_type': alert_types[i % len(alert_types)]
}
simulated_alerts.append(alert)
return simulated_alerts
except Exception as e:
logger.error(f"Failed to get alerts for tenant, tenant_id={tenant_id}, error={str(e)}")
raise Exception(f"Failed to get alerts: {str(e)}")
async def get_network_alerts(self, parent_id: str) -> List[Dict[str, Any]]:
"""
Get all alerts across the network
"""
try:
# Get all child tenants
child_tenants = await self.get_child_tenants(parent_id)
# Aggregate alerts from all child tenants
all_alerts = []
for child in child_tenants:
child_id = child['id']
child_alerts = await self.get_alerts_for_tenant(child_id)
all_alerts.extend(child_alerts)
return all_alerts
except Exception as e:
logger.error("Failed to get network alerts", parent_id=parent_id, error=str(e))
raise Exception(f"Failed to get network alerts: {str(e)}")
async def detect_alert_correlations(
self,
alerts: List[Dict[str, Any]],
min_correlation_strength: float = 0.7
) -> List[Dict[str, Any]]:
"""
Detect correlations between alerts
"""
try:
# Simple correlation detection (in real implementation, this would be more sophisticated)
correlations = []
# Group alerts by type and time proximity
alert_groups = {}
for alert in alerts:
alert_type = alert['alert_type']
timestamp = alert['timestamp']
# Use timestamp as key for grouping (simplified)
if alert_type not in alert_groups:
alert_groups[alert_type] = []
alert_groups[alert_type].append(alert)
# Create correlation groups
for alert_type, group in alert_groups.items():
if len(group) >= 2: # Only create correlations for groups with 2+ alerts
primary_alert = group[0]
related_alerts = group[1:]
correlation = {
'correlation_id': str(uuid.uuid4()),
'primary_alert': primary_alert,
'related_alerts': related_alerts,
'correlation_type': 'temporal',
'correlation_strength': 0.85,
'impact_analysis': f"Multiple {alert_type} alerts detected within short timeframe"
}
if correlation['correlation_strength'] >= min_correlation_strength:
correlations.append(correlation)
return correlations
except Exception as e:
logger.error("Failed to detect alert correlations", error=str(e))
raise Exception(f"Failed to detect correlations: {str(e)}")
async def acknowledge_alert(self, parent_id: str, alert_id: str) -> Dict[str, Any]:
"""
Acknowledge an alert
"""
try:
# In a real implementation, this would update the alert status
# For demo purposes, we'll simulate the operation
logger.info("Alert acknowledged", parent_id=parent_id, alert_id=alert_id)
return {
'success': True,
'alert_id': alert_id,
'status': 'acknowledged'
}
except Exception as e:
logger.error("Failed to acknowledge alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
raise Exception(f"Failed to acknowledge alert: {str(e)}")
async def resolve_alert(self, parent_id: str, alert_id: str, resolution_notes: Optional[str] = None) -> Dict[str, Any]:
"""
Resolve an alert
"""
try:
# In a real implementation, this would update the alert status
# For demo purposes, we'll simulate the operation
logger.info("Alert resolved", parent_id=parent_id, alert_id=alert_id, notes=resolution_notes)
return {
'success': True,
'alert_id': alert_id,
'status': 'resolved',
'resolution_notes': resolution_notes
}
except Exception as e:
logger.error("Failed to resolve alert", parent_id=parent_id, alert_id=alert_id, error=str(e))
raise Exception(f"Failed to resolve alert: {str(e)}")
async def get_alert_trends(self, parent_id: str, days: int = 30) -> List[Dict[str, Any]]:
"""
Get alert trends over time
"""
try:
# Simulate trend data
trends = []
end_date = datetime.now()
# Generate daily trend data
for i in range(days):
date = end_date - timedelta(days=i)
# Simulate varying alert counts with weekly pattern
base_count = 5
weekly_variation = int((i % 7) * 1.5) # Higher on weekdays
daily_noise = (i % 3 - 1) # Daily noise
alert_count = max(1, base_count + weekly_variation + daily_noise)
trends.append({
'date': date.strftime('%Y-%m-%d'),
'total_alerts': alert_count,
'critical_alerts': max(0, int(alert_count * 0.1)),
'high_alerts': max(0, int(alert_count * 0.2)),
'medium_alerts': max(0, int(alert_count * 0.4)),
'low_alerts': max(0, int(alert_count * 0.3))
})
# Sort by date (oldest first)
trends.sort(key=lambda x: x['date'])
return trends
except Exception as e:
logger.error("Failed to get alert trends", parent_id=parent_id, error=str(e))
raise Exception(f"Failed to get alert trends: {str(e)}")
async def get_prioritized_alerts(self, parent_id: str, limit: int = 10) -> List[Dict[str, Any]]:
"""
Get prioritized alerts based on impact and urgency
"""
try:
# Get all network alerts
all_alerts = await self.get_network_alerts(parent_id)
if not all_alerts:
return []
# Simple prioritization (in real implementation, this would use ML)
# Priority based on severity and recency
severity_scores = {'critical': 4, 'high': 3, 'medium': 2, 'low': 1}
for alert in all_alerts:
severity_score = severity_scores.get(alert['severity'], 1)
# Add recency score (newer alerts get higher priority)
timestamp = datetime.fromisoformat(alert['timestamp'])
recency_score = min(3, (datetime.now() - timestamp).days + 1)
alert['priority_score'] = severity_score * recency_score
# Sort by priority score (highest first)
all_alerts.sort(key=lambda x: x['priority_score'], reverse=True)
# Return top N alerts
prioritized = all_alerts[:limit]
# Remove priority score from response
for alert in prioritized:
alert.pop('priority_score', None)
return prioritized
except Exception as e:
logger.error("Failed to get prioritized alerts", parent_id=parent_id, error=str(e))
raise Exception(f"Failed to get prioritized alerts: {str(e)}")
# Helper class for alert analysis
class AlertAnalyzer:
"""
Helper class for analyzing alert patterns
"""
@staticmethod
def calculate_alert_severity_score(alert: Dict[str, Any]) -> float:
"""
Calculate severity score for an alert
"""
severity_scores = {'critical': 1.0, 'high': 0.75, 'medium': 0.5, 'low': 0.25}
return severity_scores.get(alert['severity'], 0.25)
@staticmethod
def detect_alert_patterns(alerts: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Detect patterns in alert data
"""
if not alerts:
return {'patterns': [], 'anomalies': []}
patterns = []
anomalies = []
# Simple pattern detection
alert_types = [a['alert_type'] for a in alerts]
type_counts = {}
for alert_type in alert_types:
type_counts[alert_type] = type_counts.get(alert_type, 0) + 1
# Detect if one type dominates
total_alerts = len(alerts)
for alert_type, count in type_counts.items():
if count / total_alerts > 0.6: # More than 60% of one type
patterns.append({
'type': 'dominant_alert_type',
'pattern': f'{alert_type} alerts dominate ({count}/{total_alerts})',
'confidence': 0.85
})
return {'patterns': patterns, 'anomalies': anomalies}
# Helper class for alert correlation
class AlertCorrelator:
"""
Helper class for correlating alerts
"""
@staticmethod
def calculate_correlation_strength(alert1: Dict[str, Any], alert2: Dict[str, Any]) -> float:
"""
Calculate correlation strength between two alerts
"""
# Simple correlation based on type and time proximity
same_type = 1.0 if alert1['alert_type'] == alert2['alert_type'] else 0.3
time1 = datetime.fromisoformat(alert1['timestamp'])
time2 = datetime.fromisoformat(alert2['timestamp'])
time_diff_hours = abs((time2 - time1).total_seconds() / 3600)
# Time proximity score (higher for closer times)
time_proximity = max(0, 1.0 - min(1.0, time_diff_hours / 24)) # Decays over 24 hours
return same_type * time_proximity
# Helper class for alert prioritization
class AlertPrioritizer:
"""
Helper class for prioritizing alerts
"""
@staticmethod
def calculate_priority_score(alert: Dict[str, Any]) -> float:
"""
Calculate priority score for an alert
"""
# Base score from severity
severity_scores = {'critical': 100, 'high': 75, 'medium': 50, 'low': 25}
base_score = severity_scores.get(alert['severity'], 25)
# Add recency bonus (newer alerts get higher priority)
timestamp = datetime.fromisoformat(alert['timestamp'])
hours_old = (datetime.now() - timestamp).total_seconds() / 3600
recency_bonus = max(0, 50 - hours_old) # Decays over 50 hours
return base_score + recency_bonus

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,358 @@
"""
Registration State Management Service
Tracks registration progress and handles state transitions
"""
import structlog
from typing import Dict, Any, Optional
from datetime import datetime
from enum import Enum
from uuid import uuid4
from shared.exceptions.registration_exceptions import (
RegistrationStateError,
InvalidStateTransitionError
)
# Configure logging
logger = structlog.get_logger()
class RegistrationState(Enum):
"""Registration process states"""
INITIATED = "initiated"
PAYMENT_VERIFICATION_PENDING = "payment_verification_pending"
PAYMENT_VERIFIED = "payment_verified"
SUBSCRIPTION_CREATED = "subscription_created"
USER_CREATED = "user_created"
COMPLETED = "completed"
FAILED = "failed"
class RegistrationStateService:
"""
Registration State Management Service
Tracks and manages registration process state
"""
def __init__(self):
"""Initialize state service"""
# In production, this would use a database
self.registration_states = {}
async def create_registration_state(
self,
email: str,
user_data: Dict[str, Any]
) -> str:
"""
Create new registration state
Args:
email: User email
user_data: Registration data
Returns:
Registration state ID
"""
try:
state_id = str(uuid4())
registration_state = {
'state_id': state_id,
'email': email,
'current_state': RegistrationState.INITIATED.value,
'created_at': datetime.now().isoformat(),
'updated_at': datetime.now().isoformat(),
'user_data': user_data,
'setup_intent_id': None,
'customer_id': None,
'subscription_id': None,
'error': None
}
self.registration_states[state_id] = registration_state
logger.info("Registration state created",
state_id=state_id,
email=email,
current_state=RegistrationState.INITIATED.value)
return state_id
except Exception as e:
logger.error("Failed to create registration state",
error=str(e),
email=email,
exc_info=True)
raise RegistrationStateError(f"State creation failed: {str(e)}") from e
async def transition_state(
self,
state_id: str,
new_state: RegistrationState,
context: Optional[Dict[str, Any]] = None
) -> None:
"""
Transition registration to new state with validation
Args:
state_id: Registration state ID
new_state: New state to transition to
context: Additional context data
Raises:
InvalidStateTransitionError: If transition is invalid
RegistrationStateError: If transition fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
current_state = self.registration_states[state_id]['current_state']
# Validate state transition
if not self._is_valid_transition(current_state, new_state.value):
raise InvalidStateTransitionError(
f"Invalid transition from {current_state} to {new_state.value}"
)
# Update state
self.registration_states[state_id]['current_state'] = new_state.value
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
# Update context data
if context:
self.registration_states[state_id].update(context)
logger.info("Registration state transitioned",
state_id=state_id,
from_state=current_state,
to_state=new_state.value)
except InvalidStateTransitionError:
raise
except Exception as e:
logger.error("State transition failed",
error=str(e),
state_id=state_id,
from_state=current_state,
to_state=new_state.value,
exc_info=True)
raise RegistrationStateError(f"State transition failed: {str(e)}") from e
async def update_state_context(
self,
state_id: str,
context: Dict[str, Any]
) -> None:
"""
Update state context data
Args:
state_id: Registration state ID
context: Context data to update
Raises:
RegistrationStateError: If update fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
self.registration_states[state_id].update(context)
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
logger.debug("Registration state context updated",
state_id=state_id,
context_keys=list(context.keys()))
except Exception as e:
logger.error("State context update failed",
error=str(e),
state_id=state_id,
exc_info=True)
raise RegistrationStateError(f"State context update failed: {str(e)}") from e
async def get_registration_state(
self,
state_id: str
) -> Dict[str, Any]:
"""
Get registration state by ID
Args:
state_id: Registration state ID
Returns:
Registration state data
Raises:
RegistrationStateError: If state not found
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
return self.registration_states[state_id]
except Exception as e:
logger.error("Failed to get registration state",
error=str(e),
state_id=state_id,
exc_info=True)
raise RegistrationStateError(f"State retrieval failed: {str(e)}") from e
async def rollback_state(
self,
state_id: str,
target_state: RegistrationState
) -> None:
"""
Rollback registration to previous state
Args:
state_id: Registration state ID
target_state: State to rollback to
Raises:
RegistrationStateError: If rollback fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
current_state = self.registration_states[state_id]['current_state']
# Only allow rollback to earlier states
if not self._can_rollback(current_state, target_state.value):
raise InvalidStateTransitionError(
f"Cannot rollback from {current_state} to {target_state.value}"
)
# Update state
self.registration_states[state_id]['current_state'] = target_state.value
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
self.registration_states[state_id]['error'] = "Registration rolled back"
logger.warning("Registration state rolled back",
state_id=state_id,
from_state=current_state,
to_state=target_state.value)
except InvalidStateTransitionError:
raise
except Exception as e:
logger.error("State rollback failed",
error=str(e),
state_id=state_id,
from_state=current_state,
to_state=target_state.value,
exc_info=True)
raise RegistrationStateError(f"State rollback failed: {str(e)}") from e
async def mark_registration_failed(
self,
state_id: str,
error: str
) -> None:
"""
Mark registration as failed
Args:
state_id: Registration state ID
error: Error message
Raises:
RegistrationStateError: If operation fails
"""
try:
if state_id not in self.registration_states:
raise RegistrationStateError(f"Registration state {state_id} not found")
self.registration_states[state_id]['current_state'] = RegistrationState.FAILED.value
self.registration_states[state_id]['error'] = error
self.registration_states[state_id]['updated_at'] = datetime.now().isoformat()
logger.error("Registration marked as failed",
state_id=state_id,
error=error)
except Exception as e:
logger.error("Failed to mark registration as failed",
error=str(e),
state_id=state_id,
exc_info=True)
raise RegistrationStateError(f"Mark failed operation failed: {str(e)}") from e
def _is_valid_transition(self, current_state: str, new_state: str) -> bool:
"""
Validate state transition
Args:
current_state: Current state
new_state: New state
Returns:
True if transition is valid
"""
# Define valid state transitions
valid_transitions = {
RegistrationState.INITIATED.value: [
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
RegistrationState.FAILED.value
],
RegistrationState.PAYMENT_VERIFICATION_PENDING.value: [
RegistrationState.PAYMENT_VERIFIED.value,
RegistrationState.FAILED.value
],
RegistrationState.PAYMENT_VERIFIED.value: [
RegistrationState.SUBSCRIPTION_CREATED.value,
RegistrationState.FAILED.value
],
RegistrationState.SUBSCRIPTION_CREATED.value: [
RegistrationState.USER_CREATED.value,
RegistrationState.FAILED.value
],
RegistrationState.USER_CREATED.value: [
RegistrationState.COMPLETED.value,
RegistrationState.FAILED.value
],
RegistrationState.COMPLETED.value: [],
RegistrationState.FAILED.value: []
}
return new_state in valid_transitions.get(current_state, [])
def _can_rollback(self, current_state: str, target_state: str) -> bool:
"""
Check if rollback to target state is allowed
Args:
current_state: Current state
target_state: Target state for rollback
Returns:
True if rollback is allowed
"""
# Define state order for rollback validation
state_order = [
RegistrationState.INITIATED.value,
RegistrationState.PAYMENT_VERIFICATION_PENDING.value,
RegistrationState.PAYMENT_VERIFIED.value,
RegistrationState.SUBSCRIPTION_CREATED.value,
RegistrationState.USER_CREATED.value,
RegistrationState.COMPLETED.value
]
try:
current_index = state_order.index(current_state)
target_index = state_order.index(target_state)
# Can only rollback to earlier states
return target_index < current_index
except ValueError:
return False
# Singleton instance for dependency injection
registration_state_service = RegistrationStateService()

View File

@@ -0,0 +1,258 @@
"""
Subscription Cache Service
Provides Redis-based caching for subscription data with 10-minute TTL
"""
import structlog
from typing import Optional, Dict, Any
import json
from datetime import datetime, timezone
from app.repositories import SubscriptionRepository
from app.models.tenants import Subscription
logger = structlog.get_logger()
# Cache TTL in seconds (10 minutes)
SUBSCRIPTION_CACHE_TTL = 600
class SubscriptionCacheService:
"""Service for cached subscription lookups"""
def __init__(self, redis_client=None, database_manager=None):
self.redis = redis_client
self.database_manager = database_manager
async def ensure_database_manager(self):
"""Ensure database manager is properly initialized"""
if self.database_manager is None:
from app.core.config import settings
from shared.database.base import create_database_manager
self.database_manager = create_database_manager(settings.DATABASE_URL, "tenant-service")
async def get_tenant_tier_cached(self, tenant_id: str) -> str:
"""
Get tenant subscription tier with Redis caching
Args:
tenant_id: Tenant ID
Returns:
Subscription tier (starter, professional, enterprise)
"""
try:
# Ensure database manager is initialized
await self.ensure_database_manager()
cache_key = f"subscription:tier:{tenant_id}"
# Try to get from cache
if self.redis:
try:
cached_tier = await self.redis.get(cache_key)
if cached_tier:
logger.debug("Subscription tier cache hit", tenant_id=tenant_id, tier=cached_tier)
return cached_tier.decode('utf-8') if isinstance(cached_tier, bytes) else cached_tier
except Exception as e:
logger.warning("Redis cache read failed, falling back to database",
tenant_id=tenant_id, error=str(e))
# Cache miss or Redis unavailable - fetch from database
logger.debug("Subscription tier cache miss", tenant_id=tenant_id)
async with self.database_manager.get_session() as db_session:
subscription_repo = SubscriptionRepository(Subscription, db_session)
subscription = await subscription_repo.get_active_subscription(tenant_id)
if not subscription:
logger.warning("No active subscription found, returning starter tier",
tenant_id=tenant_id)
return "starter"
tier = subscription.plan
# Cache the result
if self.redis:
try:
await self.redis.setex(cache_key, SUBSCRIPTION_CACHE_TTL, tier)
logger.debug("Cached subscription tier", tenant_id=tenant_id, tier=tier)
except Exception as e:
logger.warning("Failed to cache subscription tier",
tenant_id=tenant_id, error=str(e))
return tier
except Exception as e:
logger.error("Failed to get subscription tier",
tenant_id=tenant_id, error=str(e))
return "starter" # Fallback to starter on error
async def get_tenant_subscription_cached(self, tenant_id: str) -> Optional[Dict[str, Any]]:
"""
Get full tenant subscription with Redis caching
Args:
tenant_id: Tenant ID
Returns:
Subscription data as dictionary or None
"""
try:
# Ensure database manager is initialized
await self.ensure_database_manager()
cache_key = f"subscription:full:{tenant_id}"
# Try to get from cache
if self.redis:
try:
cached_data = await self.redis.get(cache_key)
if cached_data:
logger.debug("Subscription cache hit", tenant_id=tenant_id)
data = json.loads(cached_data.decode('utf-8') if isinstance(cached_data, bytes) else cached_data)
return data
except Exception as e:
logger.warning("Redis cache read failed, falling back to database",
tenant_id=tenant_id, error=str(e))
# Cache miss or Redis unavailable - fetch from database
logger.debug("Subscription cache miss", tenant_id=tenant_id)
async with self.database_manager.get_session() as db_session:
subscription_repo = SubscriptionRepository(Subscription, db_session)
subscription = await subscription_repo.get_active_subscription(tenant_id)
if not subscription:
logger.warning("No active subscription found", tenant_id=tenant_id)
return None
# Convert to dictionary
subscription_data = {
"id": str(subscription.id),
"tenant_id": str(subscription.tenant_id),
"plan": subscription.plan,
"status": subscription.status,
"monthly_price": subscription.monthly_price,
"billing_cycle": subscription.billing_cycle,
"next_billing_date": subscription.next_billing_date.isoformat() if subscription.next_billing_date else None,
"trial_ends_at": subscription.trial_ends_at.isoformat() if subscription.trial_ends_at else None,
"cancelled_at": subscription.cancelled_at.isoformat() if subscription.cancelled_at else None,
"cancellation_effective_date": subscription.cancellation_effective_date.isoformat() if subscription.cancellation_effective_date else None,
"max_users": subscription.max_users,
"max_locations": subscription.max_locations,
"max_products": subscription.max_products,
"features": subscription.features,
"created_at": subscription.created_at.isoformat() if subscription.created_at else None,
"updated_at": subscription.updated_at.isoformat() if subscription.updated_at else None
}
# Cache the result
if self.redis:
try:
await self.redis.setex(
cache_key,
SUBSCRIPTION_CACHE_TTL,
json.dumps(subscription_data)
)
logger.debug("Cached subscription data", tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to cache subscription data",
tenant_id=tenant_id, error=str(e))
return subscription_data
except Exception as e:
logger.error("Failed to get subscription",
tenant_id=tenant_id, error=str(e))
return None
async def invalidate_subscription_cache(self, tenant_id: str) -> None:
"""
Invalidate subscription cache for a tenant
Args:
tenant_id: Tenant ID
"""
try:
if not self.redis:
logger.debug("Redis not available, skipping cache invalidation",
tenant_id=tenant_id)
return
tier_key = f"subscription:tier:{tenant_id}"
full_key = f"subscription:full:{tenant_id}"
# Delete both cache keys
await self.redis.delete(tier_key, full_key)
logger.info("Invalidated subscription cache",
tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to invalidate subscription cache",
tenant_id=tenant_id, error=str(e))
async def warm_cache(self, tenant_id: str) -> None:
"""
Pre-warm the cache by loading subscription data
Args:
tenant_id: Tenant ID
"""
try:
logger.debug("Warming subscription cache", tenant_id=tenant_id)
# Load both tier and full subscription to cache
await self.get_tenant_tier_cached(tenant_id)
await self.get_tenant_subscription_cached(tenant_id)
logger.info("Subscription cache warmed", tenant_id=tenant_id)
except Exception as e:
logger.warning("Failed to warm subscription cache",
tenant_id=tenant_id, error=str(e))
# Singleton instance for easy access
_cache_service_instance: Optional[SubscriptionCacheService] = None
def get_subscription_cache_service(redis_client=None) -> SubscriptionCacheService:
"""
Get or create subscription cache service singleton
Args:
redis_client: Optional Redis client
Returns:
SubscriptionCacheService instance
"""
global _cache_service_instance
if _cache_service_instance is None:
from shared.redis_utils import initialize_redis
from app.core.config import settings
import asyncio
# Initialize Redis client if not provided
redis_client_instance = None
if redis_client is None:
try:
loop = asyncio.get_event_loop()
if not loop.is_running():
redis_client_instance = asyncio.run(initialize_redis(settings.REDIS_URL))
else:
# If event loop is running, we can't use asyncio.run
# This is a limitation, but we'll handle it by not initializing Redis here
pass
except:
pass
else:
redis_client_instance = redis_client
_cache_service_instance = SubscriptionCacheService(redis_client=redis_client_instance)
elif redis_client is not None and _cache_service_instance.redis is None:
_cache_service_instance.redis = redis_client
return _cache_service_instance

View File

@@ -0,0 +1,705 @@
"""
Subscription Limit Service
Service for validating tenant actions against subscription limits and features
"""
import structlog
from typing import Dict, Any, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from fastapi import HTTPException, status
from datetime import datetime, timezone
from app.repositories import SubscriptionRepository, TenantRepository, TenantMemberRepository
from app.models.tenants import Subscription, Tenant, TenantMember
from shared.database.exceptions import DatabaseError
from shared.database.base import create_database_manager
from shared.subscription.plans import SubscriptionPlanMetadata, get_training_job_quota, get_forecast_quota
from shared.clients.recipes_client import create_recipes_client
from shared.clients.suppliers_client import create_suppliers_client
logger = structlog.get_logger()
class SubscriptionLimitService:
"""Service for validating subscription limits and features"""
def __init__(self, database_manager=None, redis_client=None):
self.database_manager = database_manager or create_database_manager()
self.redis = redis_client
async def _init_repositories(self, session):
"""Initialize repositories with session"""
self.subscription_repo = SubscriptionRepository(Subscription, session)
self.tenant_repo = TenantRepository(Tenant, session)
self.member_repo = TenantMemberRepository(TenantMember, session)
return {
'subscription': self.subscription_repo,
'tenant': self.tenant_repo,
'member': self.member_repo
}
async def get_tenant_subscription_limits(self, tenant_id: str) -> Dict[str, Any]:
"""Get current subscription limits for a tenant"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
# Return basic limits if no subscription
return {
"plan": "starter",
"max_users": 5,
"max_locations": 1,
"max_products": 50,
"features": {
"inventory_management": "basic",
"demand_prediction": "basic",
"production_reports": "basic",
"analytics": "basic",
"support": "email",
"ai_model_configuration": "basic" # Added AI model configuration for all tiers
}
}
return {
"plan": subscription.plan,
"max_users": subscription.max_users,
"max_locations": subscription.max_locations,
"max_products": subscription.max_products,
"features": subscription.features or {}
}
except Exception as e:
logger.error("Failed to get subscription limits",
tenant_id=tenant_id,
error=str(e))
# Return basic limits on error
return {
"plan": "starter",
"max_users": 5,
"max_locations": 1,
"max_products": 50,
"features": {}
}
async def can_add_location(self, tenant_id: str) -> Dict[str, Any]:
"""Check if tenant can add another location"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Get subscription limits
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
return {"can_add": False, "reason": "No active subscription"}
# Check if unlimited locations (-1)
if subscription.max_locations == -1:
return {"can_add": True, "reason": "Unlimited locations allowed"}
# Count current locations
# Currently, each tenant has 1 location (their primary bakery location)
# This is stored in tenant.address, tenant.city, tenant.postal_code
# If multi-location support is added in the future, this would query a locations table
current_locations = 1 # Each tenant has one primary location
can_add = current_locations < subscription.max_locations
return {
"can_add": can_add,
"current_count": current_locations,
"max_allowed": subscription.max_locations,
"reason": "Within limits" if can_add else f"Maximum {subscription.max_locations} locations allowed for {subscription.plan} plan"
}
except Exception as e:
logger.error("Failed to check location limits",
tenant_id=tenant_id,
error=str(e))
return {"can_add": False, "reason": "Error checking limits"}
async def can_add_product(self, tenant_id: str) -> Dict[str, Any]:
"""Check if tenant can add another product"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Get subscription limits
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
return {"can_add": False, "reason": "No active subscription"}
# Check if unlimited products (-1)
if subscription.max_products == -1:
return {"can_add": True, "reason": "Unlimited products allowed"}
# Count current products from inventory service
current_products = await self._get_ingredient_count(tenant_id)
can_add = current_products < subscription.max_products
return {
"can_add": can_add,
"current_count": current_products,
"max_allowed": subscription.max_products,
"reason": "Within limits" if can_add else f"Maximum {subscription.max_products} products allowed for {subscription.plan} plan"
}
except Exception as e:
logger.error("Failed to check product limits",
tenant_id=tenant_id,
error=str(e))
return {"can_add": False, "reason": "Error checking limits"}
async def can_add_user(self, tenant_id: str) -> Dict[str, Any]:
"""Check if tenant can add another user/member"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Get subscription limits
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
return {"can_add": False, "reason": "No active subscription"}
# Check if unlimited users (-1)
if subscription.max_users == -1:
return {"can_add": True, "reason": "Unlimited users allowed"}
# Count current active members
members = await self.member_repo.get_tenant_members(tenant_id, active_only=True)
current_users = len(members)
can_add = current_users < subscription.max_users
return {
"can_add": can_add,
"current_count": current_users,
"max_allowed": subscription.max_users,
"reason": "Within limits" if can_add else f"Maximum {subscription.max_users} users allowed for {subscription.plan} plan"
}
except Exception as e:
logger.error("Failed to check user limits",
tenant_id=tenant_id,
error=str(e))
return {"can_add": False, "reason": "Error checking limits"}
async def can_add_recipe(self, tenant_id: str) -> Dict[str, Any]:
"""Check if tenant can add another recipe"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
return {"can_add": False, "reason": "No active subscription"}
# Get recipe limit from plan
recipes_limit = await self._get_limit_from_plan(subscription.plan, 'recipes')
# Check if unlimited (-1 or None)
if recipes_limit is None or recipes_limit == -1:
return {"can_add": True, "reason": "Unlimited recipes allowed"}
# Count current recipes from recipes service
current_recipes = await self._get_recipe_count(tenant_id)
can_add = current_recipes < recipes_limit
return {
"can_add": can_add,
"current_count": current_recipes,
"max_allowed": recipes_limit,
"reason": "Within limits" if can_add else f"Maximum {recipes_limit} recipes allowed for {subscription.plan} plan"
}
except Exception as e:
logger.error("Failed to check recipe limits",
tenant_id=tenant_id,
error=str(e))
return {"can_add": False, "reason": "Error checking limits"}
async def can_add_supplier(self, tenant_id: str) -> Dict[str, Any]:
"""Check if tenant can add another supplier"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
return {"can_add": False, "reason": "No active subscription"}
# Get supplier limit from plan
suppliers_limit = await self._get_limit_from_plan(subscription.plan, 'suppliers')
# Check if unlimited (-1 or None)
if suppliers_limit is None or suppliers_limit == -1:
return {"can_add": True, "reason": "Unlimited suppliers allowed"}
# Count current suppliers from suppliers service
current_suppliers = await self._get_supplier_count(tenant_id)
can_add = current_suppliers < suppliers_limit
return {
"can_add": can_add,
"current_count": current_suppliers,
"max_allowed": suppliers_limit,
"reason": "Within limits" if can_add else f"Maximum {suppliers_limit} suppliers allowed for {subscription.plan} plan"
}
except Exception as e:
logger.error("Failed to check supplier limits",
tenant_id=tenant_id,
error=str(e))
return {"can_add": False, "reason": "Error checking limits"}
async def has_feature(self, tenant_id: str, feature: str) -> Dict[str, Any]:
"""Check if tenant has access to a specific feature"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
return {"has_feature": False, "reason": "No active subscription"}
features = subscription.features or {}
has_feature = feature in features
return {
"has_feature": has_feature,
"feature_value": features.get(feature),
"plan": subscription.plan,
"reason": "Feature available" if has_feature else f"Feature '{feature}' not available in {subscription.plan} plan"
}
except Exception as e:
logger.error("Failed to check feature access",
tenant_id=tenant_id,
feature=feature,
error=str(e))
return {"has_feature": False, "reason": "Error checking feature access"}
async def get_feature_level(self, tenant_id: str, feature: str) -> Optional[str]:
"""Get the level/type of a feature for a tenant"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
return None
features = subscription.features or {}
return features.get(feature)
except Exception as e:
logger.error("Failed to get feature level",
tenant_id=tenant_id,
feature=feature,
error=str(e))
return None
async def validate_plan_upgrade(self, tenant_id: str, new_plan: str) -> Dict[str, Any]:
"""Validate if a tenant can upgrade to a new plan"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
# Get current subscription
current_subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not current_subscription:
return {"can_upgrade": True, "reason": "No current subscription, can start with any plan"}
# Define plan hierarchy
plan_hierarchy = {"starter": 1, "professional": 2, "enterprise": 3}
current_level = plan_hierarchy.get(current_subscription.plan, 0)
new_level = plan_hierarchy.get(new_plan, 0)
if new_level == 0:
return {"can_upgrade": False, "reason": f"Invalid plan: {new_plan}"}
# Check current usage against new plan limits
from shared.subscription.plans import SubscriptionPlanMetadata, PlanPricing
new_plan_config = SubscriptionPlanMetadata.get_plan_info(new_plan)
# Get the max_users limit from the plan limits
plan_limits = new_plan_config.get('limits', {})
max_users_limit = plan_limits.get('users', 5) # Default to 5 if not specified
# Convert "Unlimited" string to None for comparison
if max_users_limit == "Unlimited":
max_users_limit = None
elif max_users_limit is None:
max_users_limit = -1 # Use -1 to represent unlimited in the comparison
# Check if current usage fits new plan
members = await self.member_repo.get_tenant_members(tenant_id, active_only=True)
current_users = len(members)
if max_users_limit is not None and max_users_limit != -1 and current_users > max_users_limit:
return {
"can_upgrade": False,
"reason": f"Current usage ({current_users} users) exceeds {new_plan} plan limits ({max_users_limit} users)"
}
return {
"can_upgrade": True,
"current_plan": current_subscription.plan,
"new_plan": new_plan,
"price_change": float(PlanPricing.get_price(new_plan)) - current_subscription.monthly_price,
"new_features": new_plan_config.get("features", []),
"reason": "Upgrade validation successful"
}
except Exception as e:
logger.error("Failed to validate plan upgrade",
tenant_id=tenant_id,
new_plan=new_plan,
error=str(e))
return {"can_upgrade": False, "reason": "Error validating upgrade"}
async def get_usage_summary(self, tenant_id: str) -> Dict[str, Any]:
"""Get a summary of current usage vs limits for a tenant - ALL 9 METRICS"""
try:
async with self.database_manager.get_session() as db_session:
await self._init_repositories(db_session)
subscription = await self.subscription_repo.get_active_subscription(tenant_id)
if not subscription:
logger.info("No subscription found, returning mock data", tenant_id=tenant_id)
return {
"plan": "demo",
"monthly_price": 0,
"status": "active",
"billing_cycle": "monthly",
"usage": {
"users": {
"current": 1,
"limit": 5,
"unlimited": False,
"usage_percentage": 20.0
},
"locations": {
"current": 1,
"limit": 1,
"unlimited": False,
"usage_percentage": 100.0
},
"products": {
"current": 0,
"limit": 50,
"unlimited": False,
"usage_percentage": 0.0
},
"recipes": {
"current": 0,
"limit": 100,
"unlimited": False,
"usage_percentage": 0.0
},
"suppliers": {
"current": 0,
"limit": 20,
"unlimited": False,
"usage_percentage": 0.0
},
"training_jobs_today": {
"current": 0,
"limit": 2,
"unlimited": False,
"usage_percentage": 0.0
},
"forecasts_today": {
"current": 0,
"limit": 10,
"unlimited": False,
"usage_percentage": 0.0
},
"api_calls_this_hour": {
"current": 0,
"limit": 100,
"unlimited": False,
"usage_percentage": 0.0
},
"file_storage_used_gb": {
"current": 0.0,
"limit": 1.0,
"unlimited": False,
"usage_percentage": 0.0
}
},
"features": {},
"next_billing_date": None,
"trial_ends_at": None
}
# Get current usage - Team & Organization
members = await self.member_repo.get_tenant_members(tenant_id, active_only=True)
current_users = len(members)
current_locations = 1 # Each tenant has one primary location
# Get current usage - Products & Inventory (parallel calls for performance)
import asyncio
current_products, current_recipes, current_suppliers = await asyncio.gather(
self._get_ingredient_count(tenant_id),
self._get_recipe_count(tenant_id),
self._get_supplier_count(tenant_id)
)
# Get current usage - IA & Analytics + API & Storage (parallel Redis calls for performance)
training_jobs_usage, forecasts_usage, api_calls_usage, storage_usage = await asyncio.gather(
self._get_training_jobs_today(tenant_id, subscription.plan),
self._get_forecasts_today(tenant_id, subscription.plan),
self._get_api_calls_this_hour(tenant_id, subscription.plan),
self._get_file_storage_usage_gb(tenant_id, subscription.plan)
)
# Get limits from subscription
recipes_limit = await self._get_limit_from_plan(subscription.plan, 'recipes')
suppliers_limit = await self._get_limit_from_plan(subscription.plan, 'suppliers')
return {
"plan": subscription.plan,
"monthly_price": subscription.monthly_price,
"status": subscription.status,
"billing_cycle": subscription.billing_cycle or "monthly",
"usage": {
# Team & Organization
"users": {
"current": current_users,
"limit": subscription.max_users,
"unlimited": subscription.max_users == -1,
"usage_percentage": 0 if subscription.max_users == -1 else self._calculate_percentage(current_users, subscription.max_users)
},
"locations": {
"current": current_locations,
"limit": subscription.max_locations,
"unlimited": subscription.max_locations == -1,
"usage_percentage": 0 if subscription.max_locations == -1 else self._calculate_percentage(current_locations, subscription.max_locations)
},
# Products & Inventory
"products": {
"current": current_products,
"limit": subscription.max_products,
"unlimited": subscription.max_products == -1,
"usage_percentage": 0 if subscription.max_products == -1 else self._calculate_percentage(current_products, subscription.max_products)
},
"recipes": {
"current": current_recipes,
"limit": recipes_limit,
"unlimited": recipes_limit is None,
"usage_percentage": self._calculate_percentage(current_recipes, recipes_limit)
},
"suppliers": {
"current": current_suppliers,
"limit": suppliers_limit,
"unlimited": suppliers_limit is None,
"usage_percentage": self._calculate_percentage(current_suppliers, suppliers_limit)
},
# IA & Analytics (Daily quotas)
"training_jobs_today": training_jobs_usage,
"forecasts_today": forecasts_usage,
# API & Storage
"api_calls_this_hour": api_calls_usage,
"file_storage_used_gb": storage_usage
},
"features": subscription.features or {},
"next_billing_date": subscription.next_billing_date.isoformat() if subscription.next_billing_date else None,
"trial_ends_at": subscription.trial_ends_at.isoformat() if subscription.trial_ends_at else None
}
except Exception as e:
logger.error("Failed to get usage summary",
tenant_id=tenant_id,
error=str(e))
return {"error": "Failed to get usage summary"}
async def _get_ingredient_count(self, tenant_id: str) -> int:
"""Get ingredient count from inventory service using shared client"""
try:
from app.core.config import settings
from shared.clients.inventory_client import create_inventory_client
# Use the shared inventory client with proper authentication
inventory_client = create_inventory_client(settings, service_name="tenant")
count = await inventory_client.count_ingredients(tenant_id)
logger.info(
"Retrieved ingredient count via inventory client",
tenant_id=tenant_id,
count=count
)
return count
except Exception as e:
logger.error(
"Error getting ingredient count via inventory client",
tenant_id=tenant_id,
error=str(e)
)
# Return 0 as fallback to avoid breaking subscription display
return 0
async def _get_recipe_count(self, tenant_id: str) -> int:
"""Get recipe count from recipes service using shared client"""
try:
from app.core.config import settings
# Use the shared recipes client with proper authentication and resilience
recipes_client = create_recipes_client(settings, service_name="tenant")
count = await recipes_client.count_recipes(tenant_id)
logger.info(
"Retrieved recipe count via recipes client",
tenant_id=tenant_id,
count=count
)
return count
except Exception as e:
logger.error(
"Error getting recipe count via recipes client",
tenant_id=tenant_id,
error=str(e)
)
# Return 0 as fallback to avoid breaking subscription display
return 0
async def _get_supplier_count(self, tenant_id: str) -> int:
"""Get supplier count from suppliers service using shared client"""
try:
from app.core.config import settings
# Use the shared suppliers client with proper authentication and resilience
suppliers_client = create_suppliers_client(settings, service_name="tenant")
count = await suppliers_client.count_suppliers(tenant_id)
logger.info(
"Retrieved supplier count via suppliers client",
tenant_id=tenant_id,
count=count
)
return count
except Exception as e:
logger.error(
"Error getting supplier count via suppliers client",
tenant_id=tenant_id,
error=str(e)
)
# Return 0 as fallback to avoid breaking subscription display
return 0
async def _get_redis_quota(self, quota_key: str) -> int:
"""Get current count from Redis quota key"""
try:
if not self.redis:
# Try to initialize Redis client if not available
from app.core.config import settings
import shared.redis_utils
self.redis = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
if not self.redis:
return 0
current = await self.redis.get(quota_key)
return int(current) if current else 0
except Exception as e:
logger.error("Error getting Redis quota", key=quota_key, error=str(e))
return 0
async def _get_training_jobs_today(self, tenant_id: str, plan: str) -> Dict[str, Any]:
"""Get training jobs usage for today"""
try:
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
quota_key = f"quota:daily:training_jobs:{tenant_id}:{date_str}"
current_count = await self._get_redis_quota(quota_key)
limit = get_training_job_quota(plan)
return {
"current": current_count,
"limit": limit,
"unlimited": limit is None,
"usage_percentage": self._calculate_percentage(current_count, limit)
}
except Exception as e:
logger.error("Error getting training jobs today", tenant_id=tenant_id, error=str(e))
return {"current": 0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
async def _get_forecasts_today(self, tenant_id: str, plan: str) -> Dict[str, Any]:
"""Get forecast generation usage for today"""
try:
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
quota_key = f"quota:daily:forecast_generation:{tenant_id}:{date_str}"
current_count = await self._get_redis_quota(quota_key)
limit = get_forecast_quota(plan)
return {
"current": current_count,
"limit": limit,
"unlimited": limit is None,
"usage_percentage": self._calculate_percentage(current_count, limit)
}
except Exception as e:
logger.error("Error getting forecasts today", tenant_id=tenant_id, error=str(e))
return {"current": 0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
async def _get_api_calls_this_hour(self, tenant_id: str, plan: str) -> Dict[str, Any]:
"""Get API calls usage for current hour"""
try:
hour_str = datetime.now(timezone.utc).strftime('%Y-%m-%d-%H')
quota_key = f"quota:hourly:api_calls:{tenant_id}:{hour_str}"
current_count = await self._get_redis_quota(quota_key)
plan_metadata = SubscriptionPlanMetadata.PLANS.get(plan, {})
limit = plan_metadata.get('limits', {}).get('api_calls_per_hour')
return {
"current": current_count,
"limit": limit,
"unlimited": limit is None,
"usage_percentage": self._calculate_percentage(current_count, limit)
}
except Exception as e:
logger.error("Error getting API calls this hour", tenant_id=tenant_id, error=str(e))
return {"current": 0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
async def _get_file_storage_usage_gb(self, tenant_id: str, plan: str) -> Dict[str, Any]:
"""Get file storage usage in GB"""
try:
storage_key = f"storage:total_bytes:{tenant_id}"
total_bytes = await self._get_redis_quota(storage_key)
total_gb = round(total_bytes / (1024 ** 3), 2) if total_bytes > 0 else 0.0
plan_metadata = SubscriptionPlanMetadata.PLANS.get(plan, {})
limit = plan_metadata.get('limits', {}).get('file_storage_gb')
return {
"current": total_gb,
"limit": limit,
"unlimited": limit is None,
"usage_percentage": self._calculate_percentage(total_gb, limit)
}
except Exception as e:
logger.error("Error getting file storage usage", tenant_id=tenant_id, error=str(e))
return {"current": 0.0, "limit": None, "unlimited": True, "usage_percentage": 0.0}
def _calculate_percentage(self, current: float, limit: Optional[int]) -> float:
"""Calculate usage percentage"""
if limit is None or limit == -1:
return 0.0
if limit == 0:
return 0.0
return round((current / limit) * 100, 1)
async def _get_limit_from_plan(self, plan: str, limit_key: str) -> Optional[int]:
"""Get limit value from plan metadata"""
plan_metadata = SubscriptionPlanMetadata.PLANS.get(plan, {})
limit = plan_metadata.get('limits', {}).get(limit_key)
return limit if limit != -1 else None

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,792 @@
"""
Subscription Service - State Manager
This service handles ONLY subscription database operations and state management
NO payment provider interactions, NO orchestration, NO coupon logic
"""
import structlog
from typing import Dict, Any, Optional, List
from datetime import datetime, timezone, timedelta
from uuid import UUID
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from app.models.tenants import Subscription, Tenant
from app.repositories.subscription_repository import SubscriptionRepository
from app.core.config import settings
from shared.database.exceptions import DatabaseError, ValidationError
from shared.subscription.plans import PlanPricing, QuotaLimits, SubscriptionPlanMetadata
logger = structlog.get_logger()
class SubscriptionService:
"""Service for managing subscription state and database operations ONLY"""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
self.subscription_repo = SubscriptionRepository(Subscription, db_session)
async def create_subscription_record(
self,
tenant_id: str,
subscription_id: str,
customer_id: str,
plan: str,
status: str,
trial_period_days: Optional[int] = None,
billing_interval: str = "monthly"
) -> Subscription:
"""
Create a local subscription record in the database
Args:
tenant_id: Tenant ID
subscription_id: Payment provider subscription ID
customer_id: Payment provider customer ID
plan: Subscription plan
status: Subscription status
trial_period_days: Optional trial period in days
billing_interval: Billing interval (monthly or yearly)
Returns:
Created Subscription object
"""
try:
tenant_uuid = UUID(tenant_id)
# Verify tenant exists
query = select(Tenant).where(Tenant.id == tenant_uuid)
result = await self.db_session.execute(query)
tenant = result.scalar_one_or_none()
if not tenant:
raise ValidationError(f"Tenant not found: {tenant_id}")
# Create local subscription record
subscription_data = {
'tenant_id': str(tenant_id),
'subscription_id': subscription_id,
'customer_id': customer_id,
'plan': plan,
'status': status,
'created_at': datetime.now(timezone.utc),
'billing_cycle': billing_interval
}
# Add trial-related data if applicable
if trial_period_days and trial_period_days > 0:
from datetime import timedelta
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
subscription_data['trial_ends_at'] = trial_ends_at
# Check if subscription with this subscription_id already exists to prevent duplicates
existing_subscription = await self.subscription_repo.get_by_provider_id(subscription_id)
if existing_subscription:
# Update the existing subscription instead of creating a duplicate
updated_subscription = await self.subscription_repo.update(str(existing_subscription.id), subscription_data)
logger.info("Existing subscription updated",
tenant_id=tenant_id,
subscription_id=subscription_id,
plan=plan)
return updated_subscription
else:
# Create new subscription
created_subscription = await self.subscription_repo.create(subscription_data)
logger.info("subscription_record_created",
tenant_id=tenant_id,
subscription_id=subscription_id,
plan=plan)
return created_subscription
except ValidationError as ve:
logger.error(f"create_subscription_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"create_subscription_record_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to create subscription record: {str(e)}")
async def update_subscription_status(
self,
tenant_id: str,
status: str,
stripe_data: Optional[Dict[str, Any]] = None
) -> Subscription:
"""
Update subscription status in database
Args:
tenant_id: Tenant ID
status: New subscription status
stripe_data: Optional data from Stripe to update
Returns:
Updated Subscription object
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
# Prepare update data
update_data = {
'status': status,
'updated_at': datetime.now(timezone.utc)
}
# Include Stripe data if provided
if stripe_data:
# Note: current_period_start and current_period_end are not in the local model
# These would need to be stored separately or handled differently
# For now, we'll skip storing these Stripe-specific fields in the local model
pass
# Update status flags based on status value
if status == 'active':
update_data['is_active'] = True
update_data['cancelled_at'] = None
elif status in ['canceled', 'past_due', 'unpaid', 'inactive']:
update_data['is_active'] = False
elif status == 'pending_cancellation':
update_data['is_active'] = True # Still active until effective date
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
logger.info("subscription_status_updated",
tenant_id=tenant_id,
old_status=subscription.status,
new_status=status)
return updated_subscription
except ValidationError as ve:
logger.error(f"update_subscription_status_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"update_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to update subscription status: {str(e)}")
async def get_subscription_by_tenant_id(
self,
tenant_id: str
) -> Optional[Subscription]:
"""
Get subscription by tenant ID
Args:
tenant_id: Tenant ID
Returns:
Subscription object or None
"""
try:
tenant_uuid = UUID(tenant_id)
return await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
except Exception as e:
logger.error("get_subscription_by_tenant_id_failed",
error=str(e), tenant_id=tenant_id)
return None
async def get_subscription_by_provider_id(
self,
subscription_id: str
) -> Optional[Subscription]:
"""
Get subscription by payment provider subscription ID
Args:
subscription_id: Payment provider subscription ID
Returns:
Subscription object or None
"""
try:
return await self.subscription_repo.get_by_provider_id(subscription_id)
except Exception as e:
logger.error("get_subscription_by_provider_id_failed",
error=str(e), subscription_id=subscription_id)
return None
async def get_subscriptions_by_customer_id(self, customer_id: str) -> List[Subscription]:
"""
Get subscriptions by Stripe customer ID
Args:
customer_id: Stripe customer ID
Returns:
List of Subscription objects
"""
try:
return await self.subscription_repo.get_by_customer_id(customer_id)
except Exception as e:
logger.error("get_subscriptions_by_customer_id_failed",
error=str(e), customer_id=customer_id)
return []
async def cancel_subscription(
self,
tenant_id: str,
reason: str = ""
) -> Dict[str, Any]:
"""
Mark subscription as pending cancellation in database
Args:
tenant_id: Tenant ID to cancel subscription for
reason: Optional cancellation reason
Returns:
Dictionary with cancellation details
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
if subscription.status in ['pending_cancellation', 'inactive']:
raise ValidationError(f"Subscription is already {subscription.status}")
# Calculate cancellation effective date (end of billing period)
cancellation_effective_date = subscription.next_billing_date or (
datetime.now(timezone.utc) + timedelta(days=30)
)
# Update subscription status in database
update_data = {
'status': 'pending_cancellation',
'cancelled_at': datetime.now(timezone.utc),
'cancellation_effective_date': cancellation_effective_date
}
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
days_remaining = (cancellation_effective_date - datetime.now(timezone.utc)).days
logger.info(
"subscription_cancelled",
tenant_id=str(tenant_id),
effective_date=cancellation_effective_date.isoformat(),
reason=reason[:200] if reason else None
)
return {
"success": True,
"message": "Subscription cancelled successfully. You will have read-only access until the end of your billing period.",
"status": "pending_cancellation",
"cancellation_effective_date": cancellation_effective_date.isoformat(),
"days_remaining": days_remaining,
"read_only_mode_starts": cancellation_effective_date.isoformat()
}
except ValidationError as ve:
logger.error(f"subscription_cancellation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"subscription_cancellation_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to cancel subscription: {str(e)}")
async def reactivate_subscription(
self,
tenant_id: str,
plan: str = "starter"
) -> Dict[str, Any]:
"""
Reactivate a cancelled or inactive subscription
Args:
tenant_id: Tenant ID to reactivate subscription for
plan: Plan to reactivate with
Returns:
Dictionary with reactivation details
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
if subscription.status not in ['pending_cancellation', 'inactive']:
raise ValidationError(f"Cannot reactivate subscription with status: {subscription.status}")
# Update subscription status and plan
update_data = {
'status': 'active',
'plan': plan,
'cancelled_at': None,
'cancellation_effective_date': None
}
if subscription.status == 'inactive':
update_data['next_billing_date'] = datetime.now(timezone.utc) + timedelta(days=30)
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
logger.info(
"subscription_reactivated",
tenant_id=str(tenant_id),
new_plan=plan
)
return {
"success": True,
"message": "Subscription reactivated successfully",
"status": "active",
"plan": plan,
"next_billing_date": updated_subscription.next_billing_date.isoformat() if updated_subscription.next_billing_date else None
}
except ValidationError as ve:
logger.error(f"subscription_reactivation_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"subscription_reactivation_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to reactivate subscription: {str(e)}")
async def get_subscription_status(
self,
tenant_id: str
) -> Dict[str, Any]:
"""
Get current subscription status including read-only mode info
Args:
tenant_id: Tenant ID to get status for
Returns:
Dictionary with subscription status details
"""
try:
tenant_uuid = UUID(tenant_id)
# Get subscription from repository
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
is_read_only = subscription.status in ['pending_cancellation', 'inactive']
days_until_inactive = None
if subscription.status == 'pending_cancellation' and subscription.cancellation_effective_date:
days_until_inactive = (subscription.cancellation_effective_date - datetime.now(timezone.utc)).days
return {
"tenant_id": str(tenant_id),
"status": subscription.status,
"plan": subscription.plan,
"is_read_only": is_read_only,
"cancellation_effective_date": subscription.cancellation_effective_date.isoformat() if subscription.cancellation_effective_date else None,
"days_until_inactive": days_until_inactive
}
except ValidationError as ve:
logger.error("get_subscription_status_validation_failed",
error=str(ve), tenant_id=tenant_id)
raise ve
except Exception as e:
logger.error(f"get_subscription_status_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to get subscription status: {str(e)}")
async def update_subscription_plan_record(
self,
tenant_id: str,
new_plan: str,
new_status: str,
new_period_start: datetime,
new_period_end: datetime,
billing_cycle: str = "monthly",
proration_details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Update local subscription plan record in database
Args:
tenant_id: Tenant ID
new_plan: New plan name
new_status: New subscription status
new_period_start: New period start date
new_period_end: New period end date
billing_cycle: Billing cycle for the new plan
proration_details: Proration details from payment provider
Returns:
Dictionary with update results
"""
try:
tenant_uuid = UUID(tenant_id)
# Get current subscription
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
# Update local subscription record
update_data = {
'plan': new_plan,
'status': new_status,
'updated_at': datetime.now(timezone.utc)
}
# Note: current_period_start and current_period_end are not in the local model
# These Stripe-specific fields would need to be stored separately
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
logger.info(
"subscription_plan_record_updated",
tenant_id=str(tenant_id),
old_plan=subscription.plan,
new_plan=new_plan,
proration_amount=proration_details.get("net_amount", 0) if proration_details else 0
)
return {
"success": True,
"message": f"Subscription plan record updated to {new_plan}",
"old_plan": subscription.plan,
"new_plan": new_plan,
"proration_details": proration_details,
"new_status": new_status,
"new_period_end": new_period_end.isoformat()
}
except ValidationError as ve:
logger.error(f"update_subscription_plan_record_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"update_subscription_plan_record_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to update subscription plan record: {str(e)}")
async def update_billing_cycle_record(
self,
tenant_id: str,
new_billing_cycle: str,
new_status: str,
new_period_start: datetime,
new_period_end: datetime,
current_plan: str,
proration_details: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
"""
Update local billing cycle record in database
Args:
tenant_id: Tenant ID
new_billing_cycle: New billing cycle ('monthly' or 'yearly')
new_status: New subscription status
new_period_start: New period start date
new_period_end: New period end date
current_plan: Current plan name
proration_details: Proration details from payment provider
Returns:
Dictionary with billing cycle update results
"""
try:
tenant_uuid = UUID(tenant_id)
# Get current subscription
subscription = await self.subscription_repo.get_by_tenant_id(str(tenant_uuid))
if not subscription:
raise ValidationError(f"Subscription not found for tenant {tenant_id}")
# Update local subscription record
update_data = {
'status': new_status,
'updated_at': datetime.now(timezone.utc)
}
# Note: current_period_start and current_period_end are not in the local model
# These Stripe-specific fields would need to be stored separately
updated_subscription = await self.subscription_repo.update(str(subscription.id), update_data)
# Invalidate subscription cache
await self._invalidate_cache(tenant_id)
old_billing_cycle = getattr(subscription, 'billing_cycle', 'monthly')
logger.info(
"subscription_billing_cycle_record_updated",
tenant_id=str(tenant_id),
old_billing_cycle=old_billing_cycle,
new_billing_cycle=new_billing_cycle,
proration_amount=proration_details.get("net_amount", 0) if proration_details else 0
)
return {
"success": True,
"message": f"Billing cycle record changed to {new_billing_cycle}",
"old_billing_cycle": old_billing_cycle,
"new_billing_cycle": new_billing_cycle,
"proration_details": proration_details,
"new_status": new_status,
"new_period_end": new_period_end.isoformat()
}
except ValidationError as ve:
logger.error(f"change_billing_cycle_validation_failed, tenant_id={tenant_id}, error={str(ve)}")
raise ve
except Exception as e:
logger.error(f"change_billing_cycle_failed, tenant_id={tenant_id}, error={str(e)}")
raise DatabaseError(f"Failed to change billing cycle: {str(e)}")
async def _invalidate_cache(self, tenant_id: str):
"""Helper method to invalidate subscription cache"""
try:
from app.services.subscription_cache import get_subscription_cache_service
import shared.redis_utils
redis_client = await shared.redis_utils.initialize_redis(settings.REDIS_URL)
cache_service = get_subscription_cache_service(redis_client)
await cache_service.invalidate_subscription_cache(str(tenant_id))
logger.info(
"Subscription cache invalidated",
tenant_id=str(tenant_id)
)
except Exception as cache_error:
logger.error(
"Failed to invalidate subscription cache",
tenant_id=str(tenant_id),
error=str(cache_error)
)
async def validate_subscription_change(
self,
tenant_id: str,
new_plan: str
) -> bool:
"""
Validate if a subscription change is allowed
Args:
tenant_id: Tenant ID
new_plan: New plan to validate
Returns:
True if change is allowed
"""
try:
subscription = await self.get_subscription_by_tenant_id(tenant_id)
if not subscription:
return False
# Can't change if already pending cancellation or inactive
if subscription.status in ['pending_cancellation', 'inactive']:
return False
return True
except Exception as e:
logger.error("validate_subscription_change_failed",
error=str(e), tenant_id=tenant_id)
return False
# ========================================================================
# TENANT-INDEPENDENT SUBSCRIPTION METHODS (New Architecture)
# ========================================================================
async def create_tenant_independent_subscription_record(
self,
subscription_id: str,
customer_id: str,
plan: str,
status: str,
trial_period_days: Optional[int] = None,
billing_interval: str = "monthly",
user_id: str = None
) -> Subscription:
"""
Create a tenant-independent subscription record in the database
This subscription is not linked to any tenant and will be linked during onboarding
Args:
subscription_id: Payment provider subscription ID
customer_id: Payment provider customer ID
plan: Subscription plan
status: Subscription status
trial_period_days: Optional trial period in days
billing_interval: Billing interval (monthly or yearly)
user_id: User ID who created this subscription
Returns:
Created Subscription object
"""
try:
# Create tenant-independent subscription record
subscription_data = {
'subscription_id': subscription_id,
'customer_id': customer_id,
'plan': plan, # Repository expects 'plan', not 'plan_id'
'status': status,
'created_at': datetime.now(timezone.utc),
'billing_cycle': billing_interval,
'user_id': user_id,
'is_tenant_linked': False,
'tenant_linking_status': 'pending'
}
# Add trial-related data if applicable
if trial_period_days and trial_period_days > 0:
from datetime import timedelta
trial_ends_at = datetime.now(timezone.utc) + timedelta(days=trial_period_days)
subscription_data['trial_ends_at'] = trial_ends_at
created_subscription = await self.subscription_repo.create_tenant_independent_subscription(subscription_data)
logger.info("tenant_independent_subscription_record_created",
subscription_id=subscription_id,
user_id=user_id,
plan=plan)
return created_subscription
except ValidationError as ve:
logger.error("create_tenant_independent_subscription_record_validation_failed",
error=str(ve), user_id=user_id)
raise ve
except Exception as e:
logger.error("create_tenant_independent_subscription_record_failed",
error=str(e), user_id=user_id)
raise DatabaseError(f"Failed to create tenant-independent subscription record: {str(e)}")
async def get_pending_tenant_linking_subscriptions(self) -> List[Subscription]:
"""Get all subscriptions waiting to be linked to tenants"""
try:
return await self.subscription_repo.get_pending_tenant_linking_subscriptions()
except Exception as e:
logger.error(f"Failed to get pending tenant linking subscriptions: {str(e)}")
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def get_pending_subscriptions_by_user(self, user_id: str) -> List[Subscription]:
"""Get pending tenant linking subscriptions for a specific user"""
try:
return await self.subscription_repo.get_pending_subscriptions_by_user(user_id)
except Exception as e:
logger.error("Failed to get pending subscriptions by user",
user_id=user_id, error=str(e))
raise DatabaseError(f"Failed to get pending subscriptions: {str(e)}")
async def link_subscription_to_tenant(
self,
subscription_id: str,
tenant_id: str,
user_id: str
) -> Subscription:
"""
Link a pending subscription to a tenant
This completes the registration flow by associating the subscription
created during registration with the tenant created during onboarding
Args:
subscription_id: Subscription ID to link
tenant_id: Tenant ID to link to
user_id: User ID performing the linking (for validation)
Returns:
Updated Subscription object
"""
try:
return await self.subscription_repo.link_subscription_to_tenant(
subscription_id, tenant_id, user_id
)
except Exception as e:
logger.error("Failed to link subscription to tenant",
subscription_id=subscription_id,
tenant_id=tenant_id,
user_id=user_id,
error=str(e))
raise DatabaseError(f"Failed to link subscription to tenant: {str(e)}")
async def cleanup_orphaned_subscriptions(self, days_old: int = 30) -> int:
"""Clean up subscriptions that were never linked to tenants"""
try:
return await self.subscription_repo.cleanup_orphaned_subscriptions(days_old)
except Exception as e:
logger.error(f"Failed to cleanup orphaned subscriptions: {str(e)}")
raise DatabaseError(f"Failed to cleanup orphaned subscriptions: {str(e)}")
async def update_subscription_info(
self,
subscription_id: str,
update_data: Dict[str, Any]
) -> Subscription:
"""
Update subscription-related information (3DS flags, status, etc.)
This is useful for updating tenant-independent subscriptions during registration.
Args:
subscription_id: Subscription ID
update_data: Dictionary with fields to update
Returns:
Updated Subscription object
"""
try:
# Filter allowed fields
allowed_fields = {
'plan', 'status', 'is_tenant_linked', 'tenant_linking_status',
'threeds_authentication_required', 'threeds_authentication_required_at',
'threeds_authentication_completed', 'threeds_authentication_completed_at',
'last_threeds_setup_intent_id', 'threeds_action_type'
}
filtered_data = {k: v for k, v in update_data.items() if k in allowed_fields}
if not filtered_data:
logger.warning("No valid subscription info fields provided for update",
subscription_id=subscription_id)
return await self.subscription_repo.get_by_id(subscription_id)
updated_subscription = await self.subscription_repo.update(subscription_id, filtered_data)
if not updated_subscription:
raise ValidationError(f"Subscription not found: {subscription_id}")
logger.info("Subscription info updated",
subscription_id=subscription_id,
updated_fields=list(filtered_data.keys()))
return updated_subscription
except ValidationError:
raise
except Exception as e:
logger.error("Failed to update subscription info",
subscription_id=subscription_id,
error=str(e))
raise DatabaseError(f"Failed to update subscription info: {str(e)}")

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,293 @@
# services/tenant/app/services/tenant_settings_service.py
"""
Tenant Settings Service
Business logic for managing tenant-specific operational settings
"""
import structlog
from sqlalchemy.ext.asyncio import AsyncSession
from uuid import UUID
from typing import Optional, Dict, Any
from fastapi import HTTPException, status
from ..models.tenant_settings import TenantSettings
from ..repositories.tenant_settings_repository import TenantSettingsRepository
from ..schemas.tenant_settings import (
TenantSettingsUpdate,
ProcurementSettings,
InventorySettings,
ProductionSettings,
SupplierSettings,
POSSettings,
OrderSettings,
ReplenishmentSettings,
SafetyStockSettings,
MOQSettings,
SupplierSelectionSettings,
NotificationSettings
)
logger = structlog.get_logger()
class TenantSettingsService:
"""
Service for managing tenant settings
Handles validation, CRUD operations, and default value management
"""
# Map category names to schema validators
CATEGORY_SCHEMAS = {
"procurement": ProcurementSettings,
"inventory": InventorySettings,
"production": ProductionSettings,
"supplier": SupplierSettings,
"pos": POSSettings,
"order": OrderSettings,
"replenishment": ReplenishmentSettings,
"safety_stock": SafetyStockSettings,
"moq": MOQSettings,
"supplier_selection": SupplierSelectionSettings,
"notification": NotificationSettings
}
# Map category names to database column names
CATEGORY_COLUMNS = {
"procurement": "procurement_settings",
"inventory": "inventory_settings",
"production": "production_settings",
"supplier": "supplier_settings",
"pos": "pos_settings",
"order": "order_settings",
"replenishment": "replenishment_settings",
"safety_stock": "safety_stock_settings",
"moq": "moq_settings",
"supplier_selection": "supplier_selection_settings",
"notification": "notification_settings"
}
def __init__(self, db: AsyncSession):
self.db = db
self.repository = TenantSettingsRepository(db)
async def get_settings(self, tenant_id: UUID) -> TenantSettings:
"""
Get tenant settings, creating defaults if they don't exist
Args:
tenant_id: UUID of the tenant
Returns:
TenantSettings object
Raises:
HTTPException: If tenant not found
"""
try:
# Try to get existing settings using repository
settings = await self.repository.get_by_tenant_id(tenant_id)
logger.info(f"Existing settings lookup for tenant {tenant_id}: {'found' if settings else 'not found'}")
# Create default settings if they don't exist
if not settings:
logger.info(f"Creating default settings for tenant {tenant_id}")
settings = await self._create_default_settings(tenant_id)
logger.info(f"Successfully created default settings for tenant {tenant_id}")
return settings
except Exception as e:
logger.error(f"Failed to get or create tenant settings, tenant_id={tenant_id}, error={str(e)}", exc_info=True)
# Re-raise as HTTPException to match the expected behavior
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get tenant settings: {str(e)}"
)
async def update_settings(
self,
tenant_id: UUID,
updates: TenantSettingsUpdate
) -> TenantSettings:
"""
Update tenant settings
Args:
tenant_id: UUID of the tenant
updates: TenantSettingsUpdate object with new values
Returns:
Updated TenantSettings object
"""
settings = await self.get_settings(tenant_id)
# Update each category if provided
if updates.procurement_settings is not None:
settings.procurement_settings = updates.procurement_settings.dict()
if updates.inventory_settings is not None:
settings.inventory_settings = updates.inventory_settings.dict()
if updates.production_settings is not None:
settings.production_settings = updates.production_settings.dict()
if updates.supplier_settings is not None:
settings.supplier_settings = updates.supplier_settings.dict()
if updates.pos_settings is not None:
settings.pos_settings = updates.pos_settings.dict()
if updates.order_settings is not None:
settings.order_settings = updates.order_settings.dict()
if updates.replenishment_settings is not None:
settings.replenishment_settings = updates.replenishment_settings.dict()
if updates.safety_stock_settings is not None:
settings.safety_stock_settings = updates.safety_stock_settings.dict()
if updates.moq_settings is not None:
settings.moq_settings = updates.moq_settings.dict()
if updates.supplier_selection_settings is not None:
settings.supplier_selection_settings = updates.supplier_selection_settings.dict()
return await self.repository.update(settings)
async def get_category(self, tenant_id: UUID, category: str) -> Dict[str, Any]:
"""
Get settings for a specific category
Args:
tenant_id: UUID of the tenant
category: Category name (procurement, inventory, production, etc.)
Returns:
Dictionary with category settings
Raises:
HTTPException: If category is invalid
"""
if category not in self.CATEGORY_COLUMNS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid category: {category}. Valid categories: {', '.join(self.CATEGORY_COLUMNS.keys())}"
)
settings = await self.get_settings(tenant_id)
column_name = self.CATEGORY_COLUMNS[category]
return getattr(settings, column_name)
async def update_category(
self,
tenant_id: UUID,
category: str,
updates: Dict[str, Any]
) -> TenantSettings:
"""
Update settings for a specific category
Args:
tenant_id: UUID of the tenant
category: Category name
updates: Dictionary with new values
Returns:
Updated TenantSettings object
Raises:
HTTPException: If category is invalid or validation fails
"""
if category not in self.CATEGORY_COLUMNS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid category: {category}"
)
# Validate updates using the appropriate schema
schema = self.CATEGORY_SCHEMAS[category]
try:
validated_data = schema(**updates)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Validation error: {str(e)}"
)
# Get existing settings and update the category
settings = await self.get_settings(tenant_id)
column_name = self.CATEGORY_COLUMNS[category]
setattr(settings, column_name, validated_data.dict())
return await self.repository.update(settings)
async def reset_category(self, tenant_id: UUID, category: str) -> Dict[str, Any]:
"""
Reset a category to default values
Args:
tenant_id: UUID of the tenant
category: Category name
Returns:
Dictionary with reset category settings
Raises:
HTTPException: If category is invalid
"""
if category not in self.CATEGORY_COLUMNS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid category: {category}"
)
# Get default settings for the category
defaults = TenantSettings.get_default_settings()
column_name = self.CATEGORY_COLUMNS[category]
default_category_settings = defaults[column_name]
# Update the category with defaults
settings = await self.get_settings(tenant_id)
setattr(settings, column_name, default_category_settings)
await self.repository.update(settings)
return default_category_settings
async def _create_default_settings(self, tenant_id: UUID) -> TenantSettings:
"""
Create default settings for a new tenant
Args:
tenant_id: UUID of the tenant
Returns:
Newly created TenantSettings object
"""
defaults = TenantSettings.get_default_settings()
settings = TenantSettings(
tenant_id=tenant_id,
procurement_settings=defaults["procurement_settings"],
inventory_settings=defaults["inventory_settings"],
production_settings=defaults["production_settings"],
supplier_settings=defaults["supplier_settings"],
pos_settings=defaults["pos_settings"],
order_settings=defaults["order_settings"],
replenishment_settings=defaults["replenishment_settings"],
safety_stock_settings=defaults["safety_stock_settings"],
moq_settings=defaults["moq_settings"],
supplier_selection_settings=defaults["supplier_selection_settings"]
)
return await self.repository.create(settings)
async def delete_settings(self, tenant_id: UUID) -> None:
"""
Delete tenant settings (used when tenant is deleted)
Args:
tenant_id: UUID of the tenant
"""
await self.repository.delete(tenant_id)