demo seed change

This commit is contained in:
Urtzi Alfaro
2025-12-13 23:57:54 +01:00
parent f3688dfb04
commit ff830a3415
299 changed files with 20328 additions and 19485 deletions

View File

@@ -13,6 +13,7 @@ from typing import Optional
import os
import sys
from pathlib import Path
import json
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
from shared.utils.demo_dates import adjust_date_for_demo, BASE_REFERENCE_DATE
@@ -21,7 +22,7 @@ from app.core.database import get_db
from app.models.forecasts import Forecast, PredictionBatch
logger = structlog.get_logger()
router = APIRouter(prefix="/internal/demo", tags=["internal"])
router = APIRouter()
# Base demo tenant IDs
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
@@ -36,7 +37,7 @@ def verify_internal_api_key(x_internal_api_key: Optional[str] = Header(None)):
return True
@router.post("/clone")
@router.post("/internal/demo/clone")
async def clone_demo_data(
base_tenant_id: str,
virtual_tenant_id: str,
@@ -49,50 +50,95 @@ async def clone_demo_data(
"""
Clone forecasting service data for a virtual demo tenant
Clones:
- Forecasts (historical predictions)
- Prediction batches (batch prediction records)
This endpoint creates fresh demo data by:
1. Loading seed data from JSON files
2. Applying XOR-based ID transformation
3. Adjusting dates relative to session creation time
4. Creating records in the virtual tenant
Args:
base_tenant_id: Template tenant UUID to clone from
base_tenant_id: Template tenant UUID (for reference)
virtual_tenant_id: Target virtual tenant UUID
demo_account_type: Type of demo account
session_id: Originating session ID for tracing
session_created_at: ISO timestamp when demo session was created (for date adjustment)
session_created_at: Session creation timestamp for date adjustment
db: Database session
Returns:
Cloning status and record counts
Dictionary with cloning results
Raises:
HTTPException: On validation or cloning errors
"""
start_time = datetime.now(timezone.utc)
# Parse session_created_at or fallback to now
if session_created_at:
try:
session_time = datetime.fromisoformat(session_created_at.replace('Z', '+00:00'))
except (ValueError, AttributeError) as e:
logger.warning(
"Invalid session_created_at format, using current time",
session_created_at=session_created_at,
error=str(e)
)
session_time = datetime.now(timezone.utc)
else:
logger.warning("session_created_at not provided, using current time")
session_time = datetime.now(timezone.utc)
logger.info(
"Starting forecasting data cloning",
base_tenant_id=base_tenant_id,
virtual_tenant_id=virtual_tenant_id,
demo_account_type=demo_account_type,
session_id=session_id,
session_time=session_time.isoformat()
)
try:
# Validate UUIDs
base_uuid = uuid.UUID(base_tenant_id)
virtual_uuid = uuid.UUID(virtual_tenant_id)
# Parse session creation time for date adjustment
if session_created_at:
try:
session_time = datetime.fromisoformat(session_created_at.replace('Z', '+00:00'))
except (ValueError, AttributeError):
session_time = start_time
else:
session_time = start_time
logger.info(
"Starting forecasting data cloning with date adjustment",
base_tenant_id=base_tenant_id,
virtual_tenant_id=str(virtual_uuid),
demo_account_type=demo_account_type,
session_id=session_id,
session_time=session_time.isoformat()
)
# Load seed data using shared utility
try:
from shared.utils.seed_data_paths import get_seed_data_path
if demo_account_type == "enterprise":
profile = "enterprise"
else:
profile = "professional"
json_file = get_seed_data_path(profile, "10-forecasting.json")
except ImportError:
# Fallback to original path
seed_data_dir = Path(__file__).parent.parent.parent.parent / "shared" / "demo" / "fixtures"
if demo_account_type == "enterprise":
json_file = seed_data_dir / "enterprise" / "parent" / "10-forecasting.json"
else:
json_file = seed_data_dir / "professional" / "10-forecasting.json"
if not json_file.exists():
raise HTTPException(
status_code=404,
detail=f"Seed data file not found: {json_file}"
)
# Load JSON data
with open(json_file, 'r', encoding='utf-8') as f:
seed_data = json.load(f)
# Check if data already exists for this virtual tenant (idempotency)
existing_check = await db.execute(
select(Forecast).where(Forecast.tenant_id == virtual_uuid).limit(1)
)
existing_forecast = existing_check.scalar_one_or_none()
if existing_forecast:
logger.warning(
"Demo data already exists, skipping clone",
virtual_tenant_id=str(virtual_uuid)
)
return {
"status": "skipped",
"reason": "Data already exists",
"records_cloned": 0
}
# Track cloning statistics
stats = {
@@ -100,93 +146,150 @@ async def clone_demo_data(
"prediction_batches": 0
}
# Clone Forecasts
result = await db.execute(
select(Forecast).where(Forecast.tenant_id == base_uuid)
)
base_forecasts = result.scalars().all()
# Transform and insert forecasts
for forecast_data in seed_data.get('forecasts', []):
# Transform ID using XOR
from shared.utils.demo_id_transformer import transform_id
try:
forecast_uuid = uuid.UUID(forecast_data['id'])
tenant_uuid = uuid.UUID(virtual_tenant_id)
transformed_id = transform_id(forecast_data['id'], tenant_uuid)
except ValueError as e:
logger.error("Failed to parse UUIDs for ID transformation",
forecast_id=forecast_data['id'],
virtual_tenant_id=virtual_tenant_id,
error=str(e))
raise HTTPException(
status_code=400,
detail=f"Invalid UUID format in forecast data: {str(e)}"
)
# Transform dates
for date_field in ['forecast_date', 'created_at']:
if date_field in forecast_data:
try:
date_value = forecast_data[date_field]
if isinstance(date_value, str):
original_date = datetime.fromisoformat(date_value)
elif hasattr(date_value, 'isoformat'):
original_date = date_value
else:
logger.warning("Skipping invalid date format",
date_field=date_field,
date_value=date_value)
continue
adjusted_forecast_date = adjust_date_for_demo(
original_date,
session_time,
BASE_REFERENCE_DATE
)
forecast_data[date_field] = adjusted_forecast_date
except (ValueError, AttributeError) as e:
logger.warning("Failed to parse date, skipping",
date_field=date_field,
date_value=forecast_data[date_field],
error=str(e))
forecast_data.pop(date_field, None)
# Create forecast
# Map product_id to inventory_product_id if needed
inventory_product_id = forecast_data.get('inventory_product_id') or forecast_data.get('product_id')
logger.info(
"Found forecasts to clone",
count=len(base_forecasts),
base_tenant=str(base_uuid)
)
for forecast in base_forecasts:
adjusted_forecast_date = adjust_date_for_demo(
forecast.forecast_date,
session_time,
BASE_REFERENCE_DATE
) if forecast.forecast_date else None
# Map predicted_quantity to predicted_demand if needed
predicted_demand = forecast_data.get('predicted_demand') or forecast_data.get('predicted_quantity')
new_forecast = Forecast(
id=uuid.uuid4(),
id=transformed_id,
tenant_id=virtual_uuid,
inventory_product_id=forecast.inventory_product_id, # Keep product reference
product_name=forecast.product_name,
location=forecast.location,
forecast_date=adjusted_forecast_date,
created_at=session_time,
predicted_demand=forecast.predicted_demand,
confidence_lower=forecast.confidence_lower,
confidence_upper=forecast.confidence_upper,
confidence_level=forecast.confidence_level,
model_id=forecast.model_id,
model_version=forecast.model_version,
algorithm=forecast.algorithm,
business_type=forecast.business_type,
day_of_week=forecast.day_of_week,
is_holiday=forecast.is_holiday,
is_weekend=forecast.is_weekend,
weather_temperature=forecast.weather_temperature,
weather_precipitation=forecast.weather_precipitation,
weather_description=forecast.weather_description,
traffic_volume=forecast.traffic_volume,
processing_time_ms=forecast.processing_time_ms,
features_used=forecast.features_used
inventory_product_id=inventory_product_id,
product_name=forecast_data.get('product_name'),
location=forecast_data.get('location'),
forecast_date=forecast_data.get('forecast_date'),
created_at=forecast_data.get('created_at', session_time),
predicted_demand=predicted_demand,
confidence_lower=forecast_data.get('confidence_lower'),
confidence_upper=forecast_data.get('confidence_upper'),
confidence_level=forecast_data.get('confidence_level', 0.8),
model_id=forecast_data.get('model_id'),
model_version=forecast_data.get('model_version'),
algorithm=forecast_data.get('algorithm', 'prophet'),
business_type=forecast_data.get('business_type', 'individual'),
day_of_week=forecast_data.get('day_of_week'),
is_holiday=forecast_data.get('is_holiday', False),
is_weekend=forecast_data.get('is_weekend', False),
weather_temperature=forecast_data.get('weather_temperature'),
weather_precipitation=forecast_data.get('weather_precipitation'),
weather_description=forecast_data.get('weather_description'),
traffic_volume=forecast_data.get('traffic_volume'),
processing_time_ms=forecast_data.get('processing_time_ms'),
features_used=forecast_data.get('features_used')
)
db.add(new_forecast)
stats["forecasts"] += 1
# Clone Prediction Batches
result = await db.execute(
select(PredictionBatch).where(PredictionBatch.tenant_id == base_uuid)
)
base_batches = result.scalars().all()
logger.info(
"Found prediction batches to clone",
count=len(base_batches),
base_tenant=str(base_uuid)
)
for batch in base_batches:
adjusted_requested_at = adjust_date_for_demo(
batch.requested_at,
session_time,
BASE_REFERENCE_DATE
) if batch.requested_at else None
adjusted_completed_at = adjust_date_for_demo(
batch.completed_at,
session_time,
BASE_REFERENCE_DATE
) if batch.completed_at else None
# Transform and insert prediction batches
for batch_data in seed_data.get('prediction_batches', []):
# Transform ID using XOR
from shared.utils.demo_id_transformer import transform_id
try:
batch_uuid = uuid.UUID(batch_data['id'])
tenant_uuid = uuid.UUID(virtual_tenant_id)
transformed_id = transform_id(batch_data['id'], tenant_uuid)
except ValueError as e:
logger.error("Failed to parse UUIDs for ID transformation",
batch_id=batch_data['id'],
virtual_tenant_id=virtual_tenant_id,
error=str(e))
raise HTTPException(
status_code=400,
detail=f"Invalid UUID format in batch data: {str(e)}"
)
# Transform dates
for date_field in ['requested_at', 'completed_at']:
if date_field in batch_data:
try:
date_value = batch_data[date_field]
if isinstance(date_value, str):
original_date = datetime.fromisoformat(date_value)
elif hasattr(date_value, 'isoformat'):
original_date = date_value
else:
logger.warning("Skipping invalid date format",
date_field=date_field,
date_value=date_value)
continue
adjusted_batch_date = adjust_date_for_demo(
original_date,
session_time,
BASE_REFERENCE_DATE
)
batch_data[date_field] = adjusted_batch_date
except (ValueError, AttributeError) as e:
logger.warning("Failed to parse date, skipping",
date_field=date_field,
date_value=batch_data[date_field],
error=str(e))
batch_data.pop(date_field, None)
# Create prediction batch
new_batch = PredictionBatch(
id=uuid.uuid4(),
id=transformed_id,
tenant_id=virtual_uuid,
batch_name=batch.batch_name,
requested_at=adjusted_requested_at,
completed_at=adjusted_completed_at,
status=batch.status,
total_products=batch.total_products,
completed_products=batch.completed_products,
failed_products=batch.failed_products,
forecast_days=batch.forecast_days,
business_type=batch.business_type,
error_message=batch.error_message,
processing_time_ms=batch.processing_time_ms,
cancelled_by=batch.cancelled_by
batch_name=batch_data.get('batch_name'),
requested_at=batch_data.get('requested_at'),
completed_at=batch_data.get('completed_at'),
status=batch_data.get('status'),
total_products=batch_data.get('total_products'),
completed_products=batch_data.get('completed_products'),
failed_products=batch_data.get('failed_products'),
forecast_days=batch_data.get('forecast_days'),
business_type=batch_data.get('business_type'),
error_message=batch_data.get('error_message'),
processing_time_ms=batch_data.get('processing_time_ms'),
cancelled_by=batch_data.get('cancelled_by')
)
db.add(new_batch)
stats["prediction_batches"] += 1
@@ -198,11 +301,12 @@ async def clone_demo_data(
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
"Forecasting data cloning completed",
virtual_tenant_id=virtual_tenant_id,
total_records=total_records,
stats=stats,
duration_ms=duration_ms
"Forecasting data cloned successfully",
virtual_tenant_id=str(virtual_uuid),
records_cloned=total_records,
duration_ms=duration_ms,
forecasts_cloned=stats["forecasts"],
batches_cloned=stats["prediction_batches"]
)
return {
@@ -210,11 +314,15 @@ async def clone_demo_data(
"status": "completed",
"records_cloned": total_records,
"duration_ms": duration_ms,
"details": stats
"details": {
"forecasts": stats["forecasts"],
"prediction_batches": stats["prediction_batches"],
"virtual_tenant_id": str(virtual_uuid)
}
}
except ValueError as e:
logger.error("Invalid UUID format", error=str(e))
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
except Exception as e:
@@ -248,3 +356,73 @@ async def clone_health_check(_: bool = Depends(verify_internal_api_key)):
"clone_endpoint": "available",
"version": "2.0.0"
}
@router.delete("/tenant/{virtual_tenant_id}")
async def delete_demo_tenant_data(
virtual_tenant_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
_: bool = Depends(verify_internal_api_key)
):
"""
Delete all demo data for a virtual tenant.
This endpoint is idempotent - safe to call multiple times.
"""
from sqlalchemy import delete
start_time = datetime.now(timezone.utc)
records_deleted = {
"forecasts": 0,
"prediction_batches": 0,
"total": 0
}
try:
# Delete in reverse dependency order
# 1. Delete prediction batches
result = await db.execute(
delete(PredictionBatch)
.where(PredictionBatch.tenant_id == virtual_tenant_id)
)
records_deleted["prediction_batches"] = result.rowcount
# 2. Delete forecasts
result = await db.execute(
delete(Forecast)
.where(Forecast.tenant_id == virtual_tenant_id)
)
records_deleted["forecasts"] = result.rowcount
records_deleted["total"] = sum(records_deleted.values())
await db.commit()
logger.info(
"demo_data_deleted",
service="forecasting",
virtual_tenant_id=str(virtual_tenant_id),
records_deleted=records_deleted
)
return {
"service": "forecasting",
"status": "deleted",
"virtual_tenant_id": str(virtual_tenant_id),
"records_deleted": records_deleted,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
}
except Exception as e:
await db.rollback()
logger.error(
"demo_data_deletion_failed",
service="forecasting",
virtual_tenant_id=str(virtual_tenant_id),
error=str(e)
)
raise HTTPException(
status_code=500,
detail=f"Failed to delete demo data: {str(e)}"
)