557 lines
20 KiB
Python
557 lines
20 KiB
Python
"""
|
|
Demo Session Manager
|
|
Handles creation, extension, and destruction of demo sessions
|
|
"""
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, Dict, Any
|
|
import uuid
|
|
import secrets
|
|
import structlog
|
|
from sqlalchemy import select
|
|
|
|
from app.models import DemoSession, DemoSessionStatus, CloningStatus
|
|
from app.core.redis_wrapper import DemoRedisWrapper
|
|
from app.core import settings
|
|
from app.services.clone_orchestrator import CloneOrchestrator
|
|
from app.repositories.demo_session_repository import DemoSessionRepository
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class DemoSessionManager:
|
|
"""Manages demo session lifecycle"""
|
|
|
|
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
|
|
self.db = db
|
|
self.redis = redis
|
|
self.repository = DemoSessionRepository(db)
|
|
self.orchestrator = CloneOrchestrator(redis_manager=redis) # Pass Redis for real-time progress updates
|
|
|
|
async def create_session(
|
|
self,
|
|
demo_account_type: str,
|
|
subscription_tier: Optional[str] = None,
|
|
user_id: Optional[str] = None,
|
|
ip_address: Optional[str] = None,
|
|
user_agent: Optional[str] = None
|
|
) -> DemoSession:
|
|
"""
|
|
Create a new demo session
|
|
|
|
Args:
|
|
demo_account_type: 'professional' or 'enterprise'
|
|
subscription_tier: Force specific subscription tier (professional/enterprise)
|
|
user_id: Optional user ID if authenticated
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Created demo session
|
|
"""
|
|
logger.info("Creating demo session",
|
|
demo_account_type=demo_account_type,
|
|
subscription_tier=subscription_tier)
|
|
|
|
# Generate unique session ID
|
|
session_id = f"demo_{secrets.token_urlsafe(16)}"
|
|
|
|
# Generate virtual tenant ID
|
|
virtual_tenant_id = uuid.uuid4()
|
|
|
|
# Get base demo tenant ID from config
|
|
demo_config = settings.DEMO_ACCOUNTS.get(demo_account_type)
|
|
if not demo_config:
|
|
raise ValueError(f"Invalid demo account type: {demo_account_type}")
|
|
|
|
# Override subscription tier if specified
|
|
effective_subscription_tier = subscription_tier or demo_config.get("subscription_tier")
|
|
|
|
# Get base tenant ID for cloning
|
|
base_tenant_id_str = demo_config.get("base_tenant_id")
|
|
if not base_tenant_id_str:
|
|
raise ValueError(f"Base tenant ID not configured for demo account type: {demo_account_type}")
|
|
|
|
base_tenant_id = uuid.UUID(base_tenant_id_str)
|
|
|
|
# Validate that the base tenant ID exists in the tenant service
|
|
# This is important to prevent cloning from non-existent base tenants
|
|
await self._validate_base_tenant_exists(base_tenant_id, demo_account_type)
|
|
|
|
# Handle enterprise chain setup
|
|
child_tenant_ids = []
|
|
if demo_account_type == 'enterprise':
|
|
# Validate child template tenants exist before proceeding
|
|
child_configs = demo_config.get('children', [])
|
|
await self._validate_child_template_tenants(child_configs)
|
|
|
|
# Generate child tenant IDs for enterprise demos
|
|
child_tenant_ids = [uuid.uuid4() for _ in child_configs]
|
|
|
|
# Create session record using repository
|
|
session_data = {
|
|
"session_id": session_id,
|
|
"user_id": uuid.UUID(user_id) if user_id else None,
|
|
"ip_address": ip_address,
|
|
"user_agent": user_agent,
|
|
"base_demo_tenant_id": base_tenant_id,
|
|
"virtual_tenant_id": virtual_tenant_id,
|
|
"demo_account_type": demo_account_type,
|
|
"status": DemoSessionStatus.PENDING, # Start as pending until cloning completes
|
|
"created_at": datetime.now(timezone.utc),
|
|
"expires_at": datetime.now(timezone.utc) + timedelta(
|
|
minutes=settings.DEMO_SESSION_DURATION_MINUTES
|
|
),
|
|
"last_activity_at": datetime.now(timezone.utc),
|
|
"data_cloned": False,
|
|
"redis_populated": False,
|
|
"session_metadata": {
|
|
"demo_config": demo_config,
|
|
"subscription_tier": effective_subscription_tier,
|
|
"extension_count": 0,
|
|
"is_enterprise": demo_account_type == 'enterprise',
|
|
"child_tenant_ids": [str(tid) for tid in child_tenant_ids] if child_tenant_ids else [],
|
|
"child_configs": demo_config.get('children', []) if demo_account_type == 'enterprise' else []
|
|
}
|
|
}
|
|
|
|
session = await self.repository.create(session_data)
|
|
|
|
# Store session metadata in Redis
|
|
await self._store_session_metadata(session)
|
|
|
|
logger.info(
|
|
"Demo session created",
|
|
session_id=session_id,
|
|
virtual_tenant_id=str(virtual_tenant_id),
|
|
demo_account_type=demo_account_type,
|
|
is_enterprise=demo_account_type == 'enterprise',
|
|
child_tenant_count=len(child_tenant_ids),
|
|
expires_at=session.expires_at.isoformat()
|
|
)
|
|
|
|
return session
|
|
|
|
async def get_session(self, session_id: str) -> Optional[DemoSession]:
|
|
"""Get session by session_id"""
|
|
return await self.repository.get_by_session_id(session_id)
|
|
|
|
async def get_session_by_virtual_tenant(self, virtual_tenant_id: str) -> Optional[DemoSession]:
|
|
"""Get session by virtual tenant ID"""
|
|
return await self.repository.get_by_virtual_tenant_id(uuid.UUID(virtual_tenant_id))
|
|
|
|
async def extend_session(self, session_id: str) -> DemoSession:
|
|
"""
|
|
Extend session expiration time
|
|
|
|
Args:
|
|
session_id: Session ID to extend
|
|
|
|
Returns:
|
|
Updated session
|
|
|
|
Raises:
|
|
ValueError: If session cannot be extended
|
|
"""
|
|
session = await self.get_session(session_id)
|
|
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
if session.status != DemoSessionStatus.ACTIVE:
|
|
raise ValueError(f"Cannot extend {session.status.value} session")
|
|
|
|
# Check extension limit
|
|
extension_count = session.session_metadata.get("extension_count", 0)
|
|
if extension_count >= settings.DEMO_SESSION_MAX_EXTENSIONS:
|
|
raise ValueError(f"Maximum extensions ({settings.DEMO_SESSION_MAX_EXTENSIONS}) reached")
|
|
|
|
# Extend expiration
|
|
new_expires_at = datetime.now(timezone.utc) + timedelta(
|
|
minutes=settings.DEMO_SESSION_DURATION_MINUTES
|
|
)
|
|
|
|
session.expires_at = new_expires_at
|
|
session.last_activity_at = datetime.now(timezone.utc)
|
|
session.session_metadata["extension_count"] = extension_count + 1
|
|
|
|
session = await self.repository.update(session)
|
|
|
|
# Extend Redis TTL
|
|
await self.redis.extend_session_ttl(
|
|
session_id,
|
|
settings.REDIS_SESSION_TTL
|
|
)
|
|
|
|
logger.info(
|
|
"Session extended",
|
|
session_id=session_id,
|
|
new_expires_at=new_expires_at.isoformat(),
|
|
extension_count=extension_count + 1
|
|
)
|
|
|
|
return session
|
|
|
|
async def update_activity(self, session_id: str):
|
|
"""Update last activity timestamp"""
|
|
await self.repository.update_activity(session_id)
|
|
|
|
async def mark_data_cloned(self, session_id: str):
|
|
"""Mark session as having data cloned"""
|
|
await self.repository.mark_data_cloned(session_id)
|
|
|
|
async def mark_redis_populated(self, session_id: str):
|
|
"""Mark session as having Redis data populated"""
|
|
await self.repository.mark_redis_populated(session_id)
|
|
|
|
async def destroy_session(self, session_id: str):
|
|
"""
|
|
Destroy a demo session and cleanup resources
|
|
|
|
Args:
|
|
session_id: Session ID to destroy
|
|
"""
|
|
session = await self.get_session(session_id)
|
|
|
|
if not session:
|
|
logger.warning("Session not found for destruction", session_id=session_id)
|
|
return
|
|
|
|
# Update session status via repository
|
|
await self.repository.destroy(session_id)
|
|
|
|
# Delete Redis data
|
|
await self.redis.delete_session_data(session_id)
|
|
|
|
logger.info(
|
|
"Session destroyed",
|
|
session_id=session_id,
|
|
virtual_tenant_id=str(session.virtual_tenant_id)
|
|
)
|
|
|
|
async def _store_session_metadata(self, session: DemoSession):
|
|
"""Store session metadata in Redis"""
|
|
await self.redis.set_session_data(
|
|
session.session_id,
|
|
"metadata",
|
|
{
|
|
"session_id": session.session_id,
|
|
"virtual_tenant_id": str(session.virtual_tenant_id),
|
|
"demo_account_type": session.demo_account_type,
|
|
"expires_at": session.expires_at.isoformat(),
|
|
"created_at": session.created_at.isoformat()
|
|
},
|
|
ttl=settings.REDIS_SESSION_TTL
|
|
)
|
|
|
|
async def get_active_sessions_count(self) -> int:
|
|
"""Get count of active sessions"""
|
|
return await self.repository.get_active_sessions_count()
|
|
|
|
async def get_session_stats(self) -> Dict[str, Any]:
|
|
"""Get session statistics"""
|
|
return await self.repository.get_session_stats()
|
|
|
|
async def trigger_orchestrated_cloning(
|
|
self,
|
|
session: DemoSession,
|
|
base_tenant_id: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Trigger orchestrated cloning across all services
|
|
|
|
Args:
|
|
session: Demo session
|
|
base_tenant_id: Template tenant ID to clone from
|
|
|
|
Returns:
|
|
Orchestration result
|
|
"""
|
|
logger.info(
|
|
"Triggering orchestrated cloning",
|
|
session_id=session.session_id,
|
|
virtual_tenant_id=str(session.virtual_tenant_id)
|
|
)
|
|
|
|
# Mark cloning as started and update both database and Redis cache
|
|
session.cloning_started_at = datetime.now(timezone.utc)
|
|
await self.repository.update(session)
|
|
|
|
# Update Redis cache to reflect that cloning has started
|
|
await self._cache_session_status(session)
|
|
|
|
# Run orchestration
|
|
result = await self.orchestrator.clone_all_services(
|
|
base_tenant_id=base_tenant_id,
|
|
virtual_tenant_id=str(session.virtual_tenant_id),
|
|
demo_account_type=session.demo_account_type,
|
|
session_id=session.session_id,
|
|
session_metadata=session.session_metadata
|
|
)
|
|
|
|
# Update session with results
|
|
await self._update_session_from_clone_result(session, result)
|
|
|
|
return result
|
|
|
|
async def _validate_base_tenant_exists(self, base_tenant_id: uuid.UUID, demo_account_type: str) -> bool:
|
|
"""
|
|
Validate that the base tenant exists in the tenant service before starting cloning.
|
|
This prevents cloning from non-existent base tenants.
|
|
|
|
Args:
|
|
base_tenant_id: The UUID of the base tenant to validate
|
|
demo_account_type: The demo account type for logging
|
|
|
|
Returns:
|
|
True if tenant exists, raises exception otherwise
|
|
"""
|
|
logger.info(
|
|
"Validating base tenant exists before cloning",
|
|
base_tenant_id=str(base_tenant_id),
|
|
demo_account_type=demo_account_type
|
|
)
|
|
|
|
# Basic validation: check if UUID is valid (not empty/nil)
|
|
if str(base_tenant_id) == "00000000-0000-0000-0000-000000000000":
|
|
raise ValueError(f"Invalid base tenant ID: {base_tenant_id} for demo type: {demo_account_type}")
|
|
|
|
# BUG-008 FIX: Actually validate with tenant service
|
|
try:
|
|
from shared.clients.tenant_client import TenantServiceClient
|
|
|
|
tenant_client = TenantServiceClient(settings)
|
|
tenant = await tenant_client.get_tenant(str(base_tenant_id))
|
|
|
|
if not tenant:
|
|
error_msg = (
|
|
f"Base tenant {base_tenant_id} does not exist for demo type {demo_account_type}. "
|
|
f"Please verify the base_tenant_id in demo configuration."
|
|
)
|
|
logger.error(
|
|
"Base tenant validation failed",
|
|
base_tenant_id=str(base_tenant_id),
|
|
demo_account_type=demo_account_type
|
|
)
|
|
raise ValueError(error_msg)
|
|
|
|
logger.info(
|
|
"Base tenant validation passed",
|
|
base_tenant_id=str(base_tenant_id),
|
|
tenant_name=tenant.get("name", "unknown"),
|
|
demo_account_type=demo_account_type
|
|
)
|
|
return True
|
|
|
|
except ValueError:
|
|
# Re-raise ValueError from validation failure
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error validating base tenant: {e}",
|
|
base_tenant_id=str(base_tenant_id),
|
|
demo_account_type=demo_account_type,
|
|
exc_info=True
|
|
)
|
|
raise ValueError(f"Cannot validate base tenant {base_tenant_id}: {str(e)}")
|
|
|
|
async def _validate_child_template_tenants(self, child_configs: list) -> bool:
|
|
"""
|
|
Validate that all child template tenants exist before cloning.
|
|
This prevents silent failures when child base tenants are missing.
|
|
|
|
Args:
|
|
child_configs: List of child configurations with base_tenant_id
|
|
|
|
Returns:
|
|
True if all child templates exist, raises exception otherwise
|
|
"""
|
|
if not child_configs:
|
|
logger.warning("No child configurations provided for validation")
|
|
return True
|
|
|
|
logger.info("Validating child template tenants", child_count=len(child_configs))
|
|
|
|
try:
|
|
from shared.clients.tenant_client import TenantServiceClient
|
|
|
|
tenant_client = TenantServiceClient(settings)
|
|
|
|
for child_config in child_configs:
|
|
child_base_id = child_config.get("base_tenant_id")
|
|
child_name = child_config.get("name", "unknown")
|
|
|
|
if not child_base_id:
|
|
raise ValueError(f"Child config missing base_tenant_id: {child_name}")
|
|
|
|
# Validate child template exists
|
|
child_tenant = await tenant_client.get_tenant(child_base_id)
|
|
|
|
if not child_tenant:
|
|
error_msg = (
|
|
f"Child template tenant {child_base_id} ('{child_name}') does not exist. "
|
|
f"Please verify the base_tenant_id in demo configuration."
|
|
)
|
|
logger.error(
|
|
"Child template validation failed",
|
|
base_tenant_id=child_base_id,
|
|
child_name=child_name
|
|
)
|
|
raise ValueError(error_msg)
|
|
|
|
logger.info(
|
|
"Child template validation passed",
|
|
base_tenant_id=child_base_id,
|
|
child_name=child_name,
|
|
tenant_name=child_tenant.get("name", "unknown")
|
|
)
|
|
|
|
logger.info("All child template tenants validated successfully")
|
|
return True
|
|
|
|
except ValueError:
|
|
# Re-raise ValueError from validation failure
|
|
raise
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error validating child template tenants: {e}",
|
|
exc_info=True
|
|
)
|
|
raise ValueError(f"Cannot validate child template tenants: {str(e)}")
|
|
|
|
async def _update_session_from_clone_result(
|
|
self,
|
|
session: DemoSession,
|
|
clone_result: Dict[str, Any]
|
|
):
|
|
"""Update session with cloning results"""
|
|
|
|
# Map overall status to session status
|
|
overall_status = clone_result.get("overall_status")
|
|
if overall_status in ["ready", "completed"]:
|
|
session.status = DemoSessionStatus.READY
|
|
elif overall_status == "failed":
|
|
session.status = DemoSessionStatus.FAILED
|
|
elif overall_status == "partial":
|
|
session.status = DemoSessionStatus.PARTIAL
|
|
|
|
# Update cloning metadata
|
|
session.cloning_completed_at = datetime.now(timezone.utc)
|
|
# The clone result might use 'total_records' or 'total_records_cloned'
|
|
session.total_records_cloned = clone_result.get("total_records_cloned",
|
|
clone_result.get("total_records", 0))
|
|
session.cloning_progress = clone_result.get("services", {})
|
|
|
|
# Mark legacy flags for backward compatibility
|
|
if overall_status in ["ready", "completed", "partial"]:
|
|
session.data_cloned = True
|
|
session.redis_populated = True
|
|
|
|
await self.repository.update(session)
|
|
|
|
# Cache status in Redis for fast polling
|
|
await self._cache_session_status(session)
|
|
|
|
logger.info(
|
|
"Session updated with clone results",
|
|
session_id=session.session_id,
|
|
status=session.status.value,
|
|
total_records=session.total_records_cloned
|
|
)
|
|
|
|
async def _cache_session_status(self, session: DemoSession):
|
|
"""Cache session status in Redis for fast status checks"""
|
|
status_key = f"session:{session.session_id}:status"
|
|
|
|
status_data = {
|
|
"session_id": session.session_id,
|
|
"status": session.status.value,
|
|
"progress": session.cloning_progress,
|
|
"total_records_cloned": session.total_records_cloned,
|
|
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
|
|
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
|
|
"expires_at": session.expires_at.isoformat()
|
|
}
|
|
|
|
import json as json_module
|
|
client = await self.redis.get_client()
|
|
await client.setex(
|
|
status_key,
|
|
7200, # Cache for 2 hours
|
|
json_module.dumps(status_data) # Convert to JSON string
|
|
)
|
|
|
|
async def get_session_status(self, session_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Get current session status with cloning progress
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
|
|
Returns:
|
|
Status information including per-service progress
|
|
"""
|
|
# Try Redis cache first
|
|
status_key = f"session:{session_id}:status"
|
|
client = await self.redis.get_client()
|
|
cached = await client.get(status_key)
|
|
|
|
if cached:
|
|
import json
|
|
return json.loads(cached)
|
|
|
|
# Fall back to database
|
|
session = await self.get_session(session_id)
|
|
if not session:
|
|
return None
|
|
|
|
await self._cache_session_status(session)
|
|
|
|
return {
|
|
"session_id": session.session_id,
|
|
"status": session.status.value,
|
|
"progress": session.cloning_progress,
|
|
"total_records_cloned": session.total_records_cloned,
|
|
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
|
|
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
|
|
"expires_at": session.expires_at.isoformat()
|
|
}
|
|
|
|
async def retry_failed_cloning(
|
|
self,
|
|
session_id: str,
|
|
services: Optional[list] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Retry failed cloning operations
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
services: Specific services to retry (defaults to all failed)
|
|
|
|
Returns:
|
|
Retry result
|
|
"""
|
|
session = await self.get_session(session_id)
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
if session.status not in [DemoSessionStatus.FAILED, DemoSessionStatus.PARTIAL]:
|
|
raise ValueError(f"Cannot retry session in {session.status.value} state")
|
|
|
|
logger.info(
|
|
"Retrying failed cloning",
|
|
session_id=session_id,
|
|
services=services
|
|
)
|
|
|
|
# Get base tenant ID from config
|
|
demo_config = settings.DEMO_ACCOUNTS.get(session.demo_account_type)
|
|
base_tenant_id = demo_config.get("base_tenant_id", str(session.base_demo_tenant_id))
|
|
|
|
# Trigger new cloning attempt
|
|
result = await self.trigger_orchestrated_cloning(session, base_tenant_id)
|
|
|
|
return result
|