462 lines
18 KiB
Python
462 lines
18 KiB
Python
"""
|
|
Demo Cleanup Service
|
|
Handles automatic cleanup of expired sessions
|
|
"""
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from datetime import datetime, timezone, timedelta
|
|
import structlog
|
|
import httpx
|
|
import asyncio
|
|
import os
|
|
|
|
from app.models import DemoSession, DemoSessionStatus
|
|
from datetime import datetime, timezone, timedelta
|
|
from app.core.redis_wrapper import DemoRedisWrapper
|
|
from shared.auth.jwt_handler import JWTHandler
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class DemoCleanupService:
|
|
"""Handles cleanup of expired demo sessions"""
|
|
|
|
def __init__(self, db: AsyncSession, redis: DemoRedisWrapper):
|
|
self.db = db
|
|
self.redis = redis
|
|
from app.core.config import settings
|
|
# ✅ Security: JWT service tokens used for all internal communication
|
|
# No longer using internal API keys
|
|
|
|
# JWT handler for creating service tokens
|
|
self.jwt_handler = JWTHandler(settings.JWT_SECRET_KEY, settings.JWT_ALGORITHM)
|
|
|
|
# Service URLs for cleanup
|
|
self.services = [
|
|
("tenant", os.getenv("TENANT_SERVICE_URL", "http://tenant-service:8000")),
|
|
("auth", os.getenv("AUTH_SERVICE_URL", "http://auth-service:8000")),
|
|
("inventory", os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")),
|
|
("recipes", os.getenv("RECIPES_SERVICE_URL", "http://recipes-service:8000")),
|
|
("suppliers", os.getenv("SUPPLIERS_SERVICE_URL", "http://suppliers-service:8000")),
|
|
("production", os.getenv("PRODUCTION_SERVICE_URL", "http://production-service:8000")),
|
|
("procurement", os.getenv("PROCUREMENT_SERVICE_URL", "http://procurement-service:8000")),
|
|
("sales", os.getenv("SALES_SERVICE_URL", "http://sales-service:8000")),
|
|
("orders", os.getenv("ORDERS_SERVICE_URL", "http://orders-service:8000")),
|
|
("forecasting", os.getenv("FORECASTING_SERVICE_URL", "http://forecasting-service:8000")),
|
|
("orchestrator", os.getenv("ORCHESTRATOR_SERVICE_URL", "http://orchestrator-service:8000")),
|
|
]
|
|
|
|
async def cleanup_session(self, session: DemoSession) -> dict:
|
|
"""
|
|
Delete all data for a demo session across all services.
|
|
|
|
Returns:
|
|
{
|
|
"success": bool,
|
|
"total_deleted": int,
|
|
"duration_ms": int,
|
|
"details": {service: {records_deleted, duration_ms}},
|
|
"errors": []
|
|
}
|
|
"""
|
|
start_time = datetime.now(timezone.utc)
|
|
virtual_tenant_id = str(session.virtual_tenant_id)
|
|
session_id = session.session_id
|
|
|
|
logger.info(
|
|
"Starting demo session cleanup",
|
|
session_id=session_id,
|
|
virtual_tenant_id=virtual_tenant_id,
|
|
demo_account_type=session.demo_account_type
|
|
)
|
|
|
|
# Delete from all services in parallel
|
|
tasks = [
|
|
self._delete_from_service(name, url, virtual_tenant_id)
|
|
for name, url in self.services
|
|
]
|
|
|
|
service_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
# Aggregate results
|
|
total_deleted = 0
|
|
details = {}
|
|
errors = []
|
|
|
|
for (service_name, _), result in zip(self.services, service_results):
|
|
if isinstance(result, Exception):
|
|
errors.append(f"{service_name}: {str(result)}")
|
|
details[service_name] = {"status": "error", "error": str(result)}
|
|
else:
|
|
total_deleted += result.get("records_deleted", {}).get("total", 0)
|
|
details[service_name] = result
|
|
|
|
# Delete from Redis
|
|
await self._delete_redis_cache(virtual_tenant_id)
|
|
|
|
# Delete child tenants if enterprise
|
|
if session.demo_account_type == "enterprise" and session.session_metadata:
|
|
child_tenant_ids = session.session_metadata.get("child_tenant_ids", [])
|
|
logger.info(
|
|
"Deleting child tenant data",
|
|
session_id=session_id,
|
|
child_count=len(child_tenant_ids)
|
|
)
|
|
|
|
for child_tenant_id in child_tenant_ids:
|
|
child_results = await self._delete_from_all_services(str(child_tenant_id))
|
|
|
|
# Aggregate child deletion results
|
|
for (service_name, _), child_result in zip(self.services, child_results):
|
|
if isinstance(child_result, Exception):
|
|
logger.warning(
|
|
"Failed to delete child tenant data from service",
|
|
service=service_name,
|
|
child_tenant_id=child_tenant_id,
|
|
error=str(child_result)
|
|
)
|
|
else:
|
|
child_deleted = child_result.get("records_deleted", {}).get("total", 0)
|
|
total_deleted += child_deleted
|
|
|
|
# Update details to track child deletions
|
|
if service_name not in details:
|
|
details[service_name] = {"child_deletions": []}
|
|
if "child_deletions" not in details[service_name]:
|
|
details[service_name]["child_deletions"] = []
|
|
details[service_name]["child_deletions"].append({
|
|
"child_tenant_id": str(child_tenant_id),
|
|
"records_deleted": child_deleted
|
|
})
|
|
|
|
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
|
|
|
|
success = len(errors) == 0
|
|
|
|
logger.info(
|
|
"Demo session cleanup completed",
|
|
session_id=session_id,
|
|
virtual_tenant_id=virtual_tenant_id,
|
|
success=success,
|
|
total_deleted=total_deleted,
|
|
duration_ms=duration_ms,
|
|
error_count=len(errors)
|
|
)
|
|
|
|
return {
|
|
"success": success,
|
|
"total_deleted": total_deleted,
|
|
"duration_ms": duration_ms,
|
|
"details": details,
|
|
"errors": errors
|
|
}
|
|
|
|
async def _delete_from_service(
|
|
self,
|
|
service_name: str,
|
|
service_url: str,
|
|
virtual_tenant_id: str
|
|
) -> dict:
|
|
"""Delete all data from a single service"""
|
|
try:
|
|
# Create JWT service token with tenant context
|
|
service_token = self.jwt_handler.create_service_token(
|
|
service_name="demo-session",
|
|
tenant_id=virtual_tenant_id
|
|
)
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.delete(
|
|
f"{service_url}/internal/demo/tenant/{virtual_tenant_id}",
|
|
headers={
|
|
"Authorization": f"Bearer {service_token}",
|
|
"X-Service": "demo-session-service"
|
|
}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
elif response.status_code == 404:
|
|
# Already deleted or never existed - idempotent
|
|
return {
|
|
"service": service_name,
|
|
"status": "not_found",
|
|
"records_deleted": {"total": 0}
|
|
}
|
|
else:
|
|
raise Exception(f"HTTP {response.status_code}: {response.text}")
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to delete from service",
|
|
service=service_name,
|
|
virtual_tenant_id=virtual_tenant_id,
|
|
error=str(e)
|
|
)
|
|
raise
|
|
|
|
async def _delete_redis_cache(self, virtual_tenant_id: str):
|
|
"""Delete all Redis keys for a virtual tenant"""
|
|
try:
|
|
client = await self.redis.get_client()
|
|
pattern = f"*:{virtual_tenant_id}:*"
|
|
keys = await client.keys(pattern)
|
|
if keys:
|
|
await client.delete(*keys)
|
|
logger.debug("Deleted Redis cache", tenant_id=virtual_tenant_id, keys_deleted=len(keys))
|
|
except Exception as e:
|
|
logger.warning("Failed to delete Redis cache", error=str(e), tenant_id=virtual_tenant_id)
|
|
|
|
async def _delete_from_all_services(self, virtual_tenant_id: str):
|
|
"""Delete data from all services for a tenant"""
|
|
tasks = [
|
|
self._delete_from_service(name, url, virtual_tenant_id)
|
|
for name, url in self.services
|
|
]
|
|
return await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
async def _delete_tenant_data(self, tenant_id: str, session_id: str) -> dict:
|
|
"""Delete demo data for a tenant across all services"""
|
|
logger.info("Deleting tenant data", tenant_id=tenant_id, session_id=session_id)
|
|
|
|
results = {}
|
|
|
|
async def delete_from_service(service_name: str, service_url: str):
|
|
try:
|
|
# Create JWT service token with tenant context
|
|
service_token = self.jwt_handler.create_service_token(
|
|
service_name="demo-session",
|
|
tenant_id=tenant_id
|
|
)
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.delete(
|
|
f"{service_url}/internal/demo/tenant/{tenant_id}",
|
|
headers={
|
|
"Authorization": f"Bearer {service_token}",
|
|
"X-Service": "demo-session-service"
|
|
}
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
logger.debug(f"Deleted data from {service_name}", tenant_id=tenant_id)
|
|
return {"service": service_name, "status": "deleted"}
|
|
else:
|
|
logger.warning(
|
|
f"Failed to delete from {service_name}",
|
|
status_code=response.status_code,
|
|
tenant_id=tenant_id
|
|
)
|
|
return {"service": service_name, "status": "failed", "error": f"HTTP {response.status_code}"}
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Exception deleting from {service_name}",
|
|
error=str(e),
|
|
tenant_id=tenant_id
|
|
)
|
|
return {"service": service_name, "status": "failed", "error": str(e)}
|
|
|
|
# Delete from all services in parallel
|
|
tasks = [delete_from_service(name, url) for name, url in self.services]
|
|
service_results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
|
for result in service_results:
|
|
if isinstance(result, Exception):
|
|
logger.error("Service deletion failed", error=str(result))
|
|
elif isinstance(result, dict):
|
|
results[result["service"]] = result
|
|
|
|
return results
|
|
|
|
async def cleanup_expired_sessions(self) -> dict:
|
|
"""
|
|
Find and cleanup all expired sessions
|
|
Also cleans up sessions stuck in PENDING for too long (>5 minutes)
|
|
|
|
Returns:
|
|
Cleanup statistics
|
|
"""
|
|
logger.info("Starting demo session cleanup")
|
|
|
|
start_time = datetime.now(timezone.utc)
|
|
|
|
now = datetime.now(timezone.utc)
|
|
stuck_threshold = now - timedelta(minutes=5) # Sessions pending > 5 min are stuck
|
|
|
|
# Find expired sessions (any status except EXPIRED and DESTROYED)
|
|
result = await self.db.execute(
|
|
select(DemoSession).where(
|
|
DemoSession.status.in_([
|
|
DemoSessionStatus.PENDING,
|
|
DemoSessionStatus.READY,
|
|
DemoSessionStatus.PARTIAL,
|
|
DemoSessionStatus.FAILED,
|
|
DemoSessionStatus.ACTIVE # Legacy status, kept for compatibility
|
|
]),
|
|
DemoSession.expires_at < now
|
|
)
|
|
)
|
|
expired_sessions = result.scalars().all()
|
|
|
|
# Also find sessions stuck in PENDING
|
|
stuck_result = await self.db.execute(
|
|
select(DemoSession).where(
|
|
DemoSession.status == DemoSessionStatus.PENDING,
|
|
DemoSession.created_at < stuck_threshold
|
|
)
|
|
)
|
|
stuck_sessions = stuck_result.scalars().all()
|
|
|
|
# Combine both lists
|
|
all_sessions_to_cleanup = list(expired_sessions) + list(stuck_sessions)
|
|
|
|
stats = {
|
|
"total_expired": len(expired_sessions),
|
|
"total_stuck": len(stuck_sessions),
|
|
"total_to_cleanup": len(all_sessions_to_cleanup),
|
|
"cleaned_up": 0,
|
|
"failed": 0,
|
|
"errors": []
|
|
}
|
|
|
|
for session in all_sessions_to_cleanup:
|
|
try:
|
|
# Mark as expired
|
|
session.status = DemoSessionStatus.EXPIRED
|
|
await self.db.commit()
|
|
|
|
# Check if this is an enterprise demo with children
|
|
is_enterprise = session.demo_account_type == "enterprise"
|
|
child_tenant_ids = []
|
|
|
|
if is_enterprise and session.session_metadata:
|
|
child_tenant_ids = session.session_metadata.get("child_tenant_ids", [])
|
|
|
|
# Delete child tenants first (for enterprise demos)
|
|
if child_tenant_ids:
|
|
logger.info(
|
|
"Cleaning up enterprise demo children",
|
|
session_id=session.session_id,
|
|
child_count=len(child_tenant_ids)
|
|
)
|
|
for child_id in child_tenant_ids:
|
|
try:
|
|
await self._delete_tenant_data(child_id, session.session_id)
|
|
except Exception as child_error:
|
|
logger.error(
|
|
"Failed to delete child tenant",
|
|
child_id=child_id,
|
|
error=str(child_error)
|
|
)
|
|
|
|
# Delete parent/main session data
|
|
await self._delete_tenant_data(
|
|
str(session.virtual_tenant_id),
|
|
session.session_id
|
|
)
|
|
|
|
# Delete Redis data
|
|
await self.redis.delete_session_data(session.session_id)
|
|
|
|
stats["cleaned_up"] += 1
|
|
|
|
logger.info(
|
|
"Session cleaned up",
|
|
session_id=session.session_id,
|
|
is_enterprise=is_enterprise,
|
|
children_deleted=len(child_tenant_ids),
|
|
age_minutes=(now - session.created_at).total_seconds() / 60
|
|
)
|
|
|
|
except Exception as e:
|
|
stats["failed"] += 1
|
|
stats["errors"].append({
|
|
"session_id": session.session_id,
|
|
"error": str(e)
|
|
})
|
|
logger.error(
|
|
"Failed to cleanup session",
|
|
session_id=session.session_id,
|
|
error=str(e)
|
|
)
|
|
|
|
logger.info("Demo session cleanup completed", stats=stats)
|
|
|
|
return stats
|
|
|
|
async def cleanup_old_destroyed_sessions(self, days: int = 7) -> int:
|
|
"""
|
|
Delete destroyed session records older than specified days
|
|
|
|
Args:
|
|
days: Number of days to keep destroyed sessions
|
|
|
|
Returns:
|
|
Number of deleted records
|
|
"""
|
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days)
|
|
|
|
result = await self.db.execute(
|
|
select(DemoSession).where(
|
|
DemoSession.status == DemoSessionStatus.DESTROYED,
|
|
DemoSession.destroyed_at < cutoff_date
|
|
)
|
|
)
|
|
old_sessions = result.scalars().all()
|
|
|
|
for session in old_sessions:
|
|
await self.db.delete(session)
|
|
|
|
await self.db.commit()
|
|
|
|
logger.info(
|
|
"Old destroyed sessions deleted",
|
|
count=len(old_sessions),
|
|
older_than_days=days
|
|
)
|
|
|
|
return len(old_sessions)
|
|
|
|
async def get_cleanup_stats(self) -> dict:
|
|
"""Get cleanup statistics"""
|
|
result = await self.db.execute(select(DemoSession))
|
|
all_sessions = result.scalars().all()
|
|
|
|
now = datetime.now(timezone.utc)
|
|
|
|
# Count by status
|
|
pending_count = len([s for s in all_sessions if s.status == DemoSessionStatus.PENDING])
|
|
ready_count = len([s for s in all_sessions if s.status == DemoSessionStatus.READY])
|
|
partial_count = len([s for s in all_sessions if s.status == DemoSessionStatus.PARTIAL])
|
|
failed_count = len([s for s in all_sessions if s.status == DemoSessionStatus.FAILED])
|
|
active_count = len([s for s in all_sessions if s.status == DemoSessionStatus.ACTIVE])
|
|
expired_count = len([s for s in all_sessions if s.status == DemoSessionStatus.EXPIRED])
|
|
destroyed_count = len([s for s in all_sessions if s.status == DemoSessionStatus.DESTROYED])
|
|
|
|
# Find sessions that should be expired but aren't marked yet
|
|
should_be_expired = len([
|
|
s for s in all_sessions
|
|
if s.status in [
|
|
DemoSessionStatus.PENDING,
|
|
DemoSessionStatus.READY,
|
|
DemoSessionStatus.PARTIAL,
|
|
DemoSessionStatus.FAILED,
|
|
DemoSessionStatus.ACTIVE
|
|
] and s.expires_at < now
|
|
])
|
|
|
|
return {
|
|
"total_sessions": len(all_sessions),
|
|
"by_status": {
|
|
"pending": pending_count,
|
|
"ready": ready_count,
|
|
"partial": partial_count,
|
|
"failed": failed_count,
|
|
"active": active_count, # Legacy
|
|
"expired": expired_count,
|
|
"destroyed": destroyed_count
|
|
},
|
|
"pending_cleanup": should_be_expired
|
|
}
|