Files
bakery-ia/services/demo_session/app/services/cloning_strategies.py
2025-12-09 10:21:41 +01:00

605 lines
23 KiB
Python

"""
Cloning Strategy Pattern Implementation
Provides explicit, type-safe strategies for different demo account types
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
import structlog
logger = structlog.get_logger()
@dataclass
class CloningContext:
"""
Context object containing all data needed for cloning operations
Immutable to prevent state mutation bugs
"""
base_tenant_id: str
virtual_tenant_id: str
session_id: str
demo_account_type: str
session_metadata: Optional[Dict[str, Any]] = None
services_filter: Optional[List[str]] = None
# Orchestrator dependencies (injected)
orchestrator: Any = None # Will be CloneOrchestrator instance
def __post_init__(self):
"""Validate context after initialization"""
if not self.base_tenant_id:
raise ValueError("base_tenant_id is required")
if not self.virtual_tenant_id:
raise ValueError("virtual_tenant_id is required")
if not self.session_id:
raise ValueError("session_id is required")
class CloningStrategy(ABC):
"""
Abstract base class for cloning strategies
Each strategy is a leaf node - no recursion possible
"""
@abstractmethod
async def clone(self, context: CloningContext) -> Dict[str, Any]:
"""
Execute the cloning strategy
Args:
context: Immutable context with all required data
Returns:
Dictionary with cloning results
"""
pass
@abstractmethod
def get_strategy_name(self) -> str:
"""Return the name of this strategy for logging"""
pass
class ProfessionalCloningStrategy(CloningStrategy):
"""
Strategy for single-tenant professional demos
Clones all services for a single virtual tenant
"""
def get_strategy_name(self) -> str:
return "professional"
async def clone(self, context: CloningContext) -> Dict[str, Any]:
"""
Clone demo data for a professional (single-tenant) account
Process:
1. Validate context
2. Clone all services in parallel
3. Handle failures with partial success support
4. Return aggregated results
"""
logger.info(
"Executing professional cloning strategy",
session_id=context.session_id,
virtual_tenant_id=context.virtual_tenant_id,
base_tenant_id=context.base_tenant_id
)
start_time = datetime.now(timezone.utc)
# Determine which services to clone
services_to_clone = context.orchestrator.services
if context.services_filter:
services_to_clone = [
s for s in context.orchestrator.services
if s.name in context.services_filter
]
logger.info(
"Filtering services",
session_id=context.session_id,
services_filter=context.services_filter,
filtered_count=len(services_to_clone)
)
# Rollback stack for cleanup
rollback_stack = []
try:
# Import asyncio here to avoid circular imports
import asyncio
# Create parallel tasks for all services
tasks = []
service_map = {}
for service_def in services_to_clone:
task = asyncio.create_task(
context.orchestrator._clone_service(
service_def=service_def,
base_tenant_id=context.base_tenant_id,
virtual_tenant_id=context.virtual_tenant_id,
demo_account_type=context.demo_account_type,
session_id=context.session_id,
session_metadata=context.session_metadata
)
)
tasks.append(task)
service_map[task] = service_def.name
# Process tasks as they complete for real-time progress updates
service_results = {}
total_records = 0
failed_services = []
required_service_failed = False
completed_count = 0
total_count = len(tasks)
# Create a mapping from futures to service names to properly identify completed tasks
# We'll use asyncio.wait approach instead of as_completed to access the original tasks
pending = set(tasks)
completed_tasks_info = {task: service_map[task] for task in tasks} # Map tasks to service names
while pending:
# Wait for at least one task to complete
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
# Process each completed task
for completed_task in done:
try:
# Get the result from the completed task
result = await completed_task
# Get the service name from our mapping
service_name = completed_tasks_info[completed_task]
service_def = next(s for s in services_to_clone if s.name == service_name)
service_results[service_name] = result
completed_count += 1
if result.get("status") == "failed":
failed_services.append(service_name)
if service_def.required:
required_service_failed = True
else:
total_records += result.get("records_cloned", 0)
# Track successful services for rollback
if result.get("status") == "completed":
rollback_stack.append({
"type": "service",
"service_name": service_name,
"tenant_id": context.virtual_tenant_id,
"session_id": context.session_id
})
# Update Redis with granular progress after each service completes
await context.orchestrator._update_progress_in_redis(context.session_id, {
"completed_services": completed_count,
"total_services": total_count,
"progress_percentage": int((completed_count / total_count) * 100),
"services": service_results,
"total_records_cloned": total_records
})
logger.info(
f"Service {service_name} completed ({completed_count}/{total_count})",
session_id=context.session_id,
records_cloned=result.get("records_cloned", 0)
)
except Exception as e:
# Handle exceptions from the task itself
service_name = completed_tasks_info[completed_task]
service_def = next(s for s in services_to_clone if s.name == service_name)
logger.error(
f"Service {service_name} cloning failed with exception",
session_id=context.session_id,
error=str(e)
)
service_results[service_name] = {
"status": "failed",
"error": str(e),
"records_cloned": 0
}
failed_services.append(service_name)
completed_count += 1
if service_def.required:
required_service_failed = True
# Determine overall status
if required_service_failed:
overall_status = "failed"
elif failed_services:
overall_status = "partial"
else:
overall_status = "completed"
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
"Professional cloning strategy completed",
session_id=context.session_id,
overall_status=overall_status,
total_records=total_records,
failed_services=failed_services,
duration_ms=duration_ms
)
return {
"overall_status": overall_status,
"services": service_results,
"total_records": total_records,
"failed_services": failed_services,
"duration_ms": duration_ms,
"rollback_stack": rollback_stack
}
except Exception as e:
logger.error(
"Professional cloning strategy failed",
session_id=context.session_id,
error=str(e),
exc_info=True
)
return {
"overall_status": "failed",
"error": str(e),
"services": {},
"total_records": 0,
"failed_services": [],
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"rollback_stack": rollback_stack
}
class EnterpriseCloningStrategy(CloningStrategy):
"""
Strategy for multi-tenant enterprise demos
Clones parent tenant + child tenants + distribution data
"""
def get_strategy_name(self) -> str:
return "enterprise"
async def clone(self, context: CloningContext) -> Dict[str, Any]:
"""
Clone demo data for an enterprise (multi-tenant) account
Process:
1. Validate enterprise metadata
2. Clone parent tenant using ProfessionalCloningStrategy
3. Clone child tenants in parallel
4. Update distribution data with child mappings
5. Return aggregated results
NOTE: No recursion - uses ProfessionalCloningStrategy as a helper
"""
logger.info(
"Executing enterprise cloning strategy",
session_id=context.session_id,
parent_tenant_id=context.virtual_tenant_id,
base_tenant_id=context.base_tenant_id
)
start_time = datetime.now(timezone.utc)
results = {
"parent": {},
"children": [],
"distribution": {},
"overall_status": "pending"
}
rollback_stack = []
try:
# Validate enterprise metadata
if not context.session_metadata:
raise ValueError("Enterprise cloning requires session_metadata")
is_enterprise = context.session_metadata.get("is_enterprise", False)
child_configs = context.session_metadata.get("child_configs", [])
child_tenant_ids = context.session_metadata.get("child_tenant_ids", [])
if not is_enterprise:
raise ValueError("session_metadata.is_enterprise must be True")
if not child_configs or not child_tenant_ids:
raise ValueError("Enterprise metadata missing child_configs or child_tenant_ids")
logger.info(
"Enterprise metadata validated",
session_id=context.session_id,
child_count=len(child_configs)
)
# Phase 1: Clone parent tenant
logger.info("Phase 1: Cloning parent tenant", session_id=context.session_id)
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": {"overall_status": "pending"},
"children": [],
"distribution": {}
})
# Use ProfessionalCloningStrategy to clone parent
# This is composition, not recursion - explicit strategy usage
professional_strategy = ProfessionalCloningStrategy()
parent_context = CloningContext(
base_tenant_id=context.base_tenant_id,
virtual_tenant_id=context.virtual_tenant_id,
session_id=context.session_id,
demo_account_type="enterprise", # Explicit type for parent tenant
session_metadata=context.session_metadata,
orchestrator=context.orchestrator
)
parent_result = await professional_strategy.clone(parent_context)
results["parent"] = parent_result
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": [],
"distribution": {}
})
# Track parent for rollback
if parent_result.get("overall_status") not in ["failed"]:
rollback_stack.append({
"type": "tenant",
"tenant_id": context.virtual_tenant_id,
"session_id": context.session_id
})
# Validate parent success
parent_status = parent_result.get("overall_status")
if parent_status == "failed":
logger.error(
"Parent cloning failed, aborting enterprise demo",
session_id=context.session_id,
failed_services=parent_result.get("failed_services", [])
)
results["overall_status"] = "failed"
results["error"] = "Parent tenant cloning failed"
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
return results
if parent_status == "partial":
# Check if tenant service succeeded (critical)
parent_services = parent_result.get("services", {})
if parent_services.get("tenant", {}).get("status") != "completed":
logger.error(
"Tenant service failed in parent, cannot create children",
session_id=context.session_id
)
results["overall_status"] = "failed"
results["error"] = "Parent tenant creation failed - cannot create child tenants"
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
return results
logger.info(
"Parent cloning succeeded, proceeding with children",
session_id=context.session_id,
parent_status=parent_status
)
# Phase 2: Clone child tenants in parallel
logger.info(
"Phase 2: Cloning child outlets",
session_id=context.session_id,
child_count=len(child_configs)
)
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": [{"status": "pending"} for _ in child_configs],
"distribution": {}
})
# Import asyncio for parallel execution
import asyncio
child_tasks = []
for idx, (child_config, child_id) in enumerate(zip(child_configs, child_tenant_ids)):
task = context.orchestrator._clone_child_outlet(
base_tenant_id=child_config.get("base_tenant_id"),
virtual_child_id=child_id,
parent_tenant_id=context.virtual_tenant_id,
child_name=child_config.get("name"),
location=child_config.get("location"),
session_id=context.session_id
)
child_tasks.append(task)
child_results = await asyncio.gather(*child_tasks, return_exceptions=True)
# Process child results
children_data = []
failed_children = 0
for idx, result in enumerate(child_results):
if isinstance(result, Exception):
logger.error(
f"Child {idx} cloning failed",
session_id=context.session_id,
error=str(result)
)
children_data.append({
"status": "failed",
"error": str(result),
"child_id": child_tenant_ids[idx] if idx < len(child_tenant_ids) else None
})
failed_children += 1
else:
children_data.append(result)
if result.get("overall_status") == "failed":
failed_children += 1
else:
# Track for rollback
rollback_stack.append({
"type": "tenant",
"tenant_id": result.get("child_id"),
"session_id": context.session_id
})
results["children"] = children_data
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": children_data,
"distribution": {}
})
logger.info(
"Child cloning completed",
session_id=context.session_id,
total_children=len(child_configs),
failed_children=failed_children
)
# Phase 3: Clone distribution data
logger.info("Phase 3: Cloning distribution data", session_id=context.session_id)
# Find distribution service definition
dist_service_def = next(
(s for s in context.orchestrator.services if s.name == "distribution"),
None
)
if dist_service_def:
dist_result = await context.orchestrator._clone_service(
service_def=dist_service_def,
base_tenant_id=context.base_tenant_id,
virtual_tenant_id=context.virtual_tenant_id,
demo_account_type="enterprise",
session_id=context.session_id,
session_metadata=context.session_metadata
)
results["distribution"] = dist_result
# Update progress
await context.orchestrator._update_progress_in_redis(context.session_id, {
"parent": parent_result,
"children": children_data,
"distribution": dist_result
})
# Track for rollback
if dist_result.get("status") == "completed":
rollback_stack.append({
"type": "service",
"service_name": "distribution",
"tenant_id": context.virtual_tenant_id,
"session_id": context.session_id
})
total_records_cloned = parent_result.get("total_records", 0)
total_records_cloned += dist_result.get("records_cloned", 0)
else:
logger.warning("Distribution service not found in orchestrator", session_id=context.session_id)
# Determine overall status
if failed_children == len(child_configs):
overall_status = "failed"
elif failed_children > 0:
overall_status = "partial"
else:
overall_status = "completed" # Changed from "ready" to match professional strategy
# Calculate total records cloned (parent + all children)
total_records_cloned = parent_result.get("total_records", 0)
for child in children_data:
if isinstance(child, dict):
total_records_cloned += child.get("total_records", child.get("records_cloned", 0))
results["overall_status"] = overall_status
results["total_records_cloned"] = total_records_cloned # Add for session manager
results["duration_ms"] = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
results["rollback_stack"] = rollback_stack
# Include services from parent for session manager compatibility
results["services"] = parent_result.get("services", {})
logger.info(
"Enterprise cloning strategy completed",
session_id=context.session_id,
overall_status=overall_status,
parent_status=parent_status,
children_status=f"{len(child_configs) - failed_children}/{len(child_configs)} succeeded",
total_records_cloned=total_records_cloned,
duration_ms=results["duration_ms"]
)
return results
except Exception as e:
logger.error(
"Enterprise cloning strategy failed",
session_id=context.session_id,
error=str(e),
exc_info=True
)
return {
"overall_status": "failed",
"error": str(e),
"parent": {},
"children": [],
"distribution": {},
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"rollback_stack": rollback_stack
}
class CloningStrategyFactory:
"""
Factory for creating cloning strategies
Provides type-safe strategy selection
"""
_strategies: Dict[str, CloningStrategy] = {
"professional": ProfessionalCloningStrategy(),
"enterprise": EnterpriseCloningStrategy(),
"enterprise_child": ProfessionalCloningStrategy() # Alias: children use professional strategy
}
@classmethod
def get_strategy(cls, demo_account_type: str) -> CloningStrategy:
"""
Get the appropriate cloning strategy for the demo account type
Args:
demo_account_type: Type of demo account ("professional" or "enterprise")
Returns:
CloningStrategy instance
Raises:
ValueError: If demo_account_type is not supported
"""
strategy = cls._strategies.get(demo_account_type)
if not strategy:
raise ValueError(
f"Unknown demo_account_type: {demo_account_type}. "
f"Supported types: {list(cls._strategies.keys())}"
)
return strategy
@classmethod
def register_strategy(cls, name: str, strategy: CloningStrategy):
"""
Register a custom cloning strategy
Args:
name: Strategy name
strategy: Strategy instance
"""
cls._strategies[name] = strategy
logger.info(f"Registered custom cloning strategy: {name}")