demo seed change

This commit is contained in:
Urtzi Alfaro
2025-12-13 23:57:54 +01:00
parent f3688dfb04
commit ff830a3415
299 changed files with 20328 additions and 19485 deletions

View File

@@ -1,7 +1,9 @@
"""Demo Session Services"""
from .session_manager import DemoSessionManager
from .data_cloner import DemoDataCloner
from .cleanup_service import DemoCleanupService
__all__ = ["DemoSessionManager", "DemoDataCloner", "DemoCleanupService"]
__all__ = [
"DemoSessionManager",
"DemoCleanupService",
]

View File

@@ -4,14 +4,21 @@ Handles automatic cleanup of expired sessions
"""
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update
from datetime import datetime, timezone
from typing import List
from sqlalchemy import select
from datetime import datetime, timezone, timedelta
import structlog
import httpx
import asyncio
import os
from app.models import DemoSession, DemoSessionStatus
from app.services.data_cloner import DemoDataCloner
from datetime import datetime, timezone, timedelta
from app.core.redis_wrapper import DemoRedisWrapper
from app.monitoring.metrics import (
demo_sessions_deleted_total,
demo_session_cleanup_duration_seconds,
demo_sessions_active
)
logger = structlog.get_logger()
@@ -22,7 +29,199 @@ class DemoCleanupService:
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
self.db = db
self.redis = redis
self.data_cloner = DemoDataCloner(db, redis)
from app.core.config import settings
self.internal_api_key = settings.INTERNAL_API_KEY
# Service URLs for cleanup
self.services = [
("tenant", os.getenv("TENANT_SERVICE_URL", "http://tenant-service:8000")),
("auth", os.getenv("AUTH_SERVICE_URL", "http://auth-service:8000")),
("inventory", os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")),
("recipes", os.getenv("RECIPES_SERVICE_URL", "http://recipes-service:8000")),
("suppliers", os.getenv("SUPPLIERS_SERVICE_URL", "http://suppliers-service:8000")),
("production", os.getenv("PRODUCTION_SERVICE_URL", "http://production-service:8000")),
("procurement", os.getenv("PROCUREMENT_SERVICE_URL", "http://procurement-service:8000")),
("sales", os.getenv("SALES_SERVICE_URL", "http://sales-service:8000")),
("orders", os.getenv("ORDERS_SERVICE_URL", "http://orders-service:8000")),
("forecasting", os.getenv("FORECASTING_SERVICE_URL", "http://forecasting-service:8000")),
("orchestrator", os.getenv("ORCHESTRATOR_SERVICE_URL", "http://orchestrator-service:8000")),
]
async def cleanup_session(self, session: DemoSession) -> dict:
"""
Delete all data for a demo session across all services.
Returns:
{
"success": bool,
"total_deleted": int,
"duration_ms": int,
"details": {service: {records_deleted, duration_ms}},
"errors": []
}
"""
start_time = datetime.now(timezone.utc)
virtual_tenant_id = str(session.virtual_tenant_id)
session_id = session.session_id
logger.info(
"Starting demo session cleanup",
session_id=session_id,
virtual_tenant_id=virtual_tenant_id,
demo_account_type=session.demo_account_type
)
# Delete from all services in parallel
tasks = [
self._delete_from_service(name, url, virtual_tenant_id)
for name, url in self.services
]
service_results = await asyncio.gather(*tasks, return_exceptions=True)
# Aggregate results
total_deleted = 0
details = {}
errors = []
for (service_name, _), result in zip(self.services, service_results):
if isinstance(result, Exception):
errors.append(f"{service_name}: {str(result)}")
details[service_name] = {"status": "error", "error": str(result)}
else:
total_deleted += result.get("records_deleted", {}).get("total", 0)
details[service_name] = result
# Delete from Redis
await self._delete_redis_cache(virtual_tenant_id)
# Delete child tenants if enterprise
if session.demo_account_type == "enterprise":
child_metadata = session.session_metadata.get("children", [])
for child in child_metadata:
child_tenant_id = child["virtual_tenant_id"]
await self._delete_from_all_services(child_tenant_id)
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
success = len(errors) == 0
logger.info(
"Demo session cleanup completed",
session_id=session_id,
virtual_tenant_id=virtual_tenant_id,
success=success,
total_deleted=total_deleted,
duration_ms=duration_ms,
error_count=len(errors)
)
return {
"success": success,
"total_deleted": total_deleted,
"duration_ms": duration_ms,
"details": details,
"errors": errors
}
async def _delete_from_service(
self,
service_name: str,
service_url: str,
virtual_tenant_id: str
) -> dict:
"""Delete all data from a single service"""
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.delete(
f"{service_url}/internal/demo/tenant/{virtual_tenant_id}",
headers={"X-Internal-API-Key": self.internal_api_key}
)
if response.status_code == 200:
return response.json()
elif response.status_code == 404:
# Already deleted or never existed - idempotent
return {
"service": service_name,
"status": "not_found",
"records_deleted": {"total": 0}
}
else:
raise Exception(f"HTTP {response.status_code}: {response.text}")
except Exception as e:
logger.error(
"Failed to delete from service",
service=service_name,
virtual_tenant_id=virtual_tenant_id,
error=str(e)
)
raise
async def _delete_redis_cache(self, virtual_tenant_id: str):
"""Delete all Redis keys for a virtual tenant"""
try:
client = await self.redis.get_client()
pattern = f"*:{virtual_tenant_id}:*"
keys = await client.keys(pattern)
if keys:
await client.delete(*keys)
logger.debug("Deleted Redis cache", tenant_id=virtual_tenant_id, keys_deleted=len(keys))
except Exception as e:
logger.warning("Failed to delete Redis cache", error=str(e), tenant_id=virtual_tenant_id)
async def _delete_from_all_services(self, virtual_tenant_id: str):
"""Delete data from all services for a tenant"""
tasks = [
self._delete_from_service(name, url, virtual_tenant_id)
for name, url in self.services
]
return await asyncio.gather(*tasks, return_exceptions=True)
async def _delete_tenant_data(self, tenant_id: str, session_id: str) -> dict:
"""Delete demo data for a tenant across all services"""
logger.info("Deleting tenant data", tenant_id=tenant_id, session_id=session_id)
results = {}
async def delete_from_service(service_name: str, service_url: str):
try:
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.delete(
f"{service_url}/internal/demo/tenant/{tenant_id}",
headers={"X-Internal-API-Key": self.internal_api_key}
)
if response.status_code == 200:
logger.debug(f"Deleted data from {service_name}", tenant_id=tenant_id)
return {"service": service_name, "status": "deleted"}
else:
logger.warning(
f"Failed to delete from {service_name}",
status_code=response.status_code,
tenant_id=tenant_id
)
return {"service": service_name, "status": "failed", "error": f"HTTP {response.status_code}"}
except Exception as e:
logger.warning(
f"Exception deleting from {service_name}",
error=str(e),
tenant_id=tenant_id
)
return {"service": service_name, "status": "failed", "error": str(e)}
# Delete from all services in parallel
tasks = [delete_from_service(name, url) for name, url in self.services]
service_results = await asyncio.gather(*tasks, return_exceptions=True)
for result in service_results:
if isinstance(result, Exception):
logger.error("Service deletion failed", error=str(result))
elif isinstance(result, dict):
results[result["service"]] = result
return results
async def cleanup_expired_sessions(self) -> dict:
"""
@@ -32,9 +231,9 @@ class DemoCleanupService:
Returns:
Cleanup statistics
"""
from datetime import timedelta
logger.info("Starting demo session cleanup")
start_time = datetime.now(timezone.utc)
now = datetime.now(timezone.utc)
stuck_threshold = now - timedelta(minutes=5) # Sessions pending > 5 min are stuck
@@ -97,10 +296,7 @@ class DemoCleanupService:
)
for child_id in child_tenant_ids:
try:
await self.data_cloner.delete_session_data(
str(child_id),
session.session_id
)
await self._delete_tenant_data(child_id, session.session_id)
except Exception as child_error:
logger.error(
"Failed to delete child tenant",
@@ -109,11 +305,14 @@ class DemoCleanupService:
)
# Delete parent/main session data
await self.data_cloner.delete_session_data(
await self._delete_tenant_data(
str(session.virtual_tenant_id),
session.session_id
)
# Delete Redis data
await self.redis.delete_session_data(session.session_id)
stats["cleaned_up"] += 1
logger.info(
@@ -137,6 +336,19 @@ class DemoCleanupService:
)
logger.info("Demo session cleanup completed", stats=stats)
# Update Prometheus metrics
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
demo_session_cleanup_duration_seconds.labels(tier="all").observe(duration_ms / 1000)
# Update deleted sessions metrics by tier (we need to determine tiers from sessions)
for session in all_sessions_to_cleanup:
demo_sessions_deleted_total.labels(
tier=session.demo_account_type,
status="success"
).inc()
demo_sessions_active.labels(tier=session.demo_account_type).dec()
return stats
async def cleanup_old_destroyed_sessions(self, days: int = 7) -> int:
@@ -149,8 +361,6 @@ class DemoCleanupService:
Returns:
Number of deleted records
"""
from datetime import timedelta
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
result = await self.db.execute(

File diff suppressed because it is too large Load Diff

View File

@@ -1,604 +0,0 @@
"""
Cloning Strategy Pattern Implementation
Provides explicit, type-safe strategies for different demo account types
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
import structlog
logger = structlog.get_logger()
@dataclass
class CloningContext:
"""
Context object containing all data needed for cloning operations
Immutable to prevent state mutation bugs
"""
base_tenant_id: str
virtual_tenant_id: str
session_id: str
demo_account_type: str
session_metadata: Optional[Dict[str, Any]] = None
services_filter: Optional[List[str]] = None
# Orchestrator dependencies (injected)
orchestrator: Any = None # Will be CloneOrchestrator instance
def __post_init__(self):
"""Validate context after initialization"""
if not self.base_tenant_id:
raise ValueError("base_tenant_id is required")
if not self.virtual_tenant_id:
raise ValueError("virtual_tenant_id is required")
if not self.session_id:
raise ValueError("session_id is required")
class CloningStrategy(ABC):
"""
Abstract base class for cloning strategies
Each strategy is a leaf node - no recursion possible
"""
@abstractmethod
async def clone(self, context: CloningContext) -> Dict[str, Any]:
"""
Execute the cloning strategy
Args:
context: Immutable context with all required data
Returns:
Dictionary with cloning results
"""
pass
@abstractmethod
def get_strategy_name(self) -> str:
"""Return the name of this strategy for logging"""
pass
class ProfessionalCloningStrategy(CloningStrategy):
"""
Strategy for single-tenant professional demos
Clones all services for a single virtual tenant
"""
def get_strategy_name(self) -> str:
return "professional"
async def clone(self, context: CloningContext) -> Dict[str, Any]:
"""
Clone demo data for a professional (single-tenant) account
Process:
1. Validate context
2. Clone all services in parallel
3. Handle failures with partial success support
4. Return aggregated results
"""
logger.info(
"Executing professional cloning strategy",
session_id=context.session_id,
virtual_tenant_id=context.virtual_tenant_id,
base_tenant_id=context.base_tenant_id
)
start_time = datetime.now(timezone.utc)
# Determine which services to clone
services_to_clone = context.orchestrator.services
if context.services_filter:
services_to_clone = [
s for s in context.orchestrator.services
if s.name in context.services_filter
]
logger.info(
"Filtering services",
session_id=context.session_id,
services_filter=context.services_filter,
filtered_count=len(services_to_clone)
)
# Rollback stack for cleanup
rollback_stack = []
try:
# Import asyncio here to avoid circular imports
import asyncio
# Create parallel tasks for all services
tasks = []
service_map = {}
for service_def in services_to_clone:
task = asyncio.create_task(
context.orchestrator._clone_service(
service_def=service_def,
base_tenant_id=context.base_tenant_id,
virtual_tenant_id=context.virtual_tenant_id,
demo_account_type=context.demo_account_type,
session_id=context.session_id,
session_metadata=context.session_metadata
)
)
tasks.append(task)
service_map[task] = service_def.name
# Process tasks as they complete for real-time progress updates
service_results = {}
total_records = 0
failed_services = []
required_service_failed = False
completed_count = 0
total_count = len(tasks)
# Create a mapping from futures to service names to properly identify completed tasks
# We'll use asyncio.wait approach instead of as_completed to access the original tasks
pending = set(tasks)
completed_tasks_info = {task: service_map[task] for task in tasks} # Map tasks to service names
while pending:
# Wait for at least one task to complete
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
# Process each completed task
for completed_task in done:
try:
# Get the result from the completed task
result = await completed_task
# Get the service name from our mapping
service_name = completed_tasks_info[completed_task]
service_def = next(s for s in services_to_clone if s.name == service_name)
service_results[service_name] = result
completed_count += 1
if result.get("status") == "failed":
failed_services.append(service_name)
if service_def.required:
required_service_failed = True
else:
total_records += result.get("records_cloned", 0)
# Track successful services for rollback
if result.get("status") == "completed":
rollback_stack.append({
"type": "service",
"service_name": service_name,
"tenant_id": context.virtual_tenant_id,
"session_id": context.session_id
})
# Update Redis with granular progress after each service completes
await context.orchestrator._update_progress_in_redis(context.session_id, {
"completed_services": completed_count,
"total_services": total_count,
"progress_percentage": int((completed_count / total_count) * 100),
"services": service_results,
"total_records_cloned": total_records
})
logger.info(
f"Service {service_name} completed ({completed_count}/{total_count})",
session_id=context.session_id,
records_cloned=result.get("records_cloned", 0)
)
except Exception as e:
# Handle exceptions from the task itself
service_name = completed_tasks_info[completed_task]
service_def = next(s for s in services_to_clone if s.name == service_name)
logger.error(
f"Service {service_name} cloning failed with exception",
session_id=context.session_id,
error=str(e)
)
service_results[service_name] = {
"status": "failed",
"error": str(e),
"records_cloned": 0
}
failed_services.append(service_name)
completed_count += 1
if service_def.required:
required_service_failed = True
# Determine overall status
if required_service_failed:
overall_status = "failed"
elif failed_services:
overall_status = "partial"
else:
overall_status = "completed"
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
"Professional cloning strategy completed",
session_id=context.session_id,
overall_status=overall_status,
total_records=total_records,
failed_services=failed_services,
duration_ms=duration_ms
)
return {
"overall_status": overall_status,
"services": service_results,
"total_records": total_records,
"failed_services": failed_services,
"duration_ms": duration_ms,
"rollback_stack": rollback_stack
}
except Exception as e:
logger.error(
"Professional cloning strategy failed",
session_id=context.session_id,
error=str(e),
exc_info=True
)
return {
"overall_status": "failed",
"error": str(e),
"services": {},
"total_records": 0,
"failed_services": [],
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"rollback_stack": rollback_stack
}
class EnterpriseCloningStrategy(CloningStrategy):
"""
Strategy for multi-tenant enterprise demos
Clones parent tenant + child tenants + distribution data
"""
def get_strategy_name(self) -> str:
return "enterprise"
async def clone(self, context: CloningContext) -> Dict[str, Any]:
"""
Clone demo data for an enterprise (multi-tenant) account
Process:
1. Validate enterprise metadata
2. Clone parent tenant using ProfessionalCloningStrategy
3. Clone child tenants in parallel
4. Update distribution data with child mappings
5. Return aggregated results
NOTE: No recursion - uses ProfessionalCloningStrategy as a helper
"""
logger.info(
"Executing enterprise cloning strategy",
session_id=context.session_id,
parent_tenant_id=context.virtual_tenant_id,
base_tenant_id=context.base_tenant_id
)
start_time = datetime.now(timezone.utc)
results = {
"parent": {},
"children": [],
"distribution": {},
"overall_status": "pending"
}
rollback_stack = []
try:
# Validate enterprise metadata
if not context.session_metadata:
raise ValueError("Enterprise cloning requires session_metadata")
is_enterprise = context.session_metadata.get("is_enterprise", False)
child_configs = context.session_metadata.get("child_configs", [])
child_tenant_ids = context.session_metadata.get("child_tenant_ids", [])
if not is_enterprise:
raise ValueError("session_metadata.is_enterprise must be True")
if not child_configs or not child_tenant_ids:
raise ValueError("Enterprise metadata missing child_configs or child_tenant_ids")
logger.info(
"Enterprise metadata validated",
session_id=context.session_id,
child_count=len(child_configs)
)
# Phase 1: Clone parent tenant
logger.info("Phase 1: Cloning parent tenant", session_id=context.session_id)
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": {"overall_status": "pending"},
"children": [],
"distribution": {}
})
# Use ProfessionalCloningStrategy to clone parent
# This is composition, not recursion - explicit strategy usage
professional_strategy = ProfessionalCloningStrategy()
parent_context = CloningContext(
base_tenant_id=context.base_tenant_id,
virtual_tenant_id=context.virtual_tenant_id,
session_id=context.session_id,
demo_account_type="enterprise", # Explicit type for parent tenant
session_metadata=context.session_metadata,
orchestrator=context.orchestrator
)
parent_result = await professional_strategy.clone(parent_context)
results["parent"] = parent_result
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": [],
"distribution": {}
})
# Track parent for rollback
if parent_result.get("overall_status") not in ["failed"]:
rollback_stack.append({
"type": "tenant",
"tenant_id": context.virtual_tenant_id,
"session_id": context.session_id
})
# Validate parent success
parent_status = parent_result.get("overall_status")
if parent_status == "failed":
logger.error(
"Parent cloning failed, aborting enterprise demo",
session_id=context.session_id,
failed_services=parent_result.get("failed_services", [])
)
results["overall_status"] = "failed"
results["error"] = "Parent tenant cloning failed"
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
return results
if parent_status == "partial":
# Check if tenant service succeeded (critical)
parent_services = parent_result.get("services", {})
if parent_services.get("tenant", {}).get("status") != "completed":
logger.error(
"Tenant service failed in parent, cannot create children",
session_id=context.session_id
)
results["overall_status"] = "failed"
results["error"] = "Parent tenant creation failed - cannot create child tenants"
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
return results
logger.info(
"Parent cloning succeeded, proceeding with children",
session_id=context.session_id,
parent_status=parent_status
)
# Phase 2: Clone child tenants in parallel
logger.info(
"Phase 2: Cloning child outlets",
session_id=context.session_id,
child_count=len(child_configs)
)
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": [{"status": "pending"} for _ in child_configs],
"distribution": {}
})
# Import asyncio for parallel execution
import asyncio
child_tasks = []
for idx, (child_config, child_id) in enumerate(zip(child_configs, child_tenant_ids)):
task = context.orchestrator._clone_child_outlet(
base_tenant_id=child_config.get("base_tenant_id"),
virtual_child_id=child_id,
parent_tenant_id=context.virtual_tenant_id,
child_name=child_config.get("name"),
location=child_config.get("location"),
session_id=context.session_id
)
child_tasks.append(task)
child_results = await asyncio.gather(*child_tasks, return_exceptions=True)
# Process child results
children_data = []
failed_children = 0
for idx, result in enumerate(child_results):
if isinstance(result, Exception):
logger.error(
f"Child {idx} cloning failed",
session_id=context.session_id,
error=str(result)
)
children_data.append({
"status": "failed",
"error": str(result),
"child_id": child_tenant_ids[idx] if idx < len(child_tenant_ids) else None
})
failed_children += 1
else:
children_data.append(result)
if result.get("overall_status") == "failed":
failed_children += 1
else:
# Track for rollback
rollback_stack.append({
"type": "tenant",
"tenant_id": result.get("child_id"),
"session_id": context.session_id
})
results["children"] = children_data
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": children_data,
"distribution": {}
})
logger.info(
"Child cloning completed",
session_id=context.session_id,
total_children=len(child_configs),
failed_children=failed_children
)
# Phase 3: Clone distribution data
logger.info("Phase 3: Cloning distribution data", session_id=context.session_id)
# Find distribution service definition
dist_service_def = next(
(s for s in context.orchestrator.services if s.name == "distribution"),
None
)
if dist_service_def:
dist_result = await context.orchestrator._clone_service(
service_def=dist_service_def,
base_tenant_id=context.base_tenant_id,
virtual_tenant_id=context.virtual_tenant_id,
demo_account_type="enterprise",
session_id=context.session_id,
session_metadata=context.session_metadata
)
results["distribution"] = dist_result
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": children_data,
"distribution": dist_result
})
# Track for rollback
if dist_result.get("status") == "completed":
rollback_stack.append({
"type": "service",
"service_name": "distribution",
"tenant_id": context.virtual_tenant_id,
"session_id": context.session_id
})
total_records_cloned = parent_result.get("total_records", 0)
total_records_cloned += dist_result.get("records_cloned", 0)
else:
logger.warning("Distribution service not found in orchestrator", session_id=context.session_id)
# Determine overall status
if failed_children == len(child_configs):
overall_status = "failed"
elif failed_children > 0:
overall_status = "partial"
else:
overall_status = "completed" # Changed from "ready" to match professional strategy
# Calculate total records cloned (parent + all children)
total_records_cloned = parent_result.get("total_records", 0)
for child in children_data:
if isinstance(child, dict):
total_records_cloned += child.get("total_records", child.get("records_cloned", 0))
results["overall_status"] = overall_status
results["total_records_cloned"] = total_records_cloned # Add for session manager
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
results["rollback_stack"] = rollback_stack
# Include services from parent for session manager compatibility
results["services"] = parent_result.get("services", {})
logger.info(
"Enterprise cloning strategy completed",
session_id=context.session_id,
overall_status=overall_status,
parent_status=parent_status,
children_status=f"{len(child_configs) - failed_children}/{len(child_configs)} succeeded",
total_records_cloned=total_records_cloned,
duration_ms=results["duration_ms"]
)
return results
except Exception as e:
logger.error(
"Enterprise cloning strategy failed",
session_id=context.session_id,
error=str(e),
exc_info=True
)
return {
"overall_status": "failed",
"error": str(e),
"parent": {},
"children": [],
"distribution": {},
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"rollback_stack": rollback_stack
}
class CloningStrategyFactory:
"""
Factory for creating cloning strategies
Provides type-safe strategy selection
"""
_strategies: Dict[str, CloningStrategy] = {
"professional": ProfessionalCloningStrategy(),
"enterprise": EnterpriseCloningStrategy(),
"enterprise_child": ProfessionalCloningStrategy() # Alias: children use professional strategy
}
@classmethod
def get_strategy(cls, demo_account_type: str) -> CloningStrategy:
"""
Get the appropriate cloning strategy for the demo account type
Args:
demo_account_type: Type of demo account ("professional" or "enterprise")
Returns:
CloningStrategy instance
Raises:
ValueError: If demo_account_type is not supported
"""
strategy = cls._strategies.get(demo_account_type)
if not strategy:
raise ValueError(
f"Unknown demo_account_type: {demo_account_type}. "
f"Supported types: {list(cls._strategies.keys())}"
)
return strategy
@classmethod
def register_strategy(cls, name: str, strategy: CloningStrategy):
"""
Register a custom cloning strategy
Args:
name: Strategy name
strategy: Strategy instance
"""
cls._strategies[name] = strategy
logger.info(f"Registered custom cloning strategy: {name}")

View File

@@ -1,356 +0,0 @@
"""
Demo Data Cloner
Clones base demo data to session-specific virtual tenants
"""
from sqlalchemy.ext.asyncio import AsyncSession
from typing import Dict, Any, List, Optional
import httpx
import structlog
import uuid
import os
import asyncio
from app.core.redis_wrapper import DemoRedisWrapper
from app.core import settings
logger = structlog.get_logger()
class DemoDataCloner:
"""Clones demo data for isolated sessions"""
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
self.db = db
self.redis = redis
self._http_client: Optional[httpx.AsyncClient] = None
async def get_http_client(self) -> httpx.AsyncClient:
"""Get or create shared HTTP client with connection pooling"""
if self._http_client is None:
self._http_client = httpx.AsyncClient(
timeout=httpx.Timeout(30.0, connect_timeout=10.0),
limits=httpx.Limits(
max_connections=20,
max_keepalive_connections=10,
keepalive_expiry=30.0
)
)
return self._http_client
async def close(self):
"""Close HTTP client on cleanup"""
if self._http_client:
await self._http_client.aclose()
self._http_client = None
async def clone_tenant_data(
self,
session_id: str,
base_demo_tenant_id: str,
virtual_tenant_id: str,
demo_account_type: str
) -> Dict[str, Any]:
"""
Clone all demo data from base tenant to virtual tenant
Args:
session_id: Session ID
base_demo_tenant_id: Base demo tenant UUID
virtual_tenant_id: Virtual tenant UUID for this session
demo_account_type: Type of demo account
Returns:
Cloning statistics
"""
logger.info(
"Starting data cloning",
session_id=session_id,
base_demo_tenant_id=base_demo_tenant_id,
virtual_tenant_id=virtual_tenant_id
)
stats = {
"session_id": session_id,
"services_cloned": [],
"total_records": 0,
"redis_keys": 0
}
# Clone data from each service based on demo account type
services_to_clone = self._get_services_for_demo_type(demo_account_type)
for service_name in services_to_clone:
try:
service_stats = await self._clone_service_data(
service_name,
base_demo_tenant_id,
virtual_tenant_id,
session_id,
demo_account_type
)
stats["services_cloned"].append(service_name)
stats["total_records"] += service_stats.get("records_cloned", 0)
except Exception as e:
logger.error(
"Failed to clone service data",
service=service_name,
error=str(e)
)
# Populate Redis cache with hot data
redis_stats = await self._populate_redis_cache(
session_id,
virtual_tenant_id,
demo_account_type
)
stats["redis_keys"] = redis_stats.get("keys_created", 0)
logger.info(
"Data cloning completed",
session_id=session_id,
stats=stats
)
return stats
def _get_services_for_demo_type(self, demo_account_type: str) -> List[str]:
"""Get list of services to clone based on demo type"""
base_services = ["inventory", "sales", "orders", "pos"]
if demo_account_type == "professional":
# Professional has production, recipes, suppliers, and procurement
return base_services + ["recipes", "production", "suppliers", "procurement", "alert_processor"]
elif demo_account_type == "enterprise":
# Enterprise has suppliers, procurement, and distribution (for parent-child network)
return base_services + ["suppliers", "procurement", "distribution", "alert_processor"]
else:
# Basic tenant has suppliers and procurement
return base_services + ["suppliers", "procurement", "distribution", "alert_processor"]
async def _clone_service_data(
self,
service_name: str,
base_tenant_id: str,
virtual_tenant_id: str,
session_id: str,
demo_account_type: str
) -> Dict[str, Any]:
"""
Clone data for a specific service
Args:
service_name: Name of the service
base_tenant_id: Source tenant ID
virtual_tenant_id: Target tenant ID
session_id: Session ID
demo_account_type: Type of demo account
Returns:
Cloning statistics
"""
service_url = self._get_service_url(service_name)
# Get internal API key from settings
from app.core.config import settings
internal_api_key = settings.INTERNAL_API_KEY
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.post(
f"{service_url}/internal/demo/clone",
json={
"base_tenant_id": base_tenant_id,
"virtual_tenant_id": virtual_tenant_id,
"session_id": session_id,
"demo_account_type": demo_account_type
},
headers={"X-Internal-API-Key": internal_api_key}
)
response.raise_for_status()
return response.json()
async def _populate_redis_cache(
self,
session_id: str,
virtual_tenant_id: str,
demo_account_type: str
) -> Dict[str, Any]:
"""
Populate Redis with frequently accessed data
Args:
session_id: Session ID
virtual_tenant_id: Virtual tenant ID
demo_account_type: Demo account type
Returns:
Statistics about cached data
"""
logger.info("Populating Redis cache", session_id=session_id)
keys_created = 0
# Cache inventory data (hot data)
try:
inventory_data = await self._fetch_inventory_data(virtual_tenant_id)
await self.redis.set_session_data(
session_id,
"inventory",
inventory_data,
ttl=settings.REDIS_SESSION_TTL
)
keys_created += 1
except Exception as e:
logger.error("Failed to cache inventory", error=str(e))
# Cache POS data
try:
pos_data = await self._fetch_pos_data(virtual_tenant_id)
await self.redis.set_session_data(
session_id,
"pos",
pos_data,
ttl=settings.REDIS_SESSION_TTL
)
keys_created += 1
except Exception as e:
logger.error("Failed to cache POS data", error=str(e))
# Cache recent sales
try:
sales_data = await self._fetch_recent_sales(virtual_tenant_id)
await self.redis.set_session_data(
session_id,
"recent_sales",
sales_data,
ttl=settings.REDIS_SESSION_TTL
)
keys_created += 1
except Exception as e:
logger.error("Failed to cache sales", error=str(e))
return {"keys_created": keys_created}
async def _fetch_inventory_data(self, tenant_id: str) -> Dict[str, Any]:
"""Fetch inventory data for caching"""
async with httpx.AsyncClient(timeout=httpx.Timeout(15.0, connect_timeout=5.0)) as client:
response = await client.get(
f"{settings.INVENTORY_SERVICE_URL}/api/inventory/summary",
headers={"X-Tenant-Id": tenant_id}
)
return response.json()
async def _fetch_pos_data(self, tenant_id: str) -> Dict[str, Any]:
"""Fetch POS data for caching"""
async with httpx.AsyncClient(timeout=httpx.Timeout(15.0, connect_timeout=5.0)) as client:
response = await client.get(
f"{settings.POS_SERVICE_URL}/api/pos/current-session",
headers={"X-Tenant-Id": tenant_id}
)
return response.json()
async def _fetch_recent_sales(self, tenant_id: str) -> Dict[str, Any]:
"""Fetch recent sales for caching"""
async with httpx.AsyncClient() as client:
response = await client.get(
f"{settings.SALES_SERVICE_URL}/api/sales/recent?limit=50",
headers={"X-Tenant-Id": tenant_id}
)
return response.json()
def _get_service_url(self, service_name: str) -> str:
"""Get service URL from settings"""
url_map = {
"inventory": settings.INVENTORY_SERVICE_URL,
"recipes": settings.RECIPES_SERVICE_URL,
"sales": settings.SALES_SERVICE_URL,
"orders": settings.ORDERS_SERVICE_URL,
"production": settings.PRODUCTION_SERVICE_URL,
"suppliers": settings.SUPPLIERS_SERVICE_URL,
"pos": settings.POS_SERVICE_URL,
"procurement": settings.PROCUREMENT_SERVICE_URL,
"distribution": settings.DISTRIBUTION_SERVICE_URL,
"forecasting": settings.FORECASTING_SERVICE_URL,
"alert_processor": settings.ALERT_PROCESSOR_SERVICE_URL,
}
return url_map.get(service_name, "")
async def delete_session_data(
self,
virtual_tenant_id: str,
session_id: str
):
"""
Delete all data for a session using parallel deletion for performance
Args:
virtual_tenant_id: Virtual tenant ID to delete
session_id: Session ID
"""
logger.info(
"Deleting session data",
virtual_tenant_id=virtual_tenant_id,
session_id=session_id
)
# Get shared HTTP client for all deletions
client = await self.get_http_client()
# Services list - all can be deleted in parallel as deletion endpoints
# handle their own internal ordering if needed
services = [
"forecasting",
"sales",
"orders",
"production",
"inventory",
"recipes",
"suppliers",
"pos",
"distribution",
"procurement",
"alert_processor"
]
# Create deletion tasks for all services
deletion_tasks = [
self._delete_service_data(service_name, virtual_tenant_id, client)
for service_name in services
]
# Execute all deletions in parallel with exception handling
results = await asyncio.gather(*deletion_tasks, return_exceptions=True)
# Log any failures
for service_name, result in zip(services, results):
if isinstance(result, Exception):
logger.error(
"Failed to delete service data",
service=service_name,
error=str(result)
)
# Delete from Redis
await self.redis.delete_session_data(session_id)
logger.info("Session data deleted", virtual_tenant_id=virtual_tenant_id)
async def _delete_service_data(
self,
service_name: str,
virtual_tenant_id: str,
client: httpx.AsyncClient
):
"""Delete data from a specific service using provided HTTP client"""
service_url = self._get_service_url(service_name)
# Get internal API key from settings
from app.core.config import settings
internal_api_key = settings.INTERNAL_API_KEY
await client.delete(
f"{service_url}/internal/demo/tenant/{virtual_tenant_id}",
headers={"X-Internal-API-Key": internal_api_key}
)

View File

@@ -75,18 +75,11 @@ class DemoSessionManager:
base_tenant_id = uuid.UUID(base_tenant_id_str)
# Validate that the base tenant ID exists in the tenant service
# This is important to prevent cloning from non-existent base tenants
await self._validate_base_tenant_exists(base_tenant_id, demo_account_type)
# Handle enterprise chain setup
child_tenant_ids = []
if demo_account_type == 'enterprise':
# Validate child template tenants exist before proceeding
child_configs = demo_config.get('children', [])
await self._validate_child_template_tenants(child_configs)
# Generate child tenant IDs for enterprise demos
child_configs = demo_config.get('children', [])
child_tenant_ids = [uuid.uuid4() for _ in child_configs]
# Create session record using repository
@@ -208,9 +201,7 @@ class DemoSessionManager:
async def destroy_session(self, session_id: str):
"""
Destroy a demo session and cleanup resources
Args:
session_id: Session ID to destroy
This triggers parallel deletion across all services.
"""
session = await self.get_session(session_id)
@@ -218,8 +209,30 @@ class DemoSessionManager:
logger.warning("Session not found for destruction", session_id=session_id)
return
# Update session status via repository
await self.repository.destroy(session_id)
# Update status to DESTROYING
await self.repository.update_fields(
session_id,
status=DemoSessionStatus.DESTROYING
)
# Trigger cleanup across all services
cleanup_service = DemoCleanupService(self.db, self.redis)
result = await cleanup_service.cleanup_session(session)
if result["success"]:
# Update status to DESTROYED
await self.repository.update_fields(
session_id,
status=DemoSessionStatus.DESTROYED,
destroyed_at=datetime.now(timezone.utc)
)
else:
# Update status to FAILED with error details
await self.repository.update_fields(
session_id,
status=DemoSessionStatus.FAILED,
error_details=result["errors"]
)
# Delete Redis data
await self.redis.delete_session_data(session_id)
@@ -227,9 +240,34 @@ class DemoSessionManager:
logger.info(
"Session destroyed",
session_id=session_id,
virtual_tenant_id=str(session.virtual_tenant_id)
virtual_tenant_id=str(session.virtual_tenant_id),
total_records_deleted=result.get("total_deleted", 0),
duration_ms=result.get("duration_ms", 0)
)
async def _check_database_disk_space(self):
"""Check if database has sufficient disk space for demo operations"""
try:
# Execute a simple query to check database health and disk space
# This is a basic check - in production you might want more comprehensive monitoring
from sqlalchemy import text
# Check if we can execute a simple query (indicates basic database health)
result = await self.db.execute(text("SELECT 1"))
# Get the scalar result properly
scalar_result = result.scalar_one_or_none()
# For more comprehensive checking, you could add:
# 1. Check table sizes
# 2. Check available disk space via system queries (if permissions allow)
# 3. Check for long-running transactions that might block operations
logger.debug("Database health check passed", result=scalar_result)
except Exception as e:
logger.error("Database health check failed", error=str(e), exc_info=True)
raise RuntimeError(f"Database health check failed: {str(e)}")
async def _store_session_metadata(self, session: DemoSession):
"""Store session metadata in Redis"""
await self.redis.set_session_data(
@@ -274,6 +312,33 @@ class DemoSessionManager:
virtual_tenant_id=str(session.virtual_tenant_id)
)
# Check database disk space before starting cloning
try:
await self._check_database_disk_space()
except Exception as e:
logger.error(
"Database disk space check failed",
session_id=session.session_id,
error=str(e)
)
# Mark session as failed due to infrastructure issue
session.status = DemoSessionStatus.FAILED
session.cloning_completed_at = datetime.now(timezone.utc)
session.total_records_cloned = 0
session.cloning_progress = {
"error": "Database disk space issue detected",
"details": str(e)
}
await self.repository.update(session)
await self._cache_session_status(session)
return {
"overall_status": "failed",
"services": {},
"total_records": 0,
"failed_services": ["database"],
"error": "Database disk space issue"
}
# Mark cloning as started and update both database and Redis cache
session.cloning_started_at = datetime.now(timezone.utc)
await self.repository.update(session)
@@ -295,130 +360,7 @@ class DemoSessionManager:
return result
async def _validate_base_tenant_exists(self, base_tenant_id: uuid.UUID, demo_account_type: str) -> bool:
"""
Validate that the base tenant exists in the tenant service before starting cloning.
This prevents cloning from non-existent base tenants.
Args:
base_tenant_id: The UUID of the base tenant to validate
demo_account_type: The demo account type for logging
Returns:
True if tenant exists, raises exception otherwise
"""
logger.info(
"Validating base tenant exists before cloning",
base_tenant_id=str(base_tenant_id),
demo_account_type=demo_account_type
)
# Basic validation: check if UUID is valid (not empty/nil)
if str(base_tenant_id) == "00000000-0000-0000-0000-000000000000":
raise ValueError(f"Invalid base tenant ID: {base_tenant_id} for demo type: {demo_account_type}")
# BUG-008 FIX: Actually validate with tenant service
try:
from shared.clients.tenant_client import TenantServiceClient
tenant_client = TenantServiceClient(settings)
tenant = await tenant_client.get_tenant(str(base_tenant_id))
if not tenant:
error_msg = (
f"Base tenant {base_tenant_id} does not exist for demo type {demo_account_type}. "
f"Please verify the base_tenant_id in demo configuration."
)
logger.error(
"Base tenant validation failed",
base_tenant_id=str(base_tenant_id),
demo_account_type=demo_account_type
)
raise ValueError(error_msg)
logger.info(
"Base tenant validation passed",
base_tenant_id=str(base_tenant_id),
tenant_name=tenant.get("name", "unknown"),
demo_account_type=demo_account_type
)
return True
except ValueError:
# Re-raise ValueError from validation failure
raise
except Exception as e:
logger.error(
f"Error validating base tenant: {e}",
base_tenant_id=str(base_tenant_id),
demo_account_type=demo_account_type,
exc_info=True
)
raise ValueError(f"Cannot validate base tenant {base_tenant_id}: {str(e)}")
async def _validate_child_template_tenants(self, child_configs: list) -> bool:
"""
Validate that all child template tenants exist before cloning.
This prevents silent failures when child base tenants are missing.
Args:
child_configs: List of child configurations with base_tenant_id
Returns:
True if all child templates exist, raises exception otherwise
"""
if not child_configs:
logger.warning("No child configurations provided for validation")
return True
logger.info("Validating child template tenants", child_count=len(child_configs))
try:
from shared.clients.tenant_client import TenantServiceClient
tenant_client = TenantServiceClient(settings)
for child_config in child_configs:
child_base_id = child_config.get("base_tenant_id")
child_name = child_config.get("name", "unknown")
if not child_base_id:
raise ValueError(f"Child config missing base_tenant_id: {child_name}")
# Validate child template exists
child_tenant = await tenant_client.get_tenant(child_base_id)
if not child_tenant:
error_msg = (
f"Child template tenant {child_base_id} ('{child_name}') does not exist. "
f"Please verify the base_tenant_id in demo configuration."
)
logger.error(
"Child template validation failed",
base_tenant_id=child_base_id,
child_name=child_name
)
raise ValueError(error_msg)
logger.info(
"Child template validation passed",
base_tenant_id=child_base_id,
child_name=child_name,
tenant_name=child_tenant.get("name", "unknown")
)
logger.info("All child template tenants validated successfully")
return True
except ValueError:
# Re-raise ValueError from validation failure
raise
except Exception as e:
logger.error(
f"Error validating child template tenants: {e}",
exc_info=True
)
raise ValueError(f"Cannot validate child template tenants: {str(e)}")
async def _update_session_from_clone_result(
self,
@@ -573,4 +515,4 @@ class DemoSessionManager:
# Trigger new cloning attempt
result = await self.trigger_orchestrated_cloning(session, base_tenant_id)
return result
return result