Files
bakery-ia/services/forecasting/app/api/internal_demo.py

477 lines
18 KiB
Python
Raw Permalink Normal View History

"""
Internal Demo Cloning API for Forecasting Service
Service-to-service endpoint for cloning forecast data
"""
from fastapi import APIRouter, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
import structlog
import uuid
from datetime import datetime, timezone, timedelta
from typing import Optional
import os
import sys
from pathlib import Path
2025-12-13 23:57:54 +01:00
import json
sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent))
2025-12-14 11:58:14 +01:00
from shared.utils.demo_dates import adjust_date_for_demo, resolve_time_marker
from app.core.database import get_db
from app.models.forecasts import Forecast, PredictionBatch
logger = structlog.get_logger()
2026-01-02 11:12:50 +01:00
router = APIRouter(prefix="/internal/demo", tags=["internal"])
# Base demo tenant IDs
2025-11-30 09:12:40 +01:00
DEMO_TENANT_PROFESSIONAL = "a1b2c3d4-e5f6-47a8-b9c0-d1e2f3a4b5c6"
2025-12-14 11:58:14 +01:00
def parse_date_field(date_value, session_time: datetime, field_name: str = "date") -> Optional[datetime]:
"""
Parse date field, handling both ISO strings and BASE_TS markers.
Supports:
- BASE_TS markers: "BASE_TS + 1h30m", "BASE_TS - 2d"
- ISO 8601 strings: "2025-01-15T06:00:00Z"
- None values (returns None)
Returns timezone-aware datetime or None.
"""
if not date_value:
return None
# Check if it's a BASE_TS marker
if isinstance(date_value, str) and date_value.startswith("BASE_TS"):
try:
return resolve_time_marker(date_value, session_time)
except ValueError as e:
logger.warning(
f"Invalid BASE_TS marker in {field_name}",
marker=date_value,
error=str(e)
)
return None
# Handle regular ISO date strings
try:
if isinstance(date_value, str):
original_date = datetime.fromisoformat(date_value.replace('Z', '+00:00'))
elif hasattr(date_value, 'isoformat'):
original_date = date_value
else:
logger.warning(f"Unsupported date format in {field_name}", date_value=date_value)
return None
return adjust_date_for_demo(original_date, session_time)
except (ValueError, AttributeError) as e:
logger.warning(
f"Invalid date format in {field_name}",
date_value=date_value,
error=str(e)
)
return None
def align_to_week_start(target_date: datetime) -> datetime:
"""Align forecast date to Monday (start of week)"""
if target_date:
days_since_monday = target_date.weekday()
return target_date - timedelta(days=days_since_monday)
return target_date
2026-01-02 11:12:50 +01:00
@router.post("/clone")
async def clone_demo_data(
base_tenant_id: str,
virtual_tenant_id: str,
demo_account_type: str,
session_id: Optional[str] = None,
session_created_at: Optional[str] = None,
2026-01-12 14:24:14 +01:00
db: AsyncSession = Depends(get_db)
):
"""
Clone forecasting service data for a virtual demo tenant
2025-12-13 23:57:54 +01:00
This endpoint creates fresh demo data by:
1. Loading seed data from JSON files
2. Applying XOR-based ID transformation
3. Adjusting dates relative to session creation time
4. Creating records in the virtual tenant
Args:
2025-12-13 23:57:54 +01:00
base_tenant_id: Template tenant UUID (for reference)
virtual_tenant_id: Target virtual tenant UUID
demo_account_type: Type of demo account
session_id: Originating session ID for tracing
2025-12-13 23:57:54 +01:00
session_created_at: Session creation timestamp for date adjustment
db: Database session
Returns:
2025-12-13 23:57:54 +01:00
Dictionary with cloning results
Raises:
HTTPException: On validation or cloning errors
"""
start_time = datetime.now(timezone.utc)
2025-12-13 23:57:54 +01:00
try:
# Validate UUIDs
virtual_uuid = uuid.UUID(virtual_tenant_id)
# Parse session creation time for date adjustment
if session_created_at:
try:
session_time = datetime.fromisoformat(session_created_at.replace('Z', '+00:00'))
except (ValueError, AttributeError):
session_time = start_time
else:
session_time = start_time
2025-12-13 23:57:54 +01:00
logger.info(
"Starting forecasting data cloning with date adjustment",
base_tenant_id=base_tenant_id,
virtual_tenant_id=str(virtual_uuid),
demo_account_type=demo_account_type,
session_id=session_id,
session_time=session_time.isoformat()
)
# Load seed data using shared utility
try:
2025-12-13 23:57:54 +01:00
from shared.utils.seed_data_paths import get_seed_data_path
if demo_account_type == "enterprise":
profile = "enterprise"
else:
profile = "professional"
json_file = get_seed_data_path(profile, "10-forecasting.json")
except ImportError:
# Fallback to original path
seed_data_dir = Path(__file__).parent.parent.parent.parent / "shared" / "demo" / "fixtures"
if demo_account_type == "enterprise":
json_file = seed_data_dir / "enterprise" / "parent" / "10-forecasting.json"
else:
json_file = seed_data_dir / "professional" / "10-forecasting.json"
if not json_file.exists():
raise HTTPException(
status_code=404,
detail=f"Seed data file not found: {json_file}"
)
2025-12-13 23:57:54 +01:00
# Load JSON data
with open(json_file, 'r', encoding='utf-8') as f:
seed_data = json.load(f)
# Check if data already exists for this virtual tenant (idempotency)
existing_check = await db.execute(
select(Forecast).where(Forecast.tenant_id == virtual_uuid).limit(1)
)
existing_forecast = existing_check.scalar_one_or_none()
if existing_forecast:
logger.warning(
"Demo data already exists, skipping clone",
virtual_tenant_id=str(virtual_uuid)
)
return {
"status": "skipped",
"reason": "Data already exists",
"records_cloned": 0
}
# Track cloning statistics
stats = {
"forecasts": 0,
"prediction_batches": 0
}
2025-12-13 23:57:54 +01:00
# Transform and insert forecasts
for forecast_data in seed_data.get('forecasts', []):
# Transform ID using XOR
from shared.utils.demo_id_transformer import transform_id
try:
forecast_uuid = uuid.UUID(forecast_data['id'])
tenant_uuid = uuid.UUID(virtual_tenant_id)
transformed_id = transform_id(forecast_data['id'], tenant_uuid)
except ValueError as e:
logger.error("Failed to parse UUIDs for ID transformation",
forecast_id=forecast_data['id'],
virtual_tenant_id=virtual_tenant_id,
error=str(e))
raise HTTPException(
status_code=400,
detail=f"Invalid UUID format in forecast data: {str(e)}"
)
2025-12-14 16:04:16 +01:00
# Transform dates using the proper parse_date_field function
2025-12-13 23:57:54 +01:00
for date_field in ['forecast_date', 'created_at']:
if date_field in forecast_data:
try:
2025-12-14 16:04:16 +01:00
parsed_date = parse_date_field(
forecast_data[date_field],
session_time,
date_field
)
if parsed_date:
forecast_data[date_field] = parsed_date
2025-12-13 23:57:54 +01:00
else:
2025-12-14 16:04:16 +01:00
# If parsing fails, use session_time as fallback
forecast_data[date_field] = session_time
logger.warning("Using fallback date for failed parsing",
2025-12-13 23:57:54 +01:00
date_field=date_field,
2025-12-14 16:04:16 +01:00
original_value=forecast_data[date_field])
except Exception as e:
logger.warning("Failed to parse date, using fallback",
2025-12-13 23:57:54 +01:00
date_field=date_field,
date_value=forecast_data[date_field],
error=str(e))
2025-12-14 16:04:16 +01:00
forecast_data[date_field] = session_time
2025-12-13 23:57:54 +01:00
# Create forecast
# Map product_id to inventory_product_id if needed
inventory_product_id_str = forecast_data.get('inventory_product_id') or forecast_data.get('product_id')
# Convert to UUID if it's a string
if isinstance(inventory_product_id_str, str):
inventory_product_id = uuid.UUID(inventory_product_id_str)
else:
inventory_product_id = inventory_product_id_str
2025-12-13 23:57:54 +01:00
# Map predicted_quantity to predicted_demand if needed
predicted_demand = forecast_data.get('predicted_demand') or forecast_data.get('predicted_quantity')
2025-12-14 16:04:16 +01:00
# Set default location if not provided in seed data
location = forecast_data.get('location') or "Main Bakery"
2025-12-14 19:05:37 +01:00
# Get or calculate forecast date
forecast_date = forecast_data.get('forecast_date')
if not forecast_date:
forecast_date = session_time
# Calculate day_of_week from forecast_date if not provided
# day_of_week should be 0-6 (Monday=0, Sunday=6)
day_of_week = forecast_data.get('day_of_week')
if day_of_week is None and forecast_date:
day_of_week = forecast_date.weekday()
# Calculate is_weekend from day_of_week if not provided
is_weekend = forecast_data.get('is_weekend')
if is_weekend is None and day_of_week is not None:
is_weekend = day_of_week >= 5 # Saturday=5, Sunday=6
else:
is_weekend = False
new_forecast = Forecast(
2025-12-13 23:57:54 +01:00
id=transformed_id,
tenant_id=virtual_uuid,
2025-12-13 23:57:54 +01:00
inventory_product_id=inventory_product_id,
product_name=forecast_data.get('product_name'),
2025-12-14 16:04:16 +01:00
location=location,
2025-12-14 19:05:37 +01:00
forecast_date=forecast_date,
2025-12-13 23:57:54 +01:00
created_at=forecast_data.get('created_at', session_time),
predicted_demand=predicted_demand,
2025-12-14 16:04:16 +01:00
confidence_lower=forecast_data.get('confidence_lower', max(0.0, float(predicted_demand or 0.0) * 0.8)),
confidence_upper=forecast_data.get('confidence_upper', max(0.0, float(predicted_demand or 0.0) * 1.2)),
2025-12-13 23:57:54 +01:00
confidence_level=forecast_data.get('confidence_level', 0.8),
2025-12-14 19:05:37 +01:00
model_id=forecast_data.get('model_id') or 'default-fallback-model',
model_version=forecast_data.get('model_version') or '1.0',
2025-12-13 23:57:54 +01:00
algorithm=forecast_data.get('algorithm', 'prophet'),
business_type=forecast_data.get('business_type', 'individual'),
2025-12-14 19:05:37 +01:00
day_of_week=day_of_week,
2025-12-13 23:57:54 +01:00
is_holiday=forecast_data.get('is_holiday', False),
2025-12-14 19:05:37 +01:00
is_weekend=is_weekend,
2025-12-13 23:57:54 +01:00
weather_temperature=forecast_data.get('weather_temperature'),
weather_precipitation=forecast_data.get('weather_precipitation'),
weather_description=forecast_data.get('weather_description'),
traffic_volume=forecast_data.get('traffic_volume'),
processing_time_ms=forecast_data.get('processing_time_ms'),
features_used=forecast_data.get('features_used')
)
db.add(new_forecast)
stats["forecasts"] += 1
2025-12-13 23:57:54 +01:00
# Transform and insert prediction batches
for batch_data in seed_data.get('prediction_batches', []):
# Transform ID using XOR
from shared.utils.demo_id_transformer import transform_id
try:
batch_uuid = uuid.UUID(batch_data['id'])
tenant_uuid = uuid.UUID(virtual_tenant_id)
transformed_id = transform_id(batch_data['id'], tenant_uuid)
except ValueError as e:
logger.error("Failed to parse UUIDs for ID transformation",
batch_id=batch_data['id'],
virtual_tenant_id=virtual_tenant_id,
error=str(e))
raise HTTPException(
status_code=400,
detail=f"Invalid UUID format in batch data: {str(e)}"
)
# Create prediction batch
# Handle field mapping: batch_id -> batch_name, total_forecasts -> total_products
batch_name = batch_data.get('batch_name') or batch_data.get('batch_id') or f"Batch-{transformed_id}"
total_products = batch_data.get('total_products') or batch_data.get('total_forecasts') or 0
completed_products = batch_data.get('completed_products') or (total_products if batch_data.get('status') == 'COMPLETED' else 0)
# Parse dates (handle created_at or prediction_date for requested_at)
requested_at_raw = batch_data.get('requested_at') or batch_data.get('created_at') or batch_data.get('prediction_date')
requested_at = parse_date_field(requested_at_raw, session_time, 'requested_at') if requested_at_raw else session_time
completed_at_raw = batch_data.get('completed_at')
completed_at = parse_date_field(completed_at_raw, session_time, 'completed_at') if completed_at_raw else None
new_batch = PredictionBatch(
2025-12-13 23:57:54 +01:00
id=transformed_id,
tenant_id=virtual_uuid,
batch_name=batch_name,
requested_at=requested_at,
completed_at=completed_at,
status=batch_data.get('status', 'completed'),
total_products=total_products,
completed_products=completed_products,
failed_products=batch_data.get('failed_products', 0),
forecast_days=batch_data.get('forecast_days', 7),
business_type=batch_data.get('business_type', 'individual'),
2025-12-13 23:57:54 +01:00
error_message=batch_data.get('error_message'),
processing_time_ms=batch_data.get('processing_time_ms'),
cancelled_by=batch_data.get('cancelled_by')
)
db.add(new_batch)
stats["prediction_batches"] += 1
# Commit all changes
await db.commit()
total_records = sum(stats.values())
duration_ms = int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
logger.info(
2025-12-13 23:57:54 +01:00
"Forecasting data cloned successfully",
virtual_tenant_id=str(virtual_uuid),
records_cloned=total_records,
duration_ms=duration_ms,
forecasts_cloned=stats["forecasts"],
batches_cloned=stats["prediction_batches"]
)
return {
"service": "forecasting",
"status": "completed",
"records_cloned": total_records,
"duration_ms": duration_ms,
2025-12-13 23:57:54 +01:00
"details": {
"forecasts": stats["forecasts"],
"prediction_batches": stats["prediction_batches"],
"virtual_tenant_id": str(virtual_uuid)
}
}
except ValueError as e:
2025-12-13 23:57:54 +01:00
logger.error("Invalid UUID format", error=str(e), virtual_tenant_id=virtual_tenant_id)
raise HTTPException(status_code=400, detail=f"Invalid UUID: {str(e)}")
except Exception as e:
logger.error(
"Failed to clone forecasting data",
error=str(e),
virtual_tenant_id=virtual_tenant_id,
exc_info=True
)
# Rollback on error
await db.rollback()
return {
"service": "forecasting",
"status": "failed",
"records_cloned": 0,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000),
"error": str(e)
}
@router.get("/clone/health")
2026-01-12 14:24:14 +01:00
async def clone_health_check():
"""
Health check for internal cloning endpoint
Used by orchestrator to verify service availability
"""
return {
"service": "forecasting",
"clone_endpoint": "available",
"version": "2.0.0"
}
2025-12-13 23:57:54 +01:00
@router.delete("/tenant/{virtual_tenant_id}")
async def delete_demo_tenant_data(
virtual_tenant_id: uuid.UUID,
2026-01-12 14:24:14 +01:00
db: AsyncSession = Depends(get_db)
2025-12-13 23:57:54 +01:00
):
"""
Delete all demo data for a virtual tenant.
This endpoint is idempotent - safe to call multiple times.
"""
from sqlalchemy import delete
start_time = datetime.now(timezone.utc)
records_deleted = {
"forecasts": 0,
"prediction_batches": 0,
"total": 0
}
try:
# Delete in reverse dependency order
# 1. Delete prediction batches
result = await db.execute(
delete(PredictionBatch)
.where(PredictionBatch.tenant_id == virtual_tenant_id)
)
records_deleted["prediction_batches"] = result.rowcount
# 2. Delete forecasts
result = await db.execute(
delete(Forecast)
.where(Forecast.tenant_id == virtual_tenant_id)
)
records_deleted["forecasts"] = result.rowcount
records_deleted["total"] = sum(records_deleted.values())
await db.commit()
logger.info(
"demo_data_deleted",
service="forecasting",
virtual_tenant_id=str(virtual_tenant_id),
records_deleted=records_deleted
)
return {
"service": "forecasting",
"status": "deleted",
"virtual_tenant_id": str(virtual_tenant_id),
"records_deleted": records_deleted,
"duration_ms": int((datetime.now(timezone.utc) - start_time).total_seconds() * 1000)
}
except Exception as e:
await db.rollback()
logger.error(
"demo_data_deletion_failed",
service="forecasting",
virtual_tenant_id=str(virtual_tenant_id),
error=str(e)
)
raise HTTPException(
status_code=500,
detail=f"Failed to delete demo data: {str(e)}"
)