""" Cloning Strategy Pattern Implementation Provides explicit, type-safe strategies for different demo account types """ from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Dict, Any, List, Optional from datetime import datetime, timezone import structlog logger = structlog.get_logger() @dataclass class CloningContext: """ Context object containing all data needed for cloning operations Immutable to prevent state mutation bugs """ base_tenant_id: str virtual_tenant_id: str session_id: str demo_account_type: str session_metadata: Optional[Dict[str, Any]] = None services_filter: Optional[List[str]] = None # Orchestrator dependencies (injected) orchestrator: Any = None # Will be CloneOrchestrator instance def __post_init__(self): """Validate context after initialization""" if not self.base_tenant_id: raise ValueError("base_tenant_id is required") if not self.virtual_tenant_id: raise ValueError("virtual_tenant_id is required") if not self.session_id: raise ValueError("session_id is required") class CloningStrategy(ABC): """ Abstract base class for cloning strategies Each strategy is a leaf node - no recursion possible """ @abstractmethod async def clone(self, context: CloningContext) -> Dict[str, Any]: """ Execute the cloning strategy Args: context: Immutable context with all required data Returns: Dictionary with cloning results """ pass @abstractmethod def get_strategy_name(self) -> str: """Return the name of this strategy for logging""" pass class ProfessionalCloningStrategy(CloningStrategy): """ Strategy for single-tenant professional demos Clones all services for a single virtual tenant """ def get_strategy_name(self) -> str: return "professional" async def clone(self, context: CloningContext) -> Dict[str, Any]: """ Clone demo data for a professional (single-tenant) account Process: 1. Validate context 2. Clone all services in parallel 3. Handle failures with partial success support 4. Return aggregated results """ logger.info( "Executing professional cloning strategy", session_id=context.session_id, virtual_tenant_id=context.virtual_tenant_id, base_tenant_id=context.base_tenant_id ) start_time = datetime.now(timezone.utc) # Determine which services to clone services_to_clone = context.orchestrator.services if context.services_filter: services_to_clone = [ s for s in context.orchestrator.services if s.name in context.services_filter ] logger.info( "Filtering services", session_id=context.session_id, services_filter=context.services_filter, filtered_count=len(services_to_clone) ) # Rollback stack for cleanup rollback_stack = [] try: # Import asyncio here to avoid circular imports import asyncio # Create parallel tasks for all services tasks = [] service_map = {} for service_def in services_to_clone: task = asyncio.create_task( context.orchestrator._clone_service( service_def=service_def, base_tenant_id=context.base_tenant_id, virtual_tenant_id=context.virtual_tenant_id, demo_account_type=context.demo_account_type, session_id=context.session_id, session_metadata=context.session_metadata ) ) tasks.append(task) service_map[task] = service_def.name # Process tasks as they complete for real-time progress updates service_results = {} total_records = 0 failed_services = [] required_service_failed = False completed_count = 0 total_count = len(tasks) # Create a mapping from futures to service names to properly identify completed tasks # We'll use asyncio.wait approach instead of as_completed to access the original tasks pending = set(tasks) completed_tasks_info = {task: service_map[task] for task in tasks} # Map tasks to service names while pending: # Wait for at least one task to complete done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) # Process each completed task for completed_task in done: try: # Get the result from the completed task result = await completed_task # Get the service name from our mapping service_name = completed_tasks_info[completed_task] service_def = next(s for s in services_to_clone if s.name == service_name) service_results[service_name] = result completed_count += 1 if result.get("status") == "failed": failed_services.append(service_name) if service_def.required: required_service_failed = True else: total_records += result.get("records_cloned", 0) # Track successful services for rollback if result.get("status") == "completed": rollback_stack.append({ "type": "service", "service_name": service_name, "tenant_id": context.virtual_tenant_id, "session_id": context.session_id }) # Update Redis with granular progress after each service completes await context.orchestrator._update_progress_in_redis(context.session_id, { "completed_services": completed_count, "total_services": total_count, "progress_percentage": int((completed_count / total_count) * 100), "services": service_results, "total_records_cloned": total_records }) logger.info( f"Service {service_name} completed ({completed_count}/{total_count})", session_id=context.session_id, records_cloned=result.get("records_cloned", 0) ) except Exception as e: # Handle exceptions from the task itself service_name = completed_tasks_info[completed_task] service_def = next(s for s in services_to_clone if s.name == service_name) logger.error( f"Service {service_name} cloning failed with exception", session_id=context.session_id, error=str(e) ) service_results[service_name] = { "status": "failed", "error": str(e), "records_cloned": 0 } failed_services.append(service_name) completed_count += 1 if service_def.required: required_service_failed = True # Determine overall status if required_service_failed: overall_status = "failed" elif failed_services: overall_status = "partial" else: overall_status = "completed" duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000) logger.info( "Professional cloning strategy completed", session_id=context.session_id, overall_status=overall_status, total_records=total_records, failed_services=failed_services, duration_ms=duration_ms ) return { "overall_status": overall_status, "services": service_results, "total_records": total_records, "failed_services": failed_services, "duration_ms": duration_ms, "rollback_stack": rollback_stack } except Exception as e: logger.error( "Professional cloning strategy failed", session_id=context.session_id, error=str(e), exc_info=True ) return { "overall_status": "failed", "error": str(e), "services": {}, "total_records": 0, "failed_services": [], "duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000), "rollback_stack": rollback_stack } class EnterpriseCloningStrategy(CloningStrategy): """ Strategy for multi-tenant enterprise demos Clones parent tenant + child tenants + distribution data """ def get_strategy_name(self) -> str: return "enterprise" async def clone(self, context: CloningContext) -> Dict[str, Any]: """ Clone demo data for an enterprise (multi-tenant) account Process: 1. Validate enterprise metadata 2. Clone parent tenant using ProfessionalCloningStrategy 3. Clone child tenants in parallel 4. Update distribution data with child mappings 5. Return aggregated results NOTE: No recursion - uses ProfessionalCloningStrategy as a helper """ logger.info( "Executing enterprise cloning strategy", session_id=context.session_id, parent_tenant_id=context.virtual_tenant_id, base_tenant_id=context.base_tenant_id ) start_time = datetime.now(timezone.utc) results = { "parent": {}, "children": [], "distribution": {}, "overall_status": "pending" } rollback_stack = [] try: # Validate enterprise metadata if not context.session_metadata: raise ValueError("Enterprise cloning requires session_metadata") is_enterprise = context.session_metadata.get("is_enterprise", False) child_configs = context.session_metadata.get("child_configs", []) child_tenant_ids = context.session_metadata.get("child_tenant_ids", []) if not is_enterprise: raise ValueError("session_metadata.is_enterprise must be True") if not child_configs or not child_tenant_ids: raise ValueError("Enterprise metadata missing child_configs or child_tenant_ids") logger.info( "Enterprise metadata validated", session_id=context.session_id, child_count=len(child_configs) ) # Phase 1: Clone parent tenant logger.info("Phase 1: Cloning parent tenant", session_id=context.session_id) # Update progress await context.orchestrator._update_progress_in_redis(context.session_id, { "parent": {"overall_status": "pending"}, "children": [], "distribution": {} }) # Use ProfessionalCloningStrategy to clone parent # This is composition, not recursion - explicit strategy usage professional_strategy = ProfessionalCloningStrategy() parent_context = CloningContext( base_tenant_id=context.base_tenant_id, virtual_tenant_id=context.virtual_tenant_id, session_id=context.session_id, demo_account_type="enterprise", # Explicit type for parent tenant session_metadata=context.session_metadata, orchestrator=context.orchestrator ) parent_result = await professional_strategy.clone(parent_context) results["parent"] = parent_result # Update progress await context.orchestrator._update_progress_in_redis(context.session_id, { "parent": parent_result, "children": [], "distribution": {} }) # Track parent for rollback if parent_result.get("overall_status") not in ["failed"]: rollback_stack.append({ "type": "tenant", "tenant_id": context.virtual_tenant_id, "session_id": context.session_id }) # Validate parent success parent_status = parent_result.get("overall_status") if parent_status == "failed": logger.error( "Parent cloning failed, aborting enterprise demo", session_id=context.session_id, failed_services=parent_result.get("failed_services", []) ) results["overall_status"] = "failed" results["error"] = "Parent tenant cloning failed" results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000) return results if parent_status == "partial": # Check if tenant service succeeded (critical) parent_services = parent_result.get("services", {}) if parent_services.get("tenant", {}).get("status") != "completed": logger.error( "Tenant service failed in parent, cannot create children", session_id=context.session_id ) results["overall_status"] = "failed" results["error"] = "Parent tenant creation failed - cannot create child tenants" results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000) return results logger.info( "Parent cloning succeeded, proceeding with children", session_id=context.session_id, parent_status=parent_status ) # Phase 2: Clone child tenants in parallel logger.info( "Phase 2: Cloning child outlets", session_id=context.session_id, child_count=len(child_configs) ) # Update progress await context.orchestrator._update_progress_in_redis(context.session_id, { "parent": parent_result, "children": [{"status": "pending"} for _ in child_configs], "distribution": {} }) # Import asyncio for parallel execution import asyncio child_tasks = [] for idx, (child_config, child_id) in enumerate(zip(child_configs, child_tenant_ids)): task = context.orchestrator._clone_child_outlet( base_tenant_id=child_config.get("base_tenant_id"), virtual_child_id=child_id, parent_tenant_id=context.virtual_tenant_id, child_name=child_config.get("name"), location=child_config.get("location"), session_id=context.session_id ) child_tasks.append(task) child_results = await asyncio.gather(*child_tasks, return_exceptions=True) # Process child results children_data = [] failed_children = 0 for idx, result in enumerate(child_results): if isinstance(result, Exception): logger.error( f"Child {idx} cloning failed", session_id=context.session_id, error=str(result) ) children_data.append({ "status": "failed", "error": str(result), "child_id": child_tenant_ids[idx] if idx < len(child_tenant_ids) else None }) failed_children += 1 else: children_data.append(result) if result.get("overall_status") == "failed": failed_children += 1 else: # Track for rollback rollback_stack.append({ "type": "tenant", "tenant_id": result.get("child_id"), "session_id": context.session_id }) results["children"] = children_data # Update progress await context.orchestrator._update_progress_in_redis(context.session_id, { "parent": parent_result, "children": children_data, "distribution": {} }) logger.info( "Child cloning completed", session_id=context.session_id, total_children=len(child_configs), failed_children=failed_children ) # Phase 3: Clone distribution data logger.info("Phase 3: Cloning distribution data", session_id=context.session_id) # Find distribution service definition dist_service_def = next( (s for s in context.orchestrator.services if s.name == "distribution"), None ) if dist_service_def: dist_result = await context.orchestrator._clone_service( service_def=dist_service_def, base_tenant_id=context.base_tenant_id, virtual_tenant_id=context.virtual_tenant_id, demo_account_type="enterprise", session_id=context.session_id, session_metadata=context.session_metadata ) results["distribution"] = dist_result # Update progress await context.orchestrator._update_progress_in_redis(context.session_id, { "parent": parent_result, "children": children_data, "distribution": dist_result }) # Track for rollback if dist_result.get("status") == "completed": rollback_stack.append({ "type": "service", "service_name": "distribution", "tenant_id": context.virtual_tenant_id, "session_id": context.session_id }) total_records_cloned = parent_result.get("total_records", 0) total_records_cloned += dist_result.get("records_cloned", 0) else: logger.warning("Distribution service not found in orchestrator", session_id=context.session_id) # Determine overall status if failed_children == len(child_configs): overall_status = "failed" elif failed_children > 0: overall_status = "partial" else: overall_status = "completed" # Changed from "ready" to match professional strategy # Calculate total records cloned (parent + all children) total_records_cloned = parent_result.get("total_records", 0) for child in children_data: if isinstance(child, dict): total_records_cloned += child.get("total_records", child.get("records_cloned", 0)) results["overall_status"] = overall_status results["total_records_cloned"] = total_records_cloned # Add for session manager results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000) results["rollback_stack"] = rollback_stack # Include services from parent for session manager compatibility results["services"] = parent_result.get("services", {}) logger.info( "Enterprise cloning strategy completed", session_id=context.session_id, overall_status=overall_status, parent_status=parent_status, children_status=f"{len(child_configs) - failed_children}/{len(child_configs)} succeeded", total_records_cloned=total_records_cloned, duration_ms=results["duration_ms"] ) return results except Exception as e: logger.error( "Enterprise cloning strategy failed", session_id=context.session_id, error=str(e), exc_info=True ) return { "overall_status": "failed", "error": str(e), "parent": {}, "children": [], "distribution": {}, "duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000), "rollback_stack": rollback_stack } class CloningStrategyFactory: """ Factory for creating cloning strategies Provides type-safe strategy selection """ _strategies: Dict[str, CloningStrategy] = { "professional": ProfessionalCloningStrategy(), "enterprise": EnterpriseCloningStrategy(), "enterprise_child": ProfessionalCloningStrategy() # Alias: children use professional strategy } @classmethod def get_strategy(cls, demo_account_type: str) -> CloningStrategy: """ Get the appropriate cloning strategy for the demo account type Args: demo_account_type: Type of demo account ("professional" or "enterprise") Returns: CloningStrategy instance Raises: ValueError: If demo_account_type is not supported """ strategy = cls._strategies.get(demo_account_type) if not strategy: raise ValueError( f"Unknown demo_account_type: {demo_account_type}. " f"Supported types: {list(cls._strategies.keys())}" ) return strategy @classmethod def register_strategy(cls, name: str, strategy: CloningStrategy): """ Register a custom cloning strategy Args: name: Strategy name strategy: Strategy instance """ cls._strategies[name] = strategy logger.info(f"Registered custom cloning strategy: {name}")