Improve the demo feature of the project
This commit is contained in:
@@ -27,31 +27,55 @@ class DemoCleanupService:
|
||||
async def cleanup_expired_sessions(self) -> dict:
|
||||
"""
|
||||
Find and cleanup all expired sessions
|
||||
Also cleans up sessions stuck in PENDING for too long (>5 minutes)
|
||||
|
||||
Returns:
|
||||
Cleanup statistics
|
||||
"""
|
||||
from datetime import timedelta
|
||||
|
||||
logger.info("Starting demo session cleanup")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
stuck_threshold = now - timedelta(minutes=5) # Sessions pending > 5 min are stuck
|
||||
|
||||
# Find expired sessions
|
||||
# Find expired sessions (any status except EXPIRED and DESTROYED)
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(
|
||||
DemoSession.status == DemoSessionStatus.ACTIVE,
|
||||
DemoSession.status.in_([
|
||||
DemoSessionStatus.PENDING,
|
||||
DemoSessionStatus.READY,
|
||||
DemoSessionStatus.PARTIAL,
|
||||
DemoSessionStatus.FAILED,
|
||||
DemoSessionStatus.ACTIVE # Legacy status, kept for compatibility
|
||||
]),
|
||||
DemoSession.expires_at < now
|
||||
)
|
||||
)
|
||||
expired_sessions = result.scalars().all()
|
||||
|
||||
# Also find sessions stuck in PENDING
|
||||
stuck_result = await self.db.execute(
|
||||
select(DemoSession).where(
|
||||
DemoSession.status == DemoSessionStatus.PENDING,
|
||||
DemoSession.created_at < stuck_threshold
|
||||
)
|
||||
)
|
||||
stuck_sessions = stuck_result.scalars().all()
|
||||
|
||||
# Combine both lists
|
||||
all_sessions_to_cleanup = list(expired_sessions) + list(stuck_sessions)
|
||||
|
||||
stats = {
|
||||
"total_expired": len(expired_sessions),
|
||||
"total_stuck": len(stuck_sessions),
|
||||
"total_to_cleanup": len(all_sessions_to_cleanup),
|
||||
"cleaned_up": 0,
|
||||
"failed": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
for session in expired_sessions:
|
||||
for session in all_sessions_to_cleanup:
|
||||
try:
|
||||
# Mark as expired
|
||||
session.status = DemoSessionStatus.EXPIRED
|
||||
@@ -128,6 +152,11 @@ class DemoCleanupService:
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Count by status
|
||||
pending_count = len([s for s in all_sessions if s.status == DemoSessionStatus.PENDING])
|
||||
ready_count = len([s for s in all_sessions if s.status == DemoSessionStatus.READY])
|
||||
partial_count = len([s for s in all_sessions if s.status == DemoSessionStatus.PARTIAL])
|
||||
failed_count = len([s for s in all_sessions if s.status == DemoSessionStatus.FAILED])
|
||||
active_count = len([s for s in all_sessions if s.status == DemoSessionStatus.ACTIVE])
|
||||
expired_count = len([s for s in all_sessions if s.status == DemoSessionStatus.EXPIRED])
|
||||
destroyed_count = len([s for s in all_sessions if s.status == DemoSessionStatus.DESTROYED])
|
||||
@@ -135,13 +164,25 @@ class DemoCleanupService:
|
||||
# Find sessions that should be expired but aren't marked yet
|
||||
should_be_expired = len([
|
||||
s for s in all_sessions
|
||||
if s.status == DemoSessionStatus.ACTIVE and s.expires_at < now
|
||||
if s.status in [
|
||||
DemoSessionStatus.PENDING,
|
||||
DemoSessionStatus.READY,
|
||||
DemoSessionStatus.PARTIAL,
|
||||
DemoSessionStatus.FAILED,
|
||||
DemoSessionStatus.ACTIVE
|
||||
] and s.expires_at < now
|
||||
])
|
||||
|
||||
return {
|
||||
"total_sessions": len(all_sessions),
|
||||
"active_sessions": active_count,
|
||||
"expired_sessions": expired_count,
|
||||
"destroyed_sessions": destroyed_count,
|
||||
"by_status": {
|
||||
"pending": pending_count,
|
||||
"ready": ready_count,
|
||||
"partial": partial_count,
|
||||
"failed": failed_count,
|
||||
"active": active_count, # Legacy
|
||||
"expired": expired_count,
|
||||
"destroyed": destroyed_count
|
||||
},
|
||||
"pending_cleanup": should_be_expired
|
||||
}
|
||||
|
||||
330
services/demo_session/app/services/clone_orchestrator.py
Normal file
330
services/demo_session/app/services/clone_orchestrator.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Demo Data Cloning Orchestrator
|
||||
Coordinates asynchronous cloning across microservices
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import httpx
|
||||
import structlog
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, List, Optional
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
from app.models.demo_session import CloningStatus
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ServiceDefinition:
|
||||
"""Definition of a service that can clone demo data"""
|
||||
|
||||
def __init__(self, name: str, url: str, required: bool = True, timeout: float = 10.0):
|
||||
self.name = name
|
||||
self.url = url
|
||||
self.required = required # If True, failure blocks session creation
|
||||
self.timeout = timeout
|
||||
|
||||
|
||||
class CloneOrchestrator:
|
||||
"""Orchestrates parallel demo data cloning across services"""
|
||||
|
||||
def __init__(self):
|
||||
self.internal_api_key = os.getenv("INTERNAL_API_KEY", "dev-internal-key-change-in-production")
|
||||
|
||||
# Define services that participate in cloning
|
||||
# URLs should be internal Kubernetes service names
|
||||
self.services = [
|
||||
ServiceDefinition(
|
||||
name="tenant",
|
||||
url=os.getenv("TENANT_SERVICE_URL", "http://tenant-service:8000"),
|
||||
required=True, # Tenant must succeed - critical for session
|
||||
timeout=5.0
|
||||
),
|
||||
ServiceDefinition(
|
||||
name="inventory",
|
||||
url=os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000"),
|
||||
required=False, # Optional - provides ingredients/recipes
|
||||
timeout=10.0
|
||||
),
|
||||
ServiceDefinition(
|
||||
name="recipes",
|
||||
url=os.getenv("RECIPES_SERVICE_URL", "http://recipes-service:8000"),
|
||||
required=False, # Optional - provides recipes and production batches
|
||||
timeout=15.0
|
||||
),
|
||||
ServiceDefinition(
|
||||
name="suppliers",
|
||||
url=os.getenv("SUPPLIERS_SERVICE_URL", "http://suppliers-service:8000"),
|
||||
required=False, # Optional - provides supplier data and purchase orders
|
||||
timeout=20.0 # Longer - clones many entities
|
||||
),
|
||||
ServiceDefinition(
|
||||
name="sales",
|
||||
url=os.getenv("SALES_SERVICE_URL", "http://sales-service:8000"),
|
||||
required=False, # Optional - provides sales history
|
||||
timeout=10.0
|
||||
),
|
||||
ServiceDefinition(
|
||||
name="orders",
|
||||
url=os.getenv("ORDERS_SERVICE_URL", "http://orders-service:8000"),
|
||||
required=False, # Optional - provides customer orders & procurement
|
||||
timeout=15.0 # Slightly longer - clones more entities
|
||||
),
|
||||
ServiceDefinition(
|
||||
name="production",
|
||||
url=os.getenv("PRODUCTION_SERVICE_URL", "http://production-service:8000"),
|
||||
required=False, # Optional - provides production batches and quality checks
|
||||
timeout=20.0 # Longer - clones many entities
|
||||
),
|
||||
ServiceDefinition(
|
||||
name="forecasting",
|
||||
url=os.getenv("FORECASTING_SERVICE_URL", "http://forecasting-service:8000"),
|
||||
required=False, # Optional - provides historical forecasts
|
||||
timeout=15.0
|
||||
),
|
||||
]
|
||||
|
||||
async def clone_all_services(
|
||||
self,
|
||||
base_tenant_id: str,
|
||||
virtual_tenant_id: str,
|
||||
demo_account_type: str,
|
||||
session_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Orchestrate cloning across all services in parallel
|
||||
|
||||
Args:
|
||||
base_tenant_id: Template tenant UUID
|
||||
virtual_tenant_id: Target virtual tenant UUID
|
||||
demo_account_type: Type of demo account
|
||||
session_id: Session ID for tracing
|
||||
|
||||
Returns:
|
||||
Dictionary with overall status and per-service results
|
||||
"""
|
||||
logger.info(
|
||||
"Starting orchestrated cloning",
|
||||
session_id=session_id,
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
demo_account_type=demo_account_type,
|
||||
service_count=len(self.services)
|
||||
)
|
||||
|
||||
start_time = datetime.now(timezone.utc)
|
||||
|
||||
# Create tasks for all services
|
||||
tasks = []
|
||||
service_map = {}
|
||||
|
||||
for service_def in self.services:
|
||||
task = asyncio.create_task(
|
||||
self._clone_service(
|
||||
service_def=service_def,
|
||||
base_tenant_id=base_tenant_id,
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
demo_account_type=demo_account_type,
|
||||
session_id=session_id
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
service_map[task] = service_def.name
|
||||
|
||||
# Wait for all tasks to complete (with individual timeouts)
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process results
|
||||
service_results = {}
|
||||
total_records = 0
|
||||
failed_services = []
|
||||
required_service_failed = False
|
||||
|
||||
for task, result in zip(tasks, results):
|
||||
service_name = service_map[task]
|
||||
service_def = next(s for s in self.services if s.name == service_name)
|
||||
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"Service cloning failed with exception",
|
||||
service=service_name,
|
||||
error=str(result)
|
||||
)
|
||||
service_results[service_name] = {
|
||||
"status": CloningStatus.FAILED.value,
|
||||
"records_cloned": 0,
|
||||
"error": str(result),
|
||||
"duration_ms": 0
|
||||
}
|
||||
failed_services.append(service_name)
|
||||
if service_def.required:
|
||||
required_service_failed = True
|
||||
else:
|
||||
service_results[service_name] = result
|
||||
if result.get("status") == "completed":
|
||||
total_records += result.get("records_cloned", 0)
|
||||
elif result.get("status") == "failed":
|
||||
failed_services.append(service_name)
|
||||
if service_def.required:
|
||||
required_service_failed = True
|
||||
|
||||
# Determine overall status
|
||||
if required_service_failed:
|
||||
overall_status = "failed"
|
||||
elif failed_services:
|
||||
overall_status = "partial"
|
||||
else:
|
||||
overall_status = "ready"
|
||||
|
||||
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
||||
|
||||
result = {
|
||||
"overall_status": overall_status,
|
||||
"total_records_cloned": total_records,
|
||||
"duration_ms": duration_ms,
|
||||
"services": service_results,
|
||||
"failed_services": failed_services,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Orchestrated cloning completed",
|
||||
session_id=session_id,
|
||||
overall_status=overall_status,
|
||||
total_records=total_records,
|
||||
duration_ms=duration_ms,
|
||||
failed_services=failed_services
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _clone_service(
|
||||
self,
|
||||
service_def: ServiceDefinition,
|
||||
base_tenant_id: str,
|
||||
virtual_tenant_id: str,
|
||||
demo_account_type: str,
|
||||
session_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Clone data from a single service
|
||||
|
||||
Args:
|
||||
service_def: Service definition
|
||||
base_tenant_id: Template tenant UUID
|
||||
virtual_tenant_id: Target virtual tenant UUID
|
||||
demo_account_type: Type of demo account
|
||||
session_id: Session ID for tracing
|
||||
|
||||
Returns:
|
||||
Cloning result for this service
|
||||
"""
|
||||
logger.info(
|
||||
"Cloning service data",
|
||||
service=service_def.name,
|
||||
url=service_def.url,
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=service_def.timeout) as client:
|
||||
response = await client.post(
|
||||
f"{service_def.url}/internal/demo/clone",
|
||||
params={
|
||||
"base_tenant_id": base_tenant_id,
|
||||
"virtual_tenant_id": virtual_tenant_id,
|
||||
"demo_account_type": demo_account_type,
|
||||
"session_id": session_id
|
||||
},
|
||||
headers={
|
||||
"X-Internal-API-Key": self.internal_api_key
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
logger.info(
|
||||
"Service cloning succeeded",
|
||||
service=service_def.name,
|
||||
records=result.get("records_cloned", 0),
|
||||
duration_ms=result.get("duration_ms", 0)
|
||||
)
|
||||
return result
|
||||
else:
|
||||
error_msg = f"HTTP {response.status_code}: {response.text}"
|
||||
logger.error(
|
||||
"Service cloning failed",
|
||||
service=service_def.name,
|
||||
error=error_msg
|
||||
)
|
||||
return {
|
||||
"service": service_def.name,
|
||||
"status": "failed",
|
||||
"records_cloned": 0,
|
||||
"error": error_msg,
|
||||
"duration_ms": 0
|
||||
}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
error_msg = f"Timeout after {service_def.timeout}s"
|
||||
logger.error(
|
||||
"Service cloning timeout",
|
||||
service=service_def.name,
|
||||
timeout=service_def.timeout
|
||||
)
|
||||
return {
|
||||
"service": service_def.name,
|
||||
"status": "failed",
|
||||
"records_cloned": 0,
|
||||
"error": error_msg,
|
||||
"duration_ms": int(service_def.timeout * 1000)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Service cloning exception",
|
||||
service=service_def.name,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"service": service_def.name,
|
||||
"status": "failed",
|
||||
"records_cloned": 0,
|
||||
"error": str(e),
|
||||
"duration_ms": 0
|
||||
}
|
||||
|
||||
async def health_check_services(self) -> Dict[str, bool]:
|
||||
"""
|
||||
Check health of all cloning endpoints
|
||||
|
||||
Returns:
|
||||
Dictionary mapping service names to availability status
|
||||
"""
|
||||
tasks = []
|
||||
service_names = []
|
||||
|
||||
for service_def in self.services:
|
||||
task = asyncio.create_task(self._check_service_health(service_def))
|
||||
tasks.append(task)
|
||||
service_names.append(service_def.name)
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
return {
|
||||
name: (result is True)
|
||||
for name, result in zip(service_names, results)
|
||||
}
|
||||
|
||||
async def _check_service_health(self, service_def: ServiceDefinition) -> bool:
|
||||
"""Check if a service's clone endpoint is available"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=2.0) as client:
|
||||
response = await client.get(
|
||||
f"{service_def.url}/internal/demo/clone/health",
|
||||
headers={"X-Internal-API-Key": self.internal_api_key}
|
||||
)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
@@ -1,166 +0,0 @@
|
||||
"""
|
||||
Kubernetes Job-based Demo Data Cloner
|
||||
Triggers a K8s Job to clone demo data at the database level
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from typing import Dict, Any
|
||||
import os
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class K8sJobCloner:
|
||||
"""Triggers Kubernetes Jobs to clone demo data"""
|
||||
|
||||
def __init__(self):
|
||||
self.k8s_api_url = os.getenv("KUBERNETES_SERVICE_HOST")
|
||||
self.namespace = os.getenv("POD_NAMESPACE", "bakery-ia")
|
||||
self.clone_job_image = os.getenv("CLONE_JOB_IMAGE", "bakery/inventory-service:latest")
|
||||
# Service account token for K8s API access
|
||||
with open("/var/run/secrets/kubernetes.io/serviceaccount/token", "r") as f:
|
||||
self.token = f.read()
|
||||
|
||||
async def clone_tenant_data(
|
||||
self,
|
||||
session_id: str,
|
||||
base_demo_tenant_id: str,
|
||||
virtual_tenant_id: str,
|
||||
demo_account_type: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Clone demo data by creating a Kubernetes Job
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
base_demo_tenant_id: Base demo tenant UUID (not used in job approach)
|
||||
virtual_tenant_id: Virtual tenant UUID for this session
|
||||
demo_account_type: Type of demo account
|
||||
|
||||
Returns:
|
||||
Job creation status
|
||||
"""
|
||||
logger.info(
|
||||
"Triggering demo data cloning job",
|
||||
session_id=session_id,
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
demo_account_type=demo_account_type,
|
||||
clone_image=self.clone_job_image
|
||||
)
|
||||
|
||||
job_name = f"demo-clone-{virtual_tenant_id[:8]}"
|
||||
|
||||
# Create Job manifest
|
||||
job_manifest = {
|
||||
"apiVersion": "batch/v1",
|
||||
"kind": "Job",
|
||||
"metadata": {
|
||||
"name": job_name,
|
||||
"namespace": self.namespace,
|
||||
"labels": {
|
||||
"app": "demo-clone",
|
||||
"session-id": session_id,
|
||||
"component": "runtime"
|
||||
}
|
||||
},
|
||||
"spec": {
|
||||
"ttlSecondsAfterFinished": 3600,
|
||||
"backoffLimit": 2,
|
||||
"template": {
|
||||
"metadata": {
|
||||
"labels": {"app": "demo-clone"}
|
||||
},
|
||||
"spec": {
|
||||
"restartPolicy": "Never",
|
||||
"containers": [{
|
||||
"name": "clone-data",
|
||||
"image": self.clone_job_image, # Configured via environment variable
|
||||
"imagePullPolicy": "IfNotPresent", # Don't pull if image exists locally
|
||||
"command": ["python", "/app/scripts/demo/clone_demo_tenant.py"],
|
||||
"env": [
|
||||
{"name": "VIRTUAL_TENANT_ID", "value": virtual_tenant_id},
|
||||
{"name": "DEMO_ACCOUNT_TYPE", "value": demo_account_type},
|
||||
{
|
||||
"name": "INVENTORY_DATABASE_URL",
|
||||
"valueFrom": {
|
||||
"secretKeyRef": {
|
||||
"name": "database-secrets",
|
||||
"key": "INVENTORY_DATABASE_URL"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "SALES_DATABASE_URL",
|
||||
"valueFrom": {
|
||||
"secretKeyRef": {
|
||||
"name": "database-secrets",
|
||||
"key": "SALES_DATABASE_URL"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "ORDERS_DATABASE_URL",
|
||||
"valueFrom": {
|
||||
"secretKeyRef": {
|
||||
"name": "database-secrets",
|
||||
"key": "ORDERS_DATABASE_URL"
|
||||
}
|
||||
}
|
||||
},
|
||||
{"name": "LOG_LEVEL", "value": "INFO"}
|
||||
],
|
||||
"resources": {
|
||||
"requests": {"memory": "256Mi", "cpu": "100m"},
|
||||
"limits": {"memory": "512Mi", "cpu": "500m"}
|
||||
}
|
||||
}]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
try:
|
||||
# Create the Job via K8s API
|
||||
async with httpx.AsyncClient(verify=False, timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"https://{self.k8s_api_url}/apis/batch/v1/namespaces/{self.namespace}/jobs",
|
||||
json=job_manifest,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
)
|
||||
|
||||
if response.status_code == 201:
|
||||
logger.info(
|
||||
"Demo clone job created successfully",
|
||||
job_name=job_name,
|
||||
session_id=session_id
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"job_name": job_name,
|
||||
"method": "kubernetes_job"
|
||||
}
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to create demo clone job",
|
||||
status_code=response.status_code,
|
||||
response=response.text
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"K8s API returned {response.status_code}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error creating demo clone job",
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
@@ -11,8 +11,9 @@ import uuid
|
||||
import secrets
|
||||
import structlog
|
||||
|
||||
from app.models import DemoSession, DemoSessionStatus
|
||||
from app.models import DemoSession, DemoSessionStatus, CloningStatus
|
||||
from app.core import RedisClient, settings
|
||||
from app.services.clone_orchestrator import CloneOrchestrator
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -23,6 +24,7 @@ class DemoSessionManager:
|
||||
def __init__(self, db: AsyncSession, redis: RedisClient):
|
||||
self.db = db
|
||||
self.redis = redis
|
||||
self.orchestrator = CloneOrchestrator()
|
||||
|
||||
async def create_session(
|
||||
self,
|
||||
@@ -56,16 +58,23 @@ class DemoSessionManager:
|
||||
if not demo_config:
|
||||
raise ValueError(f"Invalid demo account type: {demo_account_type}")
|
||||
|
||||
# Get base tenant ID for cloning
|
||||
base_tenant_id_str = demo_config.get("base_tenant_id")
|
||||
if not base_tenant_id_str:
|
||||
raise ValueError(f"Base tenant ID not configured for demo account type: {demo_account_type}")
|
||||
|
||||
base_tenant_id = uuid.UUID(base_tenant_id_str)
|
||||
|
||||
# Create session record
|
||||
session = DemoSession(
|
||||
session_id=session_id,
|
||||
user_id=uuid.UUID(user_id) if user_id else None,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
base_demo_tenant_id=uuid.uuid4(), # Will be set by seeding script
|
||||
base_demo_tenant_id=base_tenant_id,
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
demo_account_type=demo_account_type,
|
||||
status=DemoSessionStatus.ACTIVE,
|
||||
status=DemoSessionStatus.PENDING, # Start as pending until cloning completes
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.DEMO_SESSION_DURATION_MINUTES
|
||||
@@ -265,3 +274,173 @@ class DemoSessionManager:
|
||||
) / max(len([s for s in all_sessions if s.destroyed_at]), 1),
|
||||
"total_requests": sum(s.request_count for s in all_sessions)
|
||||
}
|
||||
|
||||
async def trigger_orchestrated_cloning(
|
||||
self,
|
||||
session: DemoSession,
|
||||
base_tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Trigger orchestrated cloning across all services
|
||||
|
||||
Args:
|
||||
session: Demo session
|
||||
base_tenant_id: Template tenant ID to clone from
|
||||
|
||||
Returns:
|
||||
Orchestration result
|
||||
"""
|
||||
logger.info(
|
||||
"Triggering orchestrated cloning",
|
||||
session_id=session.session_id,
|
||||
virtual_tenant_id=str(session.virtual_tenant_id)
|
||||
)
|
||||
|
||||
# Mark cloning as started
|
||||
session.cloning_started_at = datetime.now(timezone.utc)
|
||||
await self.db.commit()
|
||||
|
||||
# Run orchestration
|
||||
result = await self.orchestrator.clone_all_services(
|
||||
base_tenant_id=base_tenant_id,
|
||||
virtual_tenant_id=str(session.virtual_tenant_id),
|
||||
demo_account_type=session.demo_account_type,
|
||||
session_id=session.session_id
|
||||
)
|
||||
|
||||
# Update session with results
|
||||
await self._update_session_from_clone_result(session, result)
|
||||
|
||||
return result
|
||||
|
||||
async def _update_session_from_clone_result(
|
||||
self,
|
||||
session: DemoSession,
|
||||
clone_result: Dict[str, Any]
|
||||
):
|
||||
"""Update session with cloning results"""
|
||||
|
||||
# Map overall status to session status
|
||||
overall_status = clone_result.get("overall_status")
|
||||
if overall_status == "ready":
|
||||
session.status = DemoSessionStatus.READY
|
||||
elif overall_status == "failed":
|
||||
session.status = DemoSessionStatus.FAILED
|
||||
elif overall_status == "partial":
|
||||
session.status = DemoSessionStatus.PARTIAL
|
||||
|
||||
# Update cloning metadata
|
||||
session.cloning_completed_at = datetime.now(timezone.utc)
|
||||
session.total_records_cloned = clone_result.get("total_records_cloned", 0)
|
||||
session.cloning_progress = clone_result.get("services", {})
|
||||
|
||||
# Mark legacy flags for backward compatibility
|
||||
if overall_status in ["ready", "partial"]:
|
||||
session.data_cloned = True
|
||||
session.redis_populated = True
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
|
||||
# Cache status in Redis for fast polling
|
||||
await self._cache_session_status(session)
|
||||
|
||||
logger.info(
|
||||
"Session updated with clone results",
|
||||
session_id=session.session_id,
|
||||
status=session.status.value,
|
||||
total_records=session.total_records_cloned
|
||||
)
|
||||
|
||||
async def _cache_session_status(self, session: DemoSession):
|
||||
"""Cache session status in Redis for fast status checks"""
|
||||
status_key = f"session:{session.session_id}:status"
|
||||
|
||||
status_data = {
|
||||
"session_id": session.session_id,
|
||||
"status": session.status.value,
|
||||
"progress": session.cloning_progress,
|
||||
"total_records_cloned": session.total_records_cloned,
|
||||
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
|
||||
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
|
||||
"expires_at": session.expires_at.isoformat()
|
||||
}
|
||||
|
||||
import json as json_module
|
||||
await self.redis.client.setex(
|
||||
status_key,
|
||||
7200, # Cache for 2 hours
|
||||
json_module.dumps(status_data) # Convert to JSON string
|
||||
)
|
||||
|
||||
async def get_session_status(self, session_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get current session status with cloning progress
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
Status information including per-service progress
|
||||
"""
|
||||
# Try Redis cache first
|
||||
status_key = f"session:{session_id}:status"
|
||||
cached = await self.redis.client.get(status_key)
|
||||
|
||||
if cached:
|
||||
import json
|
||||
return json.loads(cached)
|
||||
|
||||
# Fall back to database
|
||||
session = await self.get_session(session_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
await self._cache_session_status(session)
|
||||
|
||||
return {
|
||||
"session_id": session.session_id,
|
||||
"status": session.status.value,
|
||||
"progress": session.cloning_progress,
|
||||
"total_records_cloned": session.total_records_cloned,
|
||||
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
|
||||
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
|
||||
"expires_at": session.expires_at.isoformat()
|
||||
}
|
||||
|
||||
async def retry_failed_cloning(
|
||||
self,
|
||||
session_id: str,
|
||||
services: Optional[list] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Retry failed cloning operations
|
||||
|
||||
Args:
|
||||
session_id: Session ID
|
||||
services: Specific services to retry (defaults to all failed)
|
||||
|
||||
Returns:
|
||||
Retry result
|
||||
"""
|
||||
session = await self.get_session(session_id)
|
||||
if not session:
|
||||
raise ValueError(f"Session not found: {session_id}")
|
||||
|
||||
if session.status not in [DemoSessionStatus.FAILED, DemoSessionStatus.PARTIAL]:
|
||||
raise ValueError(f"Cannot retry session in {session.status.value} state")
|
||||
|
||||
logger.info(
|
||||
"Retrying failed cloning",
|
||||
session_id=session_id,
|
||||
services=services
|
||||
)
|
||||
|
||||
# Get base tenant ID from config
|
||||
demo_config = settings.DEMO_ACCOUNTS.get(session.demo_account_type)
|
||||
base_tenant_id = demo_config.get("base_tenant_id", str(session.base_demo_tenant_id))
|
||||
|
||||
# Trigger new cloning attempt
|
||||
result = await self.trigger_orchestrated_cloning(session, base_tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user