605 lines
23 KiB
Python
605 lines
23 KiB
Python
"""
|
|
Cloning Strategy Pattern Implementation
|
|
Provides explicit, type-safe strategies for different demo account types
|
|
"""
|
|
|
|
from abc import ABC, abstractmethod
|
|
from dataclasses import dataclass
|
|
from typing import Dict, Any, List, Optional
|
|
from datetime import datetime, timezone
|
|
import structlog
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
@dataclass
|
|
class CloningContext:
|
|
"""
|
|
Context object containing all data needed for cloning operations
|
|
Immutable to prevent state mutation bugs
|
|
"""
|
|
base_tenant_id: str
|
|
virtual_tenant_id: str
|
|
session_id: str
|
|
demo_account_type: str
|
|
session_metadata: Optional[Dict[str, Any]] = None
|
|
services_filter: Optional[List[str]] = None
|
|
|
|
# Orchestrator dependencies (injected)
|
|
orchestrator: Any = None # Will be CloneOrchestrator instance
|
|
|
|
def __post_init__(self):
|
|
"""Validate context after initialization"""
|
|
if not self.base_tenant_id:
|
|
raise ValueError("base_tenant_id is required")
|
|
if not self.virtual_tenant_id:
|
|
raise ValueError("virtual_tenant_id is required")
|
|
if not self.session_id:
|
|
raise ValueError("session_id is required")
|
|
|
|
|
|
class CloningStrategy(ABC):
|
|
"""
|
|
Abstract base class for cloning strategies
|
|
Each strategy is a leaf node - no recursion possible
|
|
"""
|
|
|
|
@abstractmethod
|
|
async def clone(self, context: CloningContext) -> Dict[str, Any]:
|
|
"""
|
|
Execute the cloning strategy
|
|
|
|
Args:
|
|
context: Immutable context with all required data
|
|
|
|
Returns:
|
|
Dictionary with cloning results
|
|
"""
|
|
pass
|
|
|
|
@abstractmethod
|
|
def get_strategy_name(self) -> str:
|
|
"""Return the name of this strategy for logging"""
|
|
pass
|
|
|
|
|
|
class ProfessionalCloningStrategy(CloningStrategy):
|
|
"""
|
|
Strategy for single-tenant professional demos
|
|
Clones all services for a single virtual tenant
|
|
"""
|
|
|
|
def get_strategy_name(self) -> str:
|
|
return "professional"
|
|
|
|
async def clone(self, context: CloningContext) -> Dict[str, Any]:
|
|
"""
|
|
Clone demo data for a professional (single-tenant) account
|
|
|
|
Process:
|
|
1. Validate context
|
|
2. Clone all services in parallel
|
|
3. Handle failures with partial success support
|
|
4. Return aggregated results
|
|
"""
|
|
logger.info(
|
|
"Executing professional cloning strategy",
|
|
session_id=context.session_id,
|
|
virtual_tenant_id=context.virtual_tenant_id,
|
|
base_tenant_id=context.base_tenant_id
|
|
)
|
|
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
# Determine which services to clone
|
|
services_to_clone = context.orchestrator.services
|
|
if context.services_filter:
|
|
services_to_clone = [
|
|
s for s in context.orchestrator.services
|
|
if s.name in context.services_filter
|
|
]
|
|
logger.info(
|
|
"Filtering services",
|
|
session_id=context.session_id,
|
|
services_filter=context.services_filter,
|
|
filtered_count=len(services_to_clone)
|
|
)
|
|
|
|
# Rollback stack for cleanup
|
|
rollback_stack = []
|
|
|
|
try:
|
|
# Import asyncio here to avoid circular imports
|
|
import asyncio
|
|
|
|
# Create parallel tasks for all services
|
|
tasks = []
|
|
service_map = {}
|
|
|
|
for service_def in services_to_clone:
|
|
task = asyncio.create_task(
|
|
context.orchestrator._clone_service(
|
|
service_def=service_def,
|
|
base_tenant_id=context.base_tenant_id,
|
|
virtual_tenant_id=context.virtual_tenant_id,
|
|
demo_account_type=context.demo_account_type,
|
|
session_id=context.session_id,
|
|
session_metadata=context.session_metadata
|
|
)
|
|
)
|
|
tasks.append(task)
|
|
service_map[task] = service_def.name
|
|
|
|
# Process tasks as they complete for real-time progress updates
|
|
service_results = {}
|
|
total_records = 0
|
|
failed_services = []
|
|
required_service_failed = False
|
|
completed_count = 0
|
|
total_count = len(tasks)
|
|
|
|
# Create a mapping from futures to service names to properly identify completed tasks
|
|
# We'll use asyncio.wait approach instead of as_completed to access the original tasks
|
|
pending = set(tasks)
|
|
completed_tasks_info = {task: service_map[task] for task in tasks} # Map tasks to service names
|
|
|
|
while pending:
|
|
# Wait for at least one task to complete
|
|
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
|
|
|
# Process each completed task
|
|
for completed_task in done:
|
|
try:
|
|
# Get the result from the completed task
|
|
result = await completed_task
|
|
# Get the service name from our mapping
|
|
service_name = completed_tasks_info[completed_task]
|
|
service_def = next(s for s in services_to_clone if s.name == service_name)
|
|
|
|
service_results[service_name] = result
|
|
completed_count += 1
|
|
|
|
if result.get("status") == "failed":
|
|
failed_services.append(service_name)
|
|
if service_def.required:
|
|
required_service_failed = True
|
|
else:
|
|
total_records += result.get("records_cloned", 0)
|
|
|
|
# Track successful services for rollback
|
|
if result.get("status") == "completed":
|
|
rollback_stack.append({
|
|
"type": "service",
|
|
"service_name": service_name,
|
|
"tenant_id": context.virtual_tenant_id,
|
|
"session_id": context.session_id
|
|
})
|
|
|
|
# Update Redis with granular progress after each service completes
|
|
await context.orchestrator._update_progress_in_redis(context.session_id, {
|
|
"completed_services": completed_count,
|
|
"total_services": total_count,
|
|
"progress_percentage": int((completed_count / total_count) * 100),
|
|
"services": service_results,
|
|
"total_records_cloned": total_records
|
|
})
|
|
|
|
logger.info(
|
|
f"Service {service_name} completed ({completed_count}/{total_count})",
|
|
session_id=context.session_id,
|
|
records_cloned=result.get("records_cloned", 0)
|
|
)
|
|
|
|
except Exception as e:
|
|
# Handle exceptions from the task itself
|
|
service_name = completed_tasks_info[completed_task]
|
|
service_def = next(s for s in services_to_clone if s.name == service_name)
|
|
|
|
logger.error(
|
|
f"Service {service_name} cloning failed with exception",
|
|
session_id=context.session_id,
|
|
error=str(e)
|
|
)
|
|
service_results[service_name] = {
|
|
"status": "failed",
|
|
"error": str(e),
|
|
"records_cloned": 0
|
|
}
|
|
failed_services.append(service_name)
|
|
completed_count += 1
|
|
if service_def.required:
|
|
required_service_failed = True
|
|
|
|
# Determine overall status
|
|
if required_service_failed:
|
|
overall_status = "failed"
|
|
elif failed_services:
|
|
overall_status = "partial"
|
|
else:
|
|
overall_status = "completed"
|
|
|
|
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
|
|
logger.info(
|
|
"Professional cloning strategy completed",
|
|
session_id=context.session_id,
|
|
overall_status=overall_status,
|
|
total_records=total_records,
|
|
failed_services=failed_services,
|
|
duration_ms=duration_ms
|
|
)
|
|
|
|
return {
|
|
"overall_status": overall_status,
|
|
"services": service_results,
|
|
"total_records": total_records,
|
|
"failed_services": failed_services,
|
|
"duration_ms": duration_ms,
|
|
"rollback_stack": rollback_stack
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Professional cloning strategy failed",
|
|
session_id=context.session_id,
|
|
error=str(e),
|
|
exc_info=True
|
|
)
|
|
return {
|
|
"overall_status": "failed",
|
|
"error": str(e),
|
|
"services": {},
|
|
"total_records": 0,
|
|
"failed_services": [],
|
|
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
|
"rollback_stack": rollback_stack
|
|
}
|
|
|
|
|
|
class EnterpriseCloningStrategy(CloningStrategy):
|
|
"""
|
|
Strategy for multi-tenant enterprise demos
|
|
Clones parent tenant + child tenants + distribution data
|
|
"""
|
|
|
|
def get_strategy_name(self) -> str:
|
|
return "enterprise"
|
|
|
|
async def clone(self, context: CloningContext) -> Dict[str, Any]:
|
|
"""
|
|
Clone demo data for an enterprise (multi-tenant) account
|
|
|
|
Process:
|
|
1. Validate enterprise metadata
|
|
2. Clone parent tenant using ProfessionalCloningStrategy
|
|
3. Clone child tenants in parallel
|
|
4. Update distribution data with child mappings
|
|
5. Return aggregated results
|
|
|
|
NOTE: No recursion - uses ProfessionalCloningStrategy as a helper
|
|
"""
|
|
logger.info(
|
|
"Executing enterprise cloning strategy",
|
|
session_id=context.session_id,
|
|
parent_tenant_id=context.virtual_tenant_id,
|
|
base_tenant_id=context.base_tenant_id
|
|
)
|
|
|
|
start_time = datetime.now(timezone.utc)
|
|
results = {
|
|
"parent": {},
|
|
"children": [],
|
|
"distribution": {},
|
|
"overall_status": "pending"
|
|
}
|
|
rollback_stack = []
|
|
|
|
try:
|
|
# Validate enterprise metadata
|
|
if not context.session_metadata:
|
|
raise ValueError("Enterprise cloning requires session_metadata")
|
|
|
|
is_enterprise = context.session_metadata.get("is_enterprise", False)
|
|
child_configs = context.session_metadata.get("child_configs", [])
|
|
child_tenant_ids = context.session_metadata.get("child_tenant_ids", [])
|
|
|
|
if not is_enterprise:
|
|
raise ValueError("session_metadata.is_enterprise must be True")
|
|
|
|
if not child_configs or not child_tenant_ids:
|
|
raise ValueError("Enterprise metadata missing child_configs or child_tenant_ids")
|
|
|
|
logger.info(
|
|
"Enterprise metadata validated",
|
|
session_id=context.session_id,
|
|
child_count=len(child_configs)
|
|
)
|
|
|
|
# Phase 1: Clone parent tenant
|
|
logger.info("Phase 1: Cloning parent tenant", session_id=context.session_id)
|
|
|
|
# Update progress
|
|
await context.orchestrator._update_progress_in_redis(context.session_id, {
|
|
"parent": {"overall_status": "pending"},
|
|
"children": [],
|
|
"distribution": {}
|
|
})
|
|
|
|
# Use ProfessionalCloningStrategy to clone parent
|
|
# This is composition, not recursion - explicit strategy usage
|
|
professional_strategy = ProfessionalCloningStrategy()
|
|
parent_context = CloningContext(
|
|
base_tenant_id=context.base_tenant_id,
|
|
virtual_tenant_id=context.virtual_tenant_id,
|
|
session_id=context.session_id,
|
|
demo_account_type="enterprise", # Explicit type for parent tenant
|
|
session_metadata=context.session_metadata,
|
|
orchestrator=context.orchestrator
|
|
)
|
|
|
|
parent_result = await professional_strategy.clone(parent_context)
|
|
results["parent"] = parent_result
|
|
|
|
# Update progress
|
|
await context.orchestrator._update_progress_in_redis(context.session_id, {
|
|
"parent": parent_result,
|
|
"children": [],
|
|
"distribution": {}
|
|
})
|
|
|
|
# Track parent for rollback
|
|
if parent_result.get("overall_status") not in ["failed"]:
|
|
rollback_stack.append({
|
|
"type": "tenant",
|
|
"tenant_id": context.virtual_tenant_id,
|
|
"session_id": context.session_id
|
|
})
|
|
|
|
# Validate parent success
|
|
parent_status = parent_result.get("overall_status")
|
|
|
|
if parent_status == "failed":
|
|
logger.error(
|
|
"Parent cloning failed, aborting enterprise demo",
|
|
session_id=context.session_id,
|
|
failed_services=parent_result.get("failed_services", [])
|
|
)
|
|
results["overall_status"] = "failed"
|
|
results["error"] = "Parent tenant cloning failed"
|
|
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
return results
|
|
|
|
if parent_status == "partial":
|
|
# Check if tenant service succeeded (critical)
|
|
parent_services = parent_result.get("services", {})
|
|
if parent_services.get("tenant", {}).get("status") != "completed":
|
|
logger.error(
|
|
"Tenant service failed in parent, cannot create children",
|
|
session_id=context.session_id
|
|
)
|
|
results["overall_status"] = "failed"
|
|
results["error"] = "Parent tenant creation failed - cannot create child tenants"
|
|
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
return results
|
|
|
|
logger.info(
|
|
"Parent cloning succeeded, proceeding with children",
|
|
session_id=context.session_id,
|
|
parent_status=parent_status
|
|
)
|
|
|
|
# Phase 2: Clone child tenants in parallel
|
|
logger.info(
|
|
"Phase 2: Cloning child outlets",
|
|
session_id=context.session_id,
|
|
child_count=len(child_configs)
|
|
)
|
|
|
|
# Update progress
|
|
await context.orchestrator._update_progress_in_redis(context.session_id, {
|
|
"parent": parent_result,
|
|
"children": [{"status": "pending"} for _ in child_configs],
|
|
"distribution": {}
|
|
})
|
|
|
|
# Import asyncio for parallel execution
|
|
import asyncio
|
|
|
|
child_tasks = []
|
|
for idx, (child_config, child_id) in enumerate(zip(child_configs, child_tenant_ids)):
|
|
task = context.orchestrator._clone_child_outlet(
|
|
base_tenant_id=child_config.get("base_tenant_id"),
|
|
virtual_child_id=child_id,
|
|
parent_tenant_id=context.virtual_tenant_id,
|
|
child_name=child_config.get("name"),
|
|
location=child_config.get("location"),
|
|
session_id=context.session_id
|
|
)
|
|
child_tasks.append(task)
|
|
|
|
child_results = await asyncio.gather(*child_tasks, return_exceptions=True)
|
|
|
|
# Process child results
|
|
children_data = []
|
|
failed_children = 0
|
|
|
|
for idx, result in enumerate(child_results):
|
|
if isinstance(result, Exception):
|
|
logger.error(
|
|
f"Child {idx} cloning failed",
|
|
session_id=context.session_id,
|
|
error=str(result)
|
|
)
|
|
children_data.append({
|
|
"status": "failed",
|
|
"error": str(result),
|
|
"child_id": child_tenant_ids[idx] if idx < len(child_tenant_ids) else None
|
|
})
|
|
failed_children += 1
|
|
else:
|
|
children_data.append(result)
|
|
if result.get("overall_status") == "failed":
|
|
failed_children += 1
|
|
else:
|
|
# Track for rollback
|
|
rollback_stack.append({
|
|
"type": "tenant",
|
|
"tenant_id": result.get("child_id"),
|
|
"session_id": context.session_id
|
|
})
|
|
|
|
results["children"] = children_data
|
|
|
|
# Update progress
|
|
await context.orchestrator._update_progress_in_redis(context.session_id, {
|
|
"parent": parent_result,
|
|
"children": children_data,
|
|
"distribution": {}
|
|
})
|
|
|
|
logger.info(
|
|
"Child cloning completed",
|
|
session_id=context.session_id,
|
|
total_children=len(child_configs),
|
|
failed_children=failed_children
|
|
)
|
|
|
|
# Phase 3: Clone distribution data
|
|
logger.info("Phase 3: Cloning distribution data", session_id=context.session_id)
|
|
|
|
# Find distribution service definition
|
|
dist_service_def = next(
|
|
(s for s in context.orchestrator.services if s.name == "distribution"),
|
|
None
|
|
)
|
|
|
|
if dist_service_def:
|
|
dist_result = await context.orchestrator._clone_service(
|
|
service_def=dist_service_def,
|
|
base_tenant_id=context.base_tenant_id,
|
|
virtual_tenant_id=context.virtual_tenant_id,
|
|
demo_account_type="enterprise",
|
|
session_id=context.session_id,
|
|
session_metadata=context.session_metadata
|
|
)
|
|
results["distribution"] = dist_result
|
|
|
|
# Update progress
|
|
await context.orchestrator._update_progress_in_redis(context.session_id, {
|
|
"parent": parent_result,
|
|
"children": children_data,
|
|
"distribution": dist_result
|
|
})
|
|
|
|
# Track for rollback
|
|
if dist_result.get("status") == "completed":
|
|
rollback_stack.append({
|
|
"type": "service",
|
|
"service_name": "distribution",
|
|
"tenant_id": context.virtual_tenant_id,
|
|
"session_id": context.session_id
|
|
})
|
|
total_records_cloned = parent_result.get("total_records", 0)
|
|
total_records_cloned += dist_result.get("records_cloned", 0)
|
|
else:
|
|
logger.warning("Distribution service not found in orchestrator", session_id=context.session_id)
|
|
|
|
# Determine overall status
|
|
if failed_children == len(child_configs):
|
|
overall_status = "failed"
|
|
elif failed_children > 0:
|
|
overall_status = "partial"
|
|
else:
|
|
overall_status = "completed" # Changed from "ready" to match professional strategy
|
|
|
|
# Calculate total records cloned (parent + all children)
|
|
total_records_cloned = parent_result.get("total_records", 0)
|
|
for child in children_data:
|
|
if isinstance(child, dict):
|
|
total_records_cloned += child.get("total_records", child.get("records_cloned", 0))
|
|
|
|
results["overall_status"] = overall_status
|
|
results["total_records_cloned"] = total_records_cloned # Add for session manager
|
|
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
results["rollback_stack"] = rollback_stack
|
|
|
|
# Include services from parent for session manager compatibility
|
|
results["services"] = parent_result.get("services", {})
|
|
|
|
logger.info(
|
|
"Enterprise cloning strategy completed",
|
|
session_id=context.session_id,
|
|
overall_status=overall_status,
|
|
parent_status=parent_status,
|
|
children_status=f"{len(child_configs) - failed_children}/{len(child_configs)} succeeded",
|
|
total_records_cloned=total_records_cloned,
|
|
duration_ms=results["duration_ms"]
|
|
)
|
|
|
|
return results
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Enterprise cloning strategy failed",
|
|
session_id=context.session_id,
|
|
error=str(e),
|
|
exc_info=True
|
|
)
|
|
return {
|
|
"overall_status": "failed",
|
|
"error": str(e),
|
|
"parent": {},
|
|
"children": [],
|
|
"distribution": {},
|
|
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
|
|
"rollback_stack": rollback_stack
|
|
}
|
|
|
|
class CloningStrategyFactory:
|
|
"""
|
|
Factory for creating cloning strategies
|
|
Provides type-safe strategy selection
|
|
"""
|
|
|
|
_strategies: Dict[str, CloningStrategy] = {
|
|
"professional": ProfessionalCloningStrategy(),
|
|
"enterprise": EnterpriseCloningStrategy(),
|
|
"enterprise_child": ProfessionalCloningStrategy() # Alias: children use professional strategy
|
|
}
|
|
|
|
@classmethod
|
|
def get_strategy(cls, demo_account_type: str) -> CloningStrategy:
|
|
"""
|
|
Get the appropriate cloning strategy for the demo account type
|
|
|
|
Args:
|
|
demo_account_type: Type of demo account ("professional" or "enterprise")
|
|
|
|
Returns:
|
|
CloningStrategy instance
|
|
|
|
Raises:
|
|
ValueError: If demo_account_type is not supported
|
|
"""
|
|
strategy = cls._strategies.get(demo_account_type)
|
|
|
|
if not strategy:
|
|
raise ValueError(
|
|
f"Unknown demo_account_type: {demo_account_type}. "
|
|
f"Supported types: {list(cls._strategies.keys())}"
|
|
)
|
|
|
|
return strategy
|
|
|
|
@classmethod
|
|
def register_strategy(cls, name: str, strategy: CloningStrategy):
|
|
"""
|
|
Register a custom cloning strategy
|
|
|
|
Args:
|
|
name: Strategy name
|
|
strategy: Strategy instance
|
|
"""
|
|
cls._strategies[name] = strategy
|
|
logger.info(f"Registered custom cloning strategy: {name}")
|