fix demo session 1

This commit is contained in:
Urtzi Alfaro
2026-01-02 11:12:50 +01:00
parent 507031deaf
commit cf0176673c
15 changed files with 136 additions and 107 deletions

View File

@@ -249,6 +249,13 @@ build_python_service('orchestrator-service', 'orchestrator')
# Demo Services # Demo Services
build_python_service('demo-session-service', 'demo_session') build_python_service('demo-session-service', 'demo_session')
# Tell Tilt that demo-cleanup-worker uses the demo-session-service image
k8s_image_json_path(
'bakery/demo-session-service',
'{.spec.template.spec.containers[?(@.name=="worker")].image}',
name='demo-cleanup-worker'
)
# ============================================================================= # =============================================================================
# INFRASTRUCTURE RESOURCES # INFRASTRUCTURE RESOURCES
# ============================================================================= # =============================================================================

View File

@@ -21,8 +21,8 @@ spec:
spec: spec:
containers: containers:
- name: worker - name: worker
image: demo-session-service:latest image: bakery/demo-session-service
imagePullPolicy: Never imagePullPolicy: IfNotPresent
command: command:
- python - python
- -m - -m

View File

@@ -24,7 +24,7 @@ from app.models.users import User
from app.core.config import settings from app.core.config import settings
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs # Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6" DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -38,7 +38,7 @@ def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
return True return True
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,

View File

@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import DatabaseManager from app.core.database import DatabaseManager
from app.core.redis_wrapper import DemoRedisWrapper from app.core.redis_wrapper import DemoRedisWrapper
from app.services.data_cloner import DemoDataCloner from app.services.cleanup_service import DemoCleanupService
from app.models.demo_session import DemoSession, DemoSessionStatus from app.models.demo_session import DemoSession, DemoSessionStatus
logger = structlog.get_logger() logger = structlog.get_logger()
@@ -122,93 +122,63 @@ class CleanupWorker:
"""Execute cleanup for list of sessions with parallelization""" """Execute cleanup for list of sessions with parallelization"""
async with get_db_session() as db: async with get_db_session() as db:
redis = DemoRedisWrapper() redis = DemoRedisWrapper()
data_cloner = DemoDataCloner(db, redis) cleanup_service = DemoCleanupService(db, redis)
try: # Get sessions to cleanup
# Get sessions to cleanup result = await db.execute(
result = await db.execute( select(DemoSession).where(
select(DemoSession).where( DemoSession.session_id.in_(session_ids)
DemoSession.session_id.in_(session_ids)
)
) )
sessions = result.scalars().all() )
sessions = result.scalars().all()
stats = { stats = {
"cleaned_up": 0, "cleaned_up": 0,
"failed": 0, "failed": 0,
"errors": [] "errors": []
} }
# Process each session # Process each session
for session in sessions: for session in sessions:
try: try:
# Mark session as expired # Mark session as expired
session.status = DemoSessionStatus.EXPIRED session.status = DemoSessionStatus.EXPIRED
await db.commit() await db.commit()
# Check if this is an enterprise demo with children # Use cleanup service to delete all session data
child_tenant_ids = [] cleanup_result = await cleanup_service.cleanup_session(session)
if session.demo_account_type == "enterprise" and session.session_metadata:
child_tenant_ids = session.session_metadata.get("child_tenant_ids", [])
# Delete child tenants in parallel (for enterprise demos)
if child_tenant_ids:
logger.info(
"Cleaning up enterprise demo children",
session_id=session.session_id,
child_count=len(child_tenant_ids)
)
child_tasks = [
data_cloner.delete_session_data(
str(child_id),
session.session_id
)
for child_id in child_tenant_ids
]
child_results = await asyncio.gather(*child_tasks, return_exceptions=True)
# Log any child deletion failures
for child_id, result in zip(child_tenant_ids, child_results):
if isinstance(result, Exception):
logger.error(
"Failed to delete child tenant",
child_id=child_id,
error=str(result)
)
# Delete parent/main session data
await data_cloner.delete_session_data(
str(session.virtual_tenant_id),
session.session_id
)
if cleanup_result["success"]:
stats["cleaned_up"] += 1 stats["cleaned_up"] += 1
logger.info( logger.info(
"Session cleaned up", "Session cleaned up",
session_id=session.session_id, session_id=session.session_id,
is_enterprise=(session.demo_account_type == "enterprise"), is_enterprise=(session.demo_account_type == "enterprise"),
children_deleted=len(child_tenant_ids) total_deleted=cleanup_result["total_deleted"],
duration_ms=cleanup_result["duration_ms"]
) )
else:
except Exception as e:
stats["failed"] += 1 stats["failed"] += 1
stats["errors"].append({ stats["errors"].append({
"session_id": session.session_id, "session_id": session.session_id,
"error": str(e) "error": "Cleanup completed with errors",
"details": cleanup_result["errors"]
}) })
logger.error(
"Failed to cleanup session",
session_id=session.session_id,
error=str(e),
exc_info=True
)
return stats except Exception as e:
stats["failed"] += 1
stats["errors"].append({
"session_id": session.session_id,
"error": str(e)
})
logger.error(
"Failed to cleanup session",
session_id=session.session_id,
error=str(e),
exc_info=True
)
finally: return stats
# Always close HTTP client
await data_cloner.close()
async def _mark_job_complete(self, job_id: str, stats: Dict[str, Any]): async def _mark_job_complete(self, job_id: str, stats: Dict[str, Any]):
"""Mark job as complete in Redis""" """Mark job as complete in Redis"""

View File

@@ -98,8 +98,37 @@ class DemoCleanupService:
# Delete child tenants if enterprise # Delete child tenants if enterprise
if session.demo_account_type == "enterprise" and session.session_metadata: if session.demo_account_type == "enterprise" and session.session_metadata:
child_tenant_ids = session.session_metadata.get("child_tenant_ids", []) child_tenant_ids = session.session_metadata.get("child_tenant_ids", [])
logger.info(
"Deleting child tenant data",
session_id=session_id,
child_count=len(child_tenant_ids)
)
for child_tenant_id in child_tenant_ids: for child_tenant_id in child_tenant_ids:
await self._delete_from_all_services(child_tenant_id) child_results = await self._delete_from_all_services(str(child_tenant_id))
# Aggregate child deletion results
for (service_name, _), child_result in zip(self.services, child_results):
if isinstance(child_result, Exception):
logger.warning(
"Failed to delete child tenant data from service",
service=service_name,
child_tenant_id=child_tenant_id,
error=str(child_result)
)
else:
child_deleted = child_result.get("records_deleted", {}).get("total", 0)
total_deleted += child_deleted
# Update details to track child deletions
if service_name not in details:
details[service_name] = {"child_deletions": []}
if "child_deletions" not in details[service_name]:
details[service_name]["child_deletions"] = []
details[service_name]["child_deletions"].append({
"child_tenant_id": str(child_tenant_id),
"records_deleted": child_deleted
})
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000) duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)

View File

@@ -217,11 +217,16 @@ async def clone_demo_data(
"actual_arrival": actual_arrival.isoformat() if actual_arrival else None "actual_arrival": actual_arrival.isoformat() if actual_arrival else None
}) })
# Make route_number unique per virtual tenant to prevent conflicts across demo sessions
# Append last 6 chars of virtual_tenant_id to ensure uniqueness
base_route_number = route_data.get('route_number', 'ROUTE-001')
unique_route_number = f"{base_route_number}-{str(virtual_uuid)[-6:]}"
# Create new delivery route # Create new delivery route
new_route = DeliveryRoute( new_route = DeliveryRoute(
id=transformed_id, id=transformed_id,
tenant_id=virtual_uuid, tenant_id=virtual_uuid,
route_number=route_data.get('route_number'), route_number=unique_route_number,
route_date=route_date, route_date=route_date,
vehicle_id=route_data.get('vehicle_id'), vehicle_id=route_data.get('vehicle_id'),
driver_id=route_data.get('driver_id'), driver_id=route_data.get('driver_id'),
@@ -294,6 +299,11 @@ async def clone_demo_data(
# (In production, items are in the linked purchase order) # (In production, items are in the linked purchase order)
items_json = json.dumps(shipment_data.get('items', [])) if shipment_data.get('items') else None items_json = json.dumps(shipment_data.get('items', [])) if shipment_data.get('items') else None
# Make shipment_number unique per virtual tenant to prevent conflicts across demo sessions
# Append last 6 chars of virtual_tenant_id to ensure uniqueness
base_shipment_number = shipment_data.get('shipment_number', 'SHIP-001')
unique_shipment_number = f"{base_shipment_number}-{str(virtual_uuid)[-6:]}"
# Create new shipment # Create new shipment
new_shipment = Shipment( new_shipment = Shipment(
id=transformed_id, id=transformed_id,
@@ -302,7 +312,7 @@ async def clone_demo_data(
child_tenant_id=shipment_data.get('child_tenant_id'), child_tenant_id=shipment_data.get('child_tenant_id'),
purchase_order_id=purchase_order_id, # Link to internal transfer PO purchase_order_id=purchase_order_id, # Link to internal transfer PO
delivery_route_id=delivery_route_id, # MUST use transformed ID delivery_route_id=delivery_route_id, # MUST use transformed ID
shipment_number=shipment_data.get('shipment_number'), shipment_number=unique_shipment_number,
shipment_date=shipment_date, shipment_date=shipment_date,
status=shipment_data.get('status', 'pending'), status=shipment_data.get('status', 'pending'),
total_weight_kg=shipment_data.get('total_weight_kg'), total_weight_kg=shipment_data.get('total_weight_kg'),

View File

@@ -22,7 +22,7 @@ from app.core.database import get_db
from app.models.forecasts import Forecast, PredictionBatch from app.models.forecasts import Forecast, PredictionBatch
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs # Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6" DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -91,7 +91,7 @@ def align_to_week_start(target_date: datetime) -> datetime:
return target_date return target_date
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,

View File

@@ -5,6 +5,7 @@ Handles internal demo data cloning operations
from fastapi import APIRouter, Depends, HTTPException, Header from fastapi import APIRouter, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, delete
from typing import Optional from typing import Optional
import structlog import structlog
import json import json
@@ -19,7 +20,7 @@ from app.models import Ingredient, Stock, ProductType
from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker, calculate_edge_case_times from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker, calculate_edge_case_times
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
async def verify_internal_api_key(x_internal_api_key: str = Header(None)): async def verify_internal_api_key(x_internal_api_key: str = Header(None)):
@@ -77,7 +78,7 @@ def parse_date_field(date_value, session_time: datetime, field_name: str = "date
return None return None
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data_internal( async def clone_demo_data_internal(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,
@@ -183,7 +184,6 @@ async def clone_demo_data_internal(
seed_data = json.load(f) seed_data = json.load(f)
# Check if data already exists for this virtual tenant (idempotency) # Check if data already exists for this virtual tenant (idempotency)
from sqlalchemy import select, delete
existing_check = await db.execute( existing_check = await db.execute(
select(Ingredient).where(Ingredient.tenant_id == virtual_tenant_id).limit(1) select(Ingredient).where(Ingredient.tenant_id == virtual_tenant_id).limit(1)
) )
@@ -547,30 +547,40 @@ async def delete_demo_tenant_data(
""" """
start_time = datetime.now(timezone.utc) start_time = datetime.now(timezone.utc)
from app.models.inventory import StockMovement
records_deleted = { records_deleted = {
"ingredients": 0, "stock_movements": 0,
"stock": 0, "stock": 0,
"ingredients": 0,
"total": 0 "total": 0
} }
try: try:
# Delete in reverse dependency order # Delete in reverse dependency order
# 1. Delete stock batches (depends on ingredients) # 1. Delete stock movements (depends on stock and ingredients)
result = await db.execute(
delete(StockMovement)
.where(StockMovement.tenant_id == virtual_tenant_id)
)
records_deleted["stock_movements"] = result.rowcount
# 2. Delete stock batches (depends on ingredients)
result = await db.execute( result = await db.execute(
delete(Stock) delete(Stock)
.where(Stock.tenant_id == virtual_tenant_id) .where(Stock.tenant_id == virtual_tenant_id)
) )
records_deleted["stock"] = result.rowcount records_deleted["stock"] = result.rowcount
# 2. Delete ingredients # 3. Delete ingredients
result = await db.execute( result = await db.execute(
delete(Ingredient) delete(Ingredient)
.where(Ingredient.tenant_id == virtual_tenant_id) .where(Ingredient.tenant_id == virtual_tenant_id)
) )
records_deleted["ingredients"] = result.rowcount records_deleted["ingredients"] = result.rowcount
records_deleted["total"] = sum(records_deleted.values()) records_deleted["total"] = records_deleted["stock_movements"] + records_deleted["stock"] + records_deleted["ingredients"]
await db.commit() await db.commit()
@@ -603,7 +613,7 @@ async def delete_demo_tenant_data(
) )
@router.get("/internal/count") @router.get("/count")
async def get_ingredient_count( async def get_ingredient_count(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),

View File

@@ -26,7 +26,7 @@ from shared.utils.demo_dates import adjust_date_for_demo
from app.core.config import settings from app.core.config import settings
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
logger = structlog.get_logger() logger = structlog.get_logger()
@@ -154,7 +154,7 @@ async def load_fixture_data_for_tenant(
return 1 return 1
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,

View File

@@ -29,7 +29,7 @@ from shared.utils.demo_dates import (
from app.core.config import settings from app.core.config import settings
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs # Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6" DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -43,7 +43,7 @@ def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
return True return True
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,

View File

@@ -28,7 +28,7 @@ from app.models.recipes import (
from app.core.config import settings from app.core.config import settings
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs # Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6" DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -98,7 +98,7 @@ def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
return True return True
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,
@@ -434,7 +434,7 @@ async def delete_demo_tenant_data(
) )
@router.get("/internal/count") @router.get("/count")
async def get_recipe_count( async def get_recipe_count(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),

View File

@@ -25,7 +25,7 @@ from app.models.sales import SalesData
from app.core.config import settings from app.core.config import settings
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs # Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6" DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -95,7 +95,7 @@ def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
return True return True
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,

View File

@@ -24,7 +24,7 @@ sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs # Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6" DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -94,7 +94,7 @@ def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
return True return True
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,
@@ -409,7 +409,7 @@ async def delete_demo_tenant_data(
) )
@router.get("/internal/count") @router.get("/count")
async def get_supplier_count( async def get_supplier_count(
tenant_id: str, tenant_id: str,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),

View File

@@ -22,7 +22,7 @@ from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker
from app.core.config import settings from app.core.config import settings
logger = structlog.get_logger() logger = structlog.get_logger()
router = APIRouter() router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs # Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6" DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -92,7 +92,7 @@ def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
return True return True
@router.post("/internal/demo/clone") @router.post("/clone")
async def clone_demo_data( async def clone_demo_data(
base_tenant_id: str, base_tenant_id: str,
virtual_tenant_id: str, virtual_tenant_id: str,
@@ -546,7 +546,7 @@ async def clone_demo_data(
} }
@router.post("/internal/demo/create-child") @router.post("/create-child")
async def create_child_outlet( async def create_child_outlet(
request: dict, request: dict,
db: AsyncSession = Depends(get_db), db: AsyncSession = Depends(get_db),

View File

@@ -91,9 +91,12 @@ For more details, see services/forecasting/README.md
from typing import Dict, Any, Optional, List from typing import Dict, Any, Optional, List
from datetime import date from datetime import date
import structlog
from .base_service_client import BaseServiceClient from .base_service_client import BaseServiceClient
from shared.config.base import BaseServiceSettings from shared.config.base import BaseServiceSettings
logger = structlog.get_logger()
class ForecastServiceClient(BaseServiceClient): class ForecastServiceClient(BaseServiceClient):
"""Client for communicating with the forecasting service""" """Client for communicating with the forecasting service"""
@@ -367,13 +370,13 @@ class ForecastServiceClient(BaseServiceClient):
) )
if result: if result:
self.logger.info( logger.info(
"Demand insights triggered successfully via internal endpoint", "Demand insights triggered successfully via internal endpoint",
tenant_id=tenant_id, tenant_id=tenant_id,
insights_posted=result.get("insights_posted", 0) insights_posted=result.get("insights_posted", 0)
) )
else: else:
self.logger.warning( logger.warning(
"Demand insights internal endpoint returned no result", "Demand insights internal endpoint returned no result",
tenant_id=tenant_id tenant_id=tenant_id
) )
@@ -381,8 +384,8 @@ class ForecastServiceClient(BaseServiceClient):
return result return result
except Exception as e: except Exception as e:
self.logger.error( logger.error(
"Error triggering demand insights via internal endpoint", "Failed to trigger demand insights",
tenant_id=tenant_id, tenant_id=tenant_id,
error=str(e) error=str(e)
) )