397 lines
13 KiB
Python
397 lines
13 KiB
Python
"""
|
|
Demo Session Manager
|
|
Handles creation, extension, and destruction of demo sessions
|
|
"""
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional, Dict, Any
|
|
import uuid
|
|
import secrets
|
|
import structlog
|
|
|
|
from app.models import DemoSession, DemoSessionStatus, CloningStatus
|
|
from app.core.redis_wrapper import DemoRedisWrapper
|
|
from app.core import settings
|
|
from app.services.clone_orchestrator import CloneOrchestrator
|
|
from app.repositories.demo_session_repository import DemoSessionRepository
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class DemoSessionManager:
|
|
"""Manages demo session lifecycle"""
|
|
|
|
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
|
|
self.db = db
|
|
self.redis = redis
|
|
self.repository = DemoSessionRepository(db)
|
|
self.orchestrator = CloneOrchestrator()
|
|
|
|
async def create_session(
|
|
self,
|
|
demo_account_type: str,
|
|
user_id: Optional[str] = None,
|
|
ip_address: Optional[str] = None,
|
|
user_agent: Optional[str] = None
|
|
) -> DemoSession:
|
|
"""
|
|
Create a new demo session
|
|
|
|
Args:
|
|
demo_account_type: 'individual_bakery' or 'central_baker'
|
|
user_id: Optional user ID if authenticated
|
|
ip_address: Client IP address
|
|
user_agent: Client user agent
|
|
|
|
Returns:
|
|
Created demo session
|
|
"""
|
|
logger.info("Creating demo session", demo_account_type=demo_account_type)
|
|
|
|
# Generate unique session ID
|
|
session_id = f"demo_{secrets.token_urlsafe(16)}"
|
|
|
|
# Generate virtual tenant ID
|
|
virtual_tenant_id = uuid.uuid4()
|
|
|
|
# Get base demo tenant ID from config
|
|
demo_config = settings.DEMO_ACCOUNTS.get(demo_account_type)
|
|
if not demo_config:
|
|
raise ValueError(f"Invalid demo account type: {demo_account_type}")
|
|
|
|
# Get base tenant ID for cloning
|
|
base_tenant_id_str = demo_config.get("base_tenant_id")
|
|
if not base_tenant_id_str:
|
|
raise ValueError(f"Base tenant ID not configured for demo account type: {demo_account_type}")
|
|
|
|
base_tenant_id = uuid.UUID(base_tenant_id_str)
|
|
|
|
# Create session record using repository
|
|
session_data = {
|
|
"session_id": session_id,
|
|
"user_id": uuid.UUID(user_id) if user_id else None,
|
|
"ip_address": ip_address,
|
|
"user_agent": user_agent,
|
|
"base_demo_tenant_id": base_tenant_id,
|
|
"virtual_tenant_id": virtual_tenant_id,
|
|
"demo_account_type": demo_account_type,
|
|
"status": DemoSessionStatus.PENDING, # Start as pending until cloning completes
|
|
"created_at": datetime.now(timezone.utc),
|
|
"expires_at": datetime.now(timezone.utc) + timedelta(
|
|
minutes=settings.DEMO_SESSION_DURATION_MINUTES
|
|
),
|
|
"last_activity_at": datetime.now(timezone.utc),
|
|
"data_cloned": False,
|
|
"redis_populated": False,
|
|
"session_metadata": {
|
|
"demo_config": demo_config,
|
|
"extension_count": 0
|
|
}
|
|
}
|
|
|
|
session = await self.repository.create(session_data)
|
|
|
|
# Store session metadata in Redis
|
|
await self._store_session_metadata(session)
|
|
|
|
logger.info(
|
|
"Demo session created",
|
|
session_id=session_id,
|
|
virtual_tenant_id=str(virtual_tenant_id),
|
|
expires_at=session.expires_at.isoformat()
|
|
)
|
|
|
|
return session
|
|
|
|
async def get_session(self, session_id: str) -> Optional[DemoSession]:
|
|
"""Get session by session_id"""
|
|
return await self.repository.get_by_session_id(session_id)
|
|
|
|
async def get_session_by_virtual_tenant(self, virtual_tenant_id: str) -> Optional[DemoSession]:
|
|
"""Get session by virtual tenant ID"""
|
|
return await self.repository.get_by_virtual_tenant_id(uuid.UUID(virtual_tenant_id))
|
|
|
|
async def extend_session(self, session_id: str) -> DemoSession:
|
|
"""
|
|
Extend session expiration time
|
|
|
|
Args:
|
|
session_id: Session ID to extend
|
|
|
|
Returns:
|
|
Updated session
|
|
|
|
Raises:
|
|
ValueError: If session cannot be extended
|
|
"""
|
|
session = await self.get_session(session_id)
|
|
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
if session.status != DemoSessionStatus.ACTIVE:
|
|
raise ValueError(f"Cannot extend {session.status.value} session")
|
|
|
|
# Check extension limit
|
|
extension_count = session.session_metadata.get("extension_count", 0)
|
|
if extension_count >= settings.DEMO_SESSION_MAX_EXTENSIONS:
|
|
raise ValueError(f"Maximum extensions ({settings.DEMO_SESSION_MAX_EXTENSIONS}) reached")
|
|
|
|
# Extend expiration
|
|
new_expires_at = datetime.now(timezone.utc) + timedelta(
|
|
minutes=settings.DEMO_SESSION_DURATION_MINUTES
|
|
)
|
|
|
|
session.expires_at = new_expires_at
|
|
session.last_activity_at = datetime.now(timezone.utc)
|
|
session.session_metadata["extension_count"] = extension_count + 1
|
|
|
|
session = await self.repository.update(session)
|
|
|
|
# Extend Redis TTL
|
|
await self.redis.extend_session_ttl(
|
|
session_id,
|
|
settings.REDIS_SESSION_TTL
|
|
)
|
|
|
|
logger.info(
|
|
"Session extended",
|
|
session_id=session_id,
|
|
new_expires_at=new_expires_at.isoformat(),
|
|
extension_count=extension_count + 1
|
|
)
|
|
|
|
return session
|
|
|
|
async def update_activity(self, session_id: str):
|
|
"""Update last activity timestamp"""
|
|
await self.repository.update_activity(session_id)
|
|
|
|
async def mark_data_cloned(self, session_id: str):
|
|
"""Mark session as having data cloned"""
|
|
await self.repository.mark_data_cloned(session_id)
|
|
|
|
async def mark_redis_populated(self, session_id: str):
|
|
"""Mark session as having Redis data populated"""
|
|
await self.repository.mark_redis_populated(session_id)
|
|
|
|
async def destroy_session(self, session_id: str):
|
|
"""
|
|
Destroy a demo session and cleanup resources
|
|
|
|
Args:
|
|
session_id: Session ID to destroy
|
|
"""
|
|
session = await self.get_session(session_id)
|
|
|
|
if not session:
|
|
logger.warning("Session not found for destruction", session_id=session_id)
|
|
return
|
|
|
|
# Update session status via repository
|
|
await self.repository.destroy(session_id)
|
|
|
|
# Delete Redis data
|
|
await self.redis.delete_session_data(session_id)
|
|
|
|
logger.info(
|
|
"Session destroyed",
|
|
session_id=session_id,
|
|
virtual_tenant_id=str(session.virtual_tenant_id)
|
|
)
|
|
|
|
async def _store_session_metadata(self, session: DemoSession):
|
|
"""Store session metadata in Redis"""
|
|
await self.redis.set_session_data(
|
|
session.session_id,
|
|
"metadata",
|
|
{
|
|
"session_id": session.session_id,
|
|
"virtual_tenant_id": str(session.virtual_tenant_id),
|
|
"demo_account_type": session.demo_account_type,
|
|
"expires_at": session.expires_at.isoformat(),
|
|
"created_at": session.created_at.isoformat()
|
|
},
|
|
ttl=settings.REDIS_SESSION_TTL
|
|
)
|
|
|
|
async def get_active_sessions_count(self) -> int:
|
|
"""Get count of active sessions"""
|
|
return await self.repository.get_active_sessions_count()
|
|
|
|
async def get_session_stats(self) -> Dict[str, Any]:
|
|
"""Get session statistics"""
|
|
return await self.repository.get_session_stats()
|
|
|
|
async def trigger_orchestrated_cloning(
|
|
self,
|
|
session: DemoSession,
|
|
base_tenant_id: str
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Trigger orchestrated cloning across all services
|
|
|
|
Args:
|
|
session: Demo session
|
|
base_tenant_id: Template tenant ID to clone from
|
|
|
|
Returns:
|
|
Orchestration result
|
|
"""
|
|
logger.info(
|
|
"Triggering orchestrated cloning",
|
|
session_id=session.session_id,
|
|
virtual_tenant_id=str(session.virtual_tenant_id)
|
|
)
|
|
|
|
# Mark cloning as started
|
|
session.cloning_started_at = datetime.now(timezone.utc)
|
|
await self.repository.update(session)
|
|
|
|
# Run orchestration
|
|
result = await self.orchestrator.clone_all_services(
|
|
base_tenant_id=base_tenant_id,
|
|
virtual_tenant_id=str(session.virtual_tenant_id),
|
|
demo_account_type=session.demo_account_type,
|
|
session_id=session.session_id
|
|
)
|
|
|
|
# Update session with results
|
|
await self._update_session_from_clone_result(session, result)
|
|
|
|
return result
|
|
|
|
async def _update_session_from_clone_result(
|
|
self,
|
|
session: DemoSession,
|
|
clone_result: Dict[str, Any]
|
|
):
|
|
"""Update session with cloning results"""
|
|
|
|
# Map overall status to session status
|
|
overall_status = clone_result.get("overall_status")
|
|
if overall_status == "ready":
|
|
session.status = DemoSessionStatus.READY
|
|
elif overall_status == "failed":
|
|
session.status = DemoSessionStatus.FAILED
|
|
elif overall_status == "partial":
|
|
session.status = DemoSessionStatus.PARTIAL
|
|
|
|
# Update cloning metadata
|
|
session.cloning_completed_at = datetime.now(timezone.utc)
|
|
session.total_records_cloned = clone_result.get("total_records_cloned", 0)
|
|
session.cloning_progress = clone_result.get("services", {})
|
|
|
|
# Mark legacy flags for backward compatibility
|
|
if overall_status in ["ready", "partial"]:
|
|
session.data_cloned = True
|
|
session.redis_populated = True
|
|
|
|
await self.repository.update(session)
|
|
|
|
# Cache status in Redis for fast polling
|
|
await self._cache_session_status(session)
|
|
|
|
logger.info(
|
|
"Session updated with clone results",
|
|
session_id=session.session_id,
|
|
status=session.status.value,
|
|
total_records=session.total_records_cloned
|
|
)
|
|
|
|
async def _cache_session_status(self, session: DemoSession):
|
|
"""Cache session status in Redis for fast status checks"""
|
|
status_key = f"session:{session.session_id}:status"
|
|
|
|
status_data = {
|
|
"session_id": session.session_id,
|
|
"status": session.status.value,
|
|
"progress": session.cloning_progress,
|
|
"total_records_cloned": session.total_records_cloned,
|
|
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
|
|
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
|
|
"expires_at": session.expires_at.isoformat()
|
|
}
|
|
|
|
import json as json_module
|
|
client = await self.redis.get_client()
|
|
await client.setex(
|
|
status_key,
|
|
7200, # Cache for 2 hours
|
|
json_module.dumps(status_data) # Convert to JSON string
|
|
)
|
|
|
|
async def get_session_status(self, session_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Get current session status with cloning progress
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
|
|
Returns:
|
|
Status information including per-service progress
|
|
"""
|
|
# Try Redis cache first
|
|
status_key = f"session:{session_id}:status"
|
|
client = await self.redis.get_client()
|
|
cached = await client.get(status_key)
|
|
|
|
if cached:
|
|
import json
|
|
return json.loads(cached)
|
|
|
|
# Fall back to database
|
|
session = await self.get_session(session_id)
|
|
if not session:
|
|
return None
|
|
|
|
await self._cache_session_status(session)
|
|
|
|
return {
|
|
"session_id": session.session_id,
|
|
"status": session.status.value,
|
|
"progress": session.cloning_progress,
|
|
"total_records_cloned": session.total_records_cloned,
|
|
"cloning_started_at": session.cloning_started_at.isoformat() if session.cloning_started_at else None,
|
|
"cloning_completed_at": session.cloning_completed_at.isoformat() if session.cloning_completed_at else None,
|
|
"expires_at": session.expires_at.isoformat()
|
|
}
|
|
|
|
async def retry_failed_cloning(
|
|
self,
|
|
session_id: str,
|
|
services: Optional[list] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Retry failed cloning operations
|
|
|
|
Args:
|
|
session_id: Session ID
|
|
services: Specific services to retry (defaults to all failed)
|
|
|
|
Returns:
|
|
Retry result
|
|
"""
|
|
session = await self.get_session(session_id)
|
|
if not session:
|
|
raise ValueError(f"Session not found: {session_id}")
|
|
|
|
if session.status not in [DemoSessionStatus.FAILED, DemoSessionStatus.PARTIAL]:
|
|
raise ValueError(f"Cannot retry session in {session.status.value} state")
|
|
|
|
logger.info(
|
|
"Retrying failed cloning",
|
|
session_id=session_id,
|
|
services=services
|
|
)
|
|
|
|
# Get base tenant ID from config
|
|
demo_config = settings.DEMO_ACCOUNTS.get(session.demo_account_type)
|
|
base_tenant_id = demo_config.get("base_tenant_id", str(session.base_demo_tenant_id))
|
|
|
|
# Trigger new cloning attempt
|
|
result = await self.trigger_orchestrated_cloning(session, base_tenant_id)
|
|
|
|
return result
|