Files
bakery-ia/services/demo_session/app/services/session_manager.py
2025-12-05 20:07:01 +01:00

557 lines
20 KiB
Python

"""
Demo Session Manager
Handles creation, extension, and destruction of demo sessions
"""
from sqlalchemy.ext.asyncio import AsyncSession
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, Any
import uuid
import secrets
import structlog
from sqlalchemy import select
from app.models import DemoSession, DemoSessionStatus, CloningStatus
from app.core.redis_wrapper import DemoRedisWrapper
from app.core import settings
from app.services.clone_orchestrator import CloneOrchestrator
from app.repositories.demo_session_repository import DemoSessionRepository
logger = structlog.get_logger()
class DemoSessionManager:
"""Manages demo session lifecycle"""
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
self.db = db
self.redis = redis
self.repository = DemoSessionRepository(db)
self.orchestrator = CloneOrchestrator(redis_manager=redis) # Pass Redis for real-time progress updates
async def create_session(
self,
demo_account_type: str,
subscription_tier: Optional[str] = None,
user_id: Optional[str] = None,
ip_address: Optional[str] = None,
user_agent: Optional[str] = None
) -> DemoSession:
"""
Create a new demo session
Args:
demo_account_type: 'professional' or 'enterprise'
subscription_tier: Force specific subscription tier (professional/enterprise)
user_id: Optional user ID if authenticated
ip_address: Client IP address
user_agent: Client user agent
Returns:
Created demo session
"""
logger.info("Creating demo session",
demo_account_type=demo_account_type,
subscription_tier=subscription_tier)
# Generate unique session ID
session_id = f"demo_{secrets.token_urlsafe(16)}"
# Generate virtual tenant ID
virtual_tenant_id = uuid.uuid4()
# Get base demo tenant ID from config
demo_config = settings.DEMO_ACCOUNTS.get(demo_account_type)
if not demo_config:
raise ValueError(f"Invalid demo account type: {demo_account_type}")
# Override subscription tier if specified
effective_subscription_tier = subscription_tier or demo_config.get("subscription_tier")
# Get base tenant ID for cloning
base_tenant_id_str = demo_config.get("base_tenant_id")
if not base_tenant_id_str:
raise ValueError(f"Base tenant ID not configured for demo account type: {demo_account_type}")
base_tenant_id = uuid.UUID(base_tenant_id_str)
# Validate that the base tenant ID exists in the tenant service
# This is important to prevent cloning from non-existent base tenants
await self._validate_base_tenant_exists(base_tenant_id, demo_account_type)
# Handle enterprise chain setup
child_tenant_ids = []
if demo_account_type == 'enterprise':
# Validate child template tenants exist before proceeding
child_configs = demo_config.get('children', [])
await self._validate_child_template_tenants(child_configs)
# Generate child tenant IDs for enterprise demos
child_tenant_ids = [uuid.uuid4() for _ in child_configs]
# Create session record using repository
session_data = {
"session_id": session_id,
"user_id": uuid.UUID(user_id) if user_id else None,
"ip_address": ip_address,
"user_agent": user_agent,
"base_demo_tenant_id": base_tenant_id,
"virtual_tenant_id": virtual_tenant_id,
"demo_account_type": demo_account_type,
"status": DemoSessionStatus.PENDING, # Start as pending until cloning completes
"created_at": datetime.now(timezone.utc),
"expires_at": datetime.now(timezone.utc) + timedelta(
minutes=settings.DEMO_SESSION_DURATION_MINUTES
),
"last_activity_at": datetime.now(timezone.utc),
"data_cloned": False,
"redis_populated": False,
"session_metadata": {
"demo_config": demo_config,
"subscription_tier": effective_subscription_tier,
"extension_count": 0,
"is_enterprise": demo_account_type == 'enterprise',
"child_tenant_ids": [str(tid) for tid in child_tenant_ids] if child_tenant_ids else [],
"child_configs": demo_config.get('children', []) if demo_account_type == 'enterprise' else []
}
}
session = await self.repository.create(session_data)
# Store session metadata in Redis
await self._store_session_metadata(session)
logger.info(
"Demo session created",
session_id=session_id,
virtual_tenant_id=str(virtual_tenant_id),
demo_account_type=demo_account_type,
is_enterprise=demo_account_type == 'enterprise',
child_tenant_count=len(child_tenant_ids),
expires_at=session.expires_at.isoformat()
)
return session
async def get_session(self, session_id: str) -> Optional[DemoSession]:
"""Get session by session_id"""
return await self.repository.get_by_session_id(session_id)
async def get_session_by_virtual_tenant(self, virtual_tenant_id: str) -> Optional[DemoSession]:
"""Get session by virtual tenant ID"""
return await self.repository.get_by_virtual_tenant_id(uuid.UUID(virtual_tenant_id))
async def extend_session(self, session_id: str) -> DemoSession:
"""
Extend session expiration time
Args:
session_id: Session ID to extend
Returns:
Updated session
Raises:
ValueError: If session cannot be extended
"""
session = await self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
if session.status != DemoSessionStatus.ACTIVE:
raise ValueError(f"Cannot extend {session.status.value} session")
# Check extension limit
extension_count = session.session_metadata.get("extension_count", 0)
if extension_count >= settings.DEMO_SESSION_MAX_EXTENSIONS:
raise ValueError(f"Maximum extensions ({settings.DEMO_SESSION_MAX_EXTENSIONS}) reached")
# Extend expiration
new_expires_at = datetime.now(timezone.utc) + timedelta(
minutes=settings.DEMO_SESSION_DURATION_MINUTES
)
session.expires_at = new_expires_at
session.last_activity_at = datetime.now(timezone.utc)
session.session_metadata["extension_count"] = extension_count + 1
session = await self.repository.update(session)
# Extend Redis TTL
await self.redis.extend_session_ttl(
session_id,
settings.REDIS_SESSION_TTL
)
logger.info(
"Session extended",
session_id=session_id,
new_expires_at=new_expires_at.isoformat(),
extension_count=extension_count + 1
)
return session
async def update_activity(self, session_id: str):
"""Update last activity timestamp"""
await self.repository.update_activity(session_id)
async def mark_data_cloned(self, session_id: str):
"""Mark session as having data cloned"""
await self.repository.mark_data_cloned(session_id)
async def mark_redis_populated(self, session_id: str):
"""Mark session as having Redis data populated"""
await self.repository.mark_redis_populated(session_id)
async def destroy_session(self, session_id: str):
"""
Destroy a demo session and cleanup resources
Args:
session_id: Session ID to destroy
"""
session = await self.get_session(session_id)
if not session:
logger.warning("Session not found for destruction", session_id=session_id)
return
# Update session status via repository
await self.repository.destroy(session_id)
# Delete Redis data
await self.redis.delete_session_data(session_id)
logger.info(
"Session destroyed",
session_id=session_id,
virtual_tenant_id=str(session.virtual_tenant_id)
)
async def _store_session_metadata(self, session: DemoSession):
"""Store session metadata in Redis"""
await self.redis.set_session_data(
session.session_id,
"metadata",
{
"session_id": session.session_id,
"virtual_tenant_id": str(session.virtual_tenant_id),
"demo_account_type": session.demo_account_type,
"expires_at": session.expires_at.isoformat(),
"created_at": session.created_at.isoformat()
},
ttl=settings.REDIS_SESSION_TTL
)
async def get_active_sessions_count(self) -> int:
"""Get count of active sessions"""
return await self.repository.get_active_sessions_count()
async def get_session_stats(self) -> Dict[str, Any]:
"""Get session statistics"""
return await self.repository.get_session_stats()
async def trigger_orchestrated_cloning(
self,
session: DemoSession,
base_tenant_id: str
) -> Dict[str, Any]:
"""
Trigger orchestrated cloning across all services
Args:
session: Demo session
base_tenant_id: Template tenant ID to clone from
Returns:
Orchestration result
"""
logger.info(
"Triggering orchestrated cloning",
session_id=session.session_id,
virtual_tenant_id=str(session.virtual_tenant_id)
)
# Mark cloning as started and update both database and Redis cache
session.cloning_started_at = datetime.now(timezone.utc)
await self.repository.update(session)
# Update Redis cache to reflect that cloning has started
await self._cache_session_status(session)
# Run orchestration
result = await self.orchestrator.clone_all_services(
base_tenant_id=base_tenant_id,
virtual_tenant_id=str(session.virtual_tenant_id),
demo_account_type=session.demo_account_type,
session_id=session.session_id,
session_metadata=session.session_metadata
)
# Update session with results
await self._update_session_from_clone_result(session, result)
return result
async def _validate_base_tenant_exists(self, base_tenant_id: uuid.UUID, demo_account_type: str) -> bool:
"""
Validate that the base tenant exists in the tenant service before starting cloning.
This prevents cloning from non-existent base tenants.
Args:
base_tenant_id: The UUID of the base tenant to validate
demo_account_type: The demo account type for logging
Returns:
True if tenant exists, raises exception otherwise
"""
logger.info(
"Validating base tenant exists before cloning",
base_tenant_id=str(base_tenant_id),
demo_account_type=demo_account_type
)
# Basic validation: check if UUID is valid (not empty/nil)
if str(base_tenant_id) == "00000000-0000-0000-0000-000000000000":
raise ValueError(f"Invalid base tenant ID: {base_tenant_id} for demo type: {demo_account_type}")
# BUG-008 FIX: Actually validate with tenant service
try:
from shared.clients.tenant_client import TenantServiceClient
tenant_client = TenantServiceClient(settings)
tenant = await tenant_client.get_tenant(str(base_tenant_id))
if not tenant:
error_msg = (
f"Base tenant {base_tenant_id} does not exist for demo type {demo_account_type}. "
f"Please verify the base_tenant_id in demo configuration."
)
logger.error(
"Base tenant validation failed",
base_tenant_id=str(base_tenant_id),
demo_account_type=demo_account_type
)
raise ValueError(error_msg)
logger.info(
"Base tenant validation passed",
base_tenant_id=str(base_tenant_id),
tenant_name=tenant.get("name", "unknown"),
demo_account_type=demo_account_type
)
return True
except ValueError:
# Re-raise ValueError from validation failure
raise
except Exception as e:
logger.error(
f"Error validating base tenant: {e}",
base_tenant_id=str(base_tenant_id),
demo_account_type=demo_account_type,
exc_info=True
)
raise ValueError(f"Cannot validate base tenant {base_tenant_id}: {str(e)}")
async def _validate_child_template_tenants(self, child_configs: list) -> bool:
"""
Validate that all child template tenants exist before cloning.
This prevents silent failures when child base tenants are missing.
Args:
child_configs: List of child configurations with base_tenant_id
Returns:
True if all child templates exist, raises exception otherwise
"""
if not child_configs:
logger.warning("No child configurations provided for validation")
return True
logger.info("Validating child template tenants", child_count=len(child_configs))
try:
from shared.clients.tenant_client import TenantServiceClient
tenant_client = TenantServiceClient(settings)
for child_config in child_configs:
child_base_id = child_config.get("base_tenant_id")
child_name = child_config.get("name", "unknown")
if not child_base_id:
raise ValueError(f"Child config missing base_tenant_id: {child_name}")
# Validate child template exists
child_tenant = await tenant_client.get_tenant(child_base_id)
if not child_tenant:
error_msg = (
f"Child template tenant {child_base_id} ('{child_name}') does not exist. "
f"Please verify the base_tenant_id in demo configuration."
)
logger.error(
"Child template validation failed",
base_tenant_id=child_base_id,
child_name=child_name
)
raise ValueError(error_msg)
logger.info(
"Child template validation passed",
base_tenant_id=child_base_id,
child_name=child_name,
tenant_name=child_tenant.get("name", "unknown")
)
logger.info("All child template tenants validated successfully")
return True
except ValueError:
# Re-raise ValueError from validation failure
raise
except Exception as e:
logger.error(
f"Error validating child template tenants: {e}",
exc_info=True
)
raise ValueError(f"Cannot validate child template tenants: {str(e)}")
async def _update_session_from_clone_result(
self,
session: DemoSession,
clone_result: Dict[str, Any]
):
"""Update session with cloning results"""
# Map overall status to session status
overall_status = clone_result.get("overall_status")
if overall_status in ["ready", "completed"]:
session.status = DemoSessionStatus.READY
elif overall_status == "failed":
session.status = DemoSessionStatus.FAILED
elif overall_status == "partial":
session.status = DemoSessionStatus.PARTIAL
# Update cloning metadata
session.cloning_completed_at = datetime.now(timezone.utc)
# The clone result might use 'total_records' or 'total_records_cloned'
session.total_records_cloned = clone_result.get("total_records_cloned",
clone_result.get("total_records", 0))
session.cloning_progress = clone_result.get("services", {})
# Mark legacy flags for backward compatibility
if overall_status in ["ready", "completed", "partial"]:
session.data_cloned = True
session.redis_populated = True
await self.repository.update(session)
# Cache status in Redis for fast polling
await self._cache_session_status(session)
logger.info(
"Session updated with clone results",
session_id=session.session_id,
status=session.status.value,
total_records=session.total_records_cloned
)
async def _cache_session_status(self, session: DemoSession):
"""Cache session status in Redis for fast status checks"""
status_key = f"session:{session.session_id}:status"
status_data = {
"session_id": session.session_id,
"status": session.status.value,
"progress": session.cloning_progress,
"total_records_cloned": session.total_records_cloned,
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
"expires_at": session.expires_at.isoformat()
}
import json as json_module
client = await self.redis.get_client()
await client.setex(
status_key,
7200, # Cache for 2 hours
json_module.dumps(status_data) # Convert to JSON string
)
async def get_session_status(self, session_id: str) -> Dict[str, Any]:
"""
Get current session status with cloning progress
Args:
session_id: Session ID
Returns:
Status information including per-service progress
"""
# Try Redis cache first
status_key = f"session:{session_id}:status"
client = await self.redis.get_client()
cached = await client.get(status_key)
if cached:
import json
return json.loads(cached)
# Fall back to database
session = await self.get_session(session_id)
if not session:
return None
await self._cache_session_status(session)
return {
"session_id": session.session_id,
"status": session.status.value,
"progress": session.cloning_progress,
"total_records_cloned": session.total_records_cloned,
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
"expires_at": session.expires_at.isoformat()
}
async def retry_failed_cloning(
self,
session_id: str,
services: Optional[list] = None
) -> Dict[str, Any]:
"""
Retry failed cloning operations
Args:
session_id: Session ID
services: Specific services to retry (defaults to all failed)
Returns:
Retry result
"""
session = await self.get_session(session_id)
if not session:
raise ValueError(f"Session not found: {session_id}")
if session.status not in [DemoSessionStatus.FAILED, DemoSessionStatus.PARTIAL]:
raise ValueError(f"Cannot retry session in {session.status.value} state")
logger.info(
"Retrying failed cloning",
session_id=session_id,
services=services
)
# Get base tenant ID from config
demo_config = settings.DEMO_ACCOUNTS.get(session.demo_account_type)
base_tenant_id = demo_config.get("base_tenant_id", str(session.base_demo_tenant_id))
# Trigger new cloning attempt
result = await self.trigger_orchestrated_cloning(session, base_tenant_id)
return result