Improve the frontend and repository layer
This commit is contained in:
7
services/demo_session/app/repositories/__init__.py
Normal file
7
services/demo_session/app/repositories/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
Demo Session Repositories
|
||||
"""
|
||||
|
||||
from .demo_session_repository import DemoSessionRepository
|
||||
|
||||
__all__ = ["DemoSessionRepository"]
|
||||
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Demo Session Repository
|
||||
Data access layer for demo sessions
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, List, Dict, Any
|
||||
from uuid import UUID
|
||||
import structlog
|
||||
|
||||
from app.models import DemoSession, DemoSessionStatus
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DemoSessionRepository:
|
||||
"""Repository for DemoSession data access"""
|
||||
|
||||
def __init__(self, db: AsyncSession):
|
||||
self.db = db
|
||||
|
||||
async def create(self, session_data: Dict[str, Any]) -> DemoSession:
|
||||
"""
|
||||
Create a new demo session
|
||||
|
||||
Args:
|
||||
session_data: Dictionary with session attributes
|
||||
|
||||
Returns:
|
||||
Created DemoSession instance
|
||||
"""
|
||||
session = DemoSession(**session_data)
|
||||
self.db.add(session)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
return session
|
||||
|
||||
async def get_by_session_id(self, session_id: str) -> Optional[DemoSession]:
|
||||
"""
|
||||
Get session by session_id
|
||||
|
||||
Args:
|
||||
session_id: Session ID string
|
||||
|
||||
Returns:
|
||||
DemoSession or None if not found
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(DemoSession.session_id == session_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_by_virtual_tenant_id(self, virtual_tenant_id: UUID) -> Optional[DemoSession]:
|
||||
"""
|
||||
Get session by virtual tenant ID
|
||||
|
||||
Args:
|
||||
virtual_tenant_id: Virtual tenant UUID
|
||||
|
||||
Returns:
|
||||
DemoSession or None if not found
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(DemoSession.virtual_tenant_id == virtual_tenant_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def update(self, session: DemoSession) -> DemoSession:
|
||||
"""
|
||||
Update an existing session
|
||||
|
||||
Args:
|
||||
session: DemoSession instance with updates
|
||||
|
||||
Returns:
|
||||
Updated DemoSession instance
|
||||
"""
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
return session
|
||||
|
||||
async def update_fields(self, session_id: str, **fields) -> None:
|
||||
"""
|
||||
Update specific fields of a session
|
||||
|
||||
Args:
|
||||
session_id: Session ID to update
|
||||
**fields: Field names and values to update
|
||||
"""
|
||||
await self.db.execute(
|
||||
update(DemoSession)
|
||||
.where(DemoSession.session_id == session_id)
|
||||
.values(**fields)
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
async def update_activity(self, session_id: str) -> None:
|
||||
"""
|
||||
Update last activity timestamp and increment request count
|
||||
|
||||
Args:
|
||||
session_id: Session ID to update
|
||||
"""
|
||||
await self.db.execute(
|
||||
update(DemoSession)
|
||||
.where(DemoSession.session_id == session_id)
|
||||
.values(
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
request_count=DemoSession.request_count + 1
|
||||
)
|
||||
)
|
||||
await self.db.commit()
|
||||
|
||||
async def mark_data_cloned(self, session_id: str) -> None:
|
||||
"""
|
||||
Mark session as having data cloned
|
||||
|
||||
Args:
|
||||
session_id: Session ID to update
|
||||
"""
|
||||
await self.update_fields(session_id, data_cloned=True)
|
||||
|
||||
async def mark_redis_populated(self, session_id: str) -> None:
|
||||
"""
|
||||
Mark session as having Redis data populated
|
||||
|
||||
Args:
|
||||
session_id: Session ID to update
|
||||
"""
|
||||
await self.update_fields(session_id, redis_populated=True)
|
||||
|
||||
async def destroy(self, session_id: str) -> None:
|
||||
"""
|
||||
Mark session as destroyed
|
||||
|
||||
Args:
|
||||
session_id: Session ID to destroy
|
||||
"""
|
||||
await self.update_fields(
|
||||
session_id,
|
||||
status=DemoSessionStatus.DESTROYED,
|
||||
destroyed_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
async def get_active_sessions_count(self) -> int:
|
||||
"""
|
||||
Get count of active sessions
|
||||
|
||||
Returns:
|
||||
Number of active sessions
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(DemoSession.status == DemoSessionStatus.ACTIVE)
|
||||
)
|
||||
return len(result.scalars().all())
|
||||
|
||||
async def get_all_sessions(self) -> List[DemoSession]:
|
||||
"""
|
||||
Get all demo sessions
|
||||
|
||||
Returns:
|
||||
List of all DemoSession instances
|
||||
"""
|
||||
result = await self.db.execute(select(DemoSession))
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_sessions_by_status(self, status: DemoSessionStatus) -> List[DemoSession]:
|
||||
"""
|
||||
Get sessions by status
|
||||
|
||||
Args:
|
||||
status: DemoSessionStatus to filter by
|
||||
|
||||
Returns:
|
||||
List of DemoSession instances with the specified status
|
||||
"""
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(DemoSession.status == status)
|
||||
)
|
||||
return result.scalars().all()
|
||||
|
||||
async def get_session_stats(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Get session statistics
|
||||
|
||||
Returns:
|
||||
Dictionary with session statistics
|
||||
"""
|
||||
all_sessions = await self.get_all_sessions()
|
||||
active_sessions = [s for s in all_sessions if s.status == DemoSessionStatus.ACTIVE]
|
||||
|
||||
return {
|
||||
"total_sessions": len(all_sessions),
|
||||
"active_sessions": len(active_sessions),
|
||||
"expired_sessions": len([s for s in all_sessions if s.status == DemoSessionStatus.EXPIRED]),
|
||||
"destroyed_sessions": len([s for s in all_sessions if s.status == DemoSessionStatus.DESTROYED]),
|
||||
"avg_duration_minutes": sum(
|
||||
(s.destroyed_at - s.created_at).total_seconds() / 60
|
||||
for s in all_sessions if s.destroyed_at
|
||||
) / max(len([s for s in all_sessions if s.destroyed_at]), 1),
|
||||
"total_requests": sum(s.request_count for s in all_sessions)
|
||||
}
|
||||
@@ -4,7 +4,6 @@ Handles creation, extension, and destruction of demo sessions
|
||||
"""
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, update
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
import uuid
|
||||
@@ -15,6 +14,7 @@ from app.models import DemoSession, DemoSessionStatus, CloningStatus
|
||||
from app.core.redis_wrapper import DemoRedisWrapper
|
||||
from app.core import settings
|
||||
from app.services.clone_orchestrator import CloneOrchestrator
|
||||
from app.repositories.demo_session_repository import DemoSessionRepository
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@@ -25,6 +25,7 @@ class DemoSessionManager:
|
||||
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
|
||||
self.db = db
|
||||
self.redis = redis
|
||||
self.repository = DemoSessionRepository(db)
|
||||
self.orchestrator = CloneOrchestrator()
|
||||
|
||||
async def create_session(
|
||||
@@ -66,32 +67,30 @@ class DemoSessionManager:
|
||||
|
||||
base_tenant_id = uuid.UUID(base_tenant_id_str)
|
||||
|
||||
# Create session record
|
||||
session = DemoSession(
|
||||
session_id=session_id,
|
||||
user_id=uuid.UUID(user_id) if user_id else None,
|
||||
ip_address=ip_address,
|
||||
user_agent=user_agent,
|
||||
base_demo_tenant_id=base_tenant_id,
|
||||
virtual_tenant_id=virtual_tenant_id,
|
||||
demo_account_type=demo_account_type,
|
||||
status=DemoSessionStatus.PENDING, # Start as pending until cloning completes
|
||||
created_at=datetime.now(timezone.utc),
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(
|
||||
# Create session record using repository
|
||||
session_data = {
|
||||
"session_id": session_id,
|
||||
"user_id": uuid.UUID(user_id) if user_id else None,
|
||||
"ip_address": ip_address,
|
||||
"user_agent": user_agent,
|
||||
"base_demo_tenant_id": base_tenant_id,
|
||||
"virtual_tenant_id": virtual_tenant_id,
|
||||
"demo_account_type": demo_account_type,
|
||||
"status": DemoSessionStatus.PENDING, # Start as pending until cloning completes
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(
|
||||
minutes=settings.DEMO_SESSION_DURATION_MINUTES
|
||||
),
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
data_cloned=False,
|
||||
redis_populated=False,
|
||||
session_metadata={
|
||||
"last_activity_at": datetime.now(timezone.utc),
|
||||
"data_cloned": False,
|
||||
"redis_populated": False,
|
||||
"session_metadata": {
|
||||
"demo_config": demo_config,
|
||||
"extension_count": 0
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
self.db.add(session)
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
session = await self.repository.create(session_data)
|
||||
|
||||
# Store session metadata in Redis
|
||||
await self._store_session_metadata(session)
|
||||
@@ -107,19 +106,11 @@ class DemoSessionManager:
|
||||
|
||||
async def get_session(self, session_id: str) -> Optional[DemoSession]:
|
||||
"""Get session by session_id"""
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(DemoSession.session_id == session_id)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return await self.repository.get_by_session_id(session_id)
|
||||
|
||||
async def get_session_by_virtual_tenant(self, virtual_tenant_id: str) -> Optional[DemoSession]:
|
||||
"""Get session by virtual tenant ID"""
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(
|
||||
DemoSession.virtual_tenant_id == uuid.UUID(virtual_tenant_id)
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
return await self.repository.get_by_virtual_tenant_id(uuid.UUID(virtual_tenant_id))
|
||||
|
||||
async def extend_session(self, session_id: str) -> DemoSession:
|
||||
"""
|
||||
@@ -156,8 +147,7 @@ class DemoSessionManager:
|
||||
session.last_activity_at = datetime.now(timezone.utc)
|
||||
session.session_metadata["extension_count"] = extension_count + 1
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
session = await self.repository.update(session)
|
||||
|
||||
# Extend Redis TTL
|
||||
await self.redis.extend_session_ttl(
|
||||
@@ -176,33 +166,15 @@ class DemoSessionManager:
|
||||
|
||||
async def update_activity(self, session_id: str):
|
||||
"""Update last activity timestamp"""
|
||||
await self.db.execute(
|
||||
update(DemoSession)
|
||||
.where(DemoSession.session_id == session_id)
|
||||
.values(
|
||||
last_activity_at=datetime.now(timezone.utc),
|
||||
request_count=DemoSession.request_count + 1
|
||||
)
|
||||
)
|
||||
await self.db.commit()
|
||||
await self.repository.update_activity(session_id)
|
||||
|
||||
async def mark_data_cloned(self, session_id: str):
|
||||
"""Mark session as having data cloned"""
|
||||
await self.db.execute(
|
||||
update(DemoSession)
|
||||
.where(DemoSession.session_id == session_id)
|
||||
.values(data_cloned=True)
|
||||
)
|
||||
await self.db.commit()
|
||||
await self.repository.mark_data_cloned(session_id)
|
||||
|
||||
async def mark_redis_populated(self, session_id: str):
|
||||
"""Mark session as having Redis data populated"""
|
||||
await self.db.execute(
|
||||
update(DemoSession)
|
||||
.where(DemoSession.session_id == session_id)
|
||||
.values(redis_populated=True)
|
||||
)
|
||||
await self.db.commit()
|
||||
await self.repository.mark_redis_populated(session_id)
|
||||
|
||||
async def destroy_session(self, session_id: str):
|
||||
"""
|
||||
@@ -217,11 +189,8 @@ class DemoSessionManager:
|
||||
logger.warning("Session not found for destruction", session_id=session_id)
|
||||
return
|
||||
|
||||
# Update session status
|
||||
session.status = DemoSessionStatus.DESTROYED
|
||||
session.destroyed_at = datetime.now(timezone.utc)
|
||||
|
||||
await self.db.commit()
|
||||
# Update session status via repository
|
||||
await self.repository.destroy(session_id)
|
||||
|
||||
# Delete Redis data
|
||||
await self.redis.delete_session_data(session_id)
|
||||
@@ -229,10 +198,7 @@ class DemoSessionManager:
|
||||
logger.info(
|
||||
"Session destroyed",
|
||||
session_id=session_id,
|
||||
virtual_tenant_id=str(session.virtual_tenant_id),
|
||||
duration_seconds=(
|
||||
session.destroyed_at - session.created_at
|
||||
).total_seconds()
|
||||
virtual_tenant_id=str(session.virtual_tenant_id)
|
||||
)
|
||||
|
||||
async def _store_session_metadata(self, session: DemoSession):
|
||||
@@ -252,29 +218,11 @@ class DemoSessionManager:
|
||||
|
||||
async def get_active_sessions_count(self) -> int:
|
||||
"""Get count of active sessions"""
|
||||
result = await self.db.execute(
|
||||
select(DemoSession).where(DemoSession.status == DemoSessionStatus.ACTIVE)
|
||||
)
|
||||
return len(result.scalars().all())
|
||||
return await self.repository.get_active_sessions_count()
|
||||
|
||||
async def get_session_stats(self) -> Dict[str, Any]:
|
||||
"""Get session statistics"""
|
||||
result = await self.db.execute(select(DemoSession))
|
||||
all_sessions = result.scalars().all()
|
||||
|
||||
active_sessions = [s for s in all_sessions if s.status == DemoSessionStatus.ACTIVE]
|
||||
|
||||
return {
|
||||
"total_sessions": len(all_sessions),
|
||||
"active_sessions": len(active_sessions),
|
||||
"expired_sessions": len([s for s in all_sessions if s.status == DemoSessionStatus.EXPIRED]),
|
||||
"destroyed_sessions": len([s for s in all_sessions if s.status == DemoSessionStatus.DESTROYED]),
|
||||
"avg_duration_minutes": sum(
|
||||
(s.destroyed_at - s.created_at).total_seconds() / 60
|
||||
for s in all_sessions if s.destroyed_at
|
||||
) / max(len([s for s in all_sessions if s.destroyed_at]), 1),
|
||||
"total_requests": sum(s.request_count for s in all_sessions)
|
||||
}
|
||||
return await self.repository.get_session_stats()
|
||||
|
||||
async def trigger_orchestrated_cloning(
|
||||
self,
|
||||
@@ -299,7 +247,7 @@ class DemoSessionManager:
|
||||
|
||||
# Mark cloning as started
|
||||
session.cloning_started_at = datetime.now(timezone.utc)
|
||||
await self.db.commit()
|
||||
await self.repository.update(session)
|
||||
|
||||
# Run orchestration
|
||||
result = await self.orchestrator.clone_all_services(
|
||||
@@ -340,8 +288,7 @@ class DemoSessionManager:
|
||||
session.data_cloned = True
|
||||
session.redis_populated = True
|
||||
|
||||
await self.db.commit()
|
||||
await self.db.refresh(session)
|
||||
await self.repository.update(session)
|
||||
|
||||
# Cache status in Redis for fast polling
|
||||
await self._cache_session_status(session)
|
||||
|
||||
Reference in New Issue
Block a user