Improve user delete flow

This commit is contained in:
Urtzi Alfaro
2025-08-02 17:09:53 +02:00
parent 277e8bec73
commit 3681429e11
10 changed files with 1334 additions and 210 deletions

View File

@@ -322,30 +322,15 @@ async def acknowledge_alert(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
@router.delete("/forecasts/tenant/{tenant_id}")
async def delete_tenant_forecasts_complete(
@router.delete("/tenants/{tenant_id}/forecasts")
async def delete_tenant_forecasts(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""
Delete all forecasts and predictions for a tenant.
**WARNING: This operation is irreversible!**
This endpoint:
1. Cancels any active prediction batches
2. Clears prediction cache
3. Deletes all forecast records
4. Deletes prediction batch records
5. Deletes model performance metrics
6. Publishes deletion event
Used by admin user deletion process to clean up all forecasting data.
"""
"""Delete all forecasts and predictions for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
@@ -355,163 +340,155 @@ async def delete_tenant_forecasts_complete(
)
try:
from app.models.forecasts import Forecast, PredictionBatch
from app.models.predictions import ModelPerformanceMetric, PredictionCache
from app.models.forecasts import Forecast, Prediction, PredictionBatch
deletion_stats = {
"tenant_id": tenant_id,
"deleted_at": datetime.utcnow().isoformat(),
"batches_cancelled": 0,
"forecasts_deleted": 0,
"prediction_batches_deleted": 0,
"performance_metrics_deleted": 0,
"cache_entries_deleted": 0,
"predictions_deleted": 0,
"batches_deleted": 0,
"errors": []
}
# Step 1: Cancel active prediction batches
# Count before deletion
forecasts_count_query = select(func.count(Forecast.id)).where(
Forecast.tenant_id == tenant_uuid
)
forecasts_count_result = await db.execute(forecasts_count_query)
forecasts_count = forecasts_count_result.scalar()
predictions_count_query = select(func.count(Prediction.id)).where(
Prediction.tenant_id == tenant_uuid
)
predictions_count_result = await db.execute(predictions_count_query)
predictions_count = predictions_count_result.scalar()
batches_count_query = select(func.count(PredictionBatch.id)).where(
PredictionBatch.tenant_id == tenant_uuid
)
batches_count_result = await db.execute(batches_count_query)
batches_count = batches_count_result.scalar()
# Delete predictions first (they may reference forecasts)
try:
active_batches_query = select(PredictionBatch).where(
PredictionBatch.tenant_id == tenant_uuid,
PredictionBatch.status.in_(["pending", "processing"])
predictions_delete_query = delete(Prediction).where(
Prediction.tenant_id == tenant_uuid
)
active_batches_result = await db.execute(active_batches_query)
active_batches = active_batches_result.scalars().all()
for batch in active_batches:
batch.status = "cancelled"
batch.completed_at = datetime.utcnow()
deletion_stats["batches_cancelled"] += 1
if active_batches:
await db.commit()
logger.info("Cancelled active prediction batches",
tenant_id=tenant_id,
count=len(active_batches))
predictions_delete_result = await db.execute(predictions_delete_query)
deletion_stats["predictions_deleted"] = predictions_delete_result.rowcount
except Exception as e:
error_msg = f"Error cancelling prediction batches: {str(e)}"
error_msg = f"Error deleting predictions: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 2: Delete prediction cache
# Delete prediction batches
try:
cache_count_query = select(func.count(PredictionCache.id)).where(
PredictionCache.tenant_id == tenant_uuid
)
cache_count_result = await db.execute(cache_count_query)
cache_count = cache_count_result.scalar()
cache_delete_query = delete(PredictionCache).where(
PredictionCache.tenant_id == tenant_uuid
)
await db.execute(cache_delete_query)
deletion_stats["cache_entries_deleted"] = cache_count
logger.info("Deleted prediction cache entries",
tenant_id=tenant_id,
count=cache_count)
except Exception as e:
error_msg = f"Error deleting prediction cache: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 3: Delete model performance metrics
try:
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
metrics_count_result = await db.execute(metrics_count_query)
metrics_count = metrics_count_result.scalar()
metrics_delete_query = delete(ModelPerformanceMetric).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
await db.execute(metrics_delete_query)
deletion_stats["performance_metrics_deleted"] = metrics_count
logger.info("Deleted performance metrics",
tenant_id=tenant_id,
count=metrics_count)
except Exception as e:
error_msg = f"Error deleting performance metrics: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 4: Delete prediction batches
try:
batches_count_query = select(func.count(PredictionBatch.id)).where(
PredictionBatch.tenant_id == tenant_uuid
)
batches_count_result = await db.execute(batches_count_query)
batches_count = batches_count_result.scalar()
batches_delete_query = delete(PredictionBatch).where(
PredictionBatch.tenant_id == tenant_uuid
)
await db.execute(batches_delete_query)
deletion_stats["prediction_batches_deleted"] = batches_count
logger.info("Deleted prediction batches",
tenant_id=tenant_id,
count=batches_count)
batches_delete_result = await db.execute(batches_delete_query)
deletion_stats["batches_deleted"] = batches_delete_result.rowcount
except Exception as e:
error_msg = f"Error deleting prediction batches: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 5: Delete forecasts (main data)
# Delete forecasts
try:
forecasts_count_query = select(func.count(Forecast.id)).where(
Forecast.tenant_id == tenant_uuid
)
forecasts_count_result = await db.execute(forecasts_count_query)
forecasts_count = forecasts_count_result.scalar()
forecasts_delete_query = delete(Forecast).where(
Forecast.tenant_id == tenant_uuid
)
await db.execute(forecasts_delete_query)
deletion_stats["forecasts_deleted"] = forecasts_count
await db.commit()
logger.info("Deleted forecasts",
tenant_id=tenant_id,
count=forecasts_count)
forecasts_delete_result = await db.execute(forecasts_delete_query)
deletion_stats["forecasts_deleted"] = forecasts_delete_result.rowcount
except Exception as e:
await db.rollback()
error_msg = f"Error deleting forecasts: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg
)
# Step 6: Publish deletion event
try:
await publish_forecasts_deleted_event(tenant_id, deletion_stats)
except Exception as e:
logger.warning("Failed to publish forecasts deletion event", error=str(e))
await db.commit()
return {
"success": True,
"message": f"All forecasting data for tenant {tenant_id} deleted successfully",
"deletion_details": deletion_stats
logger.info("Deleted tenant forecasting data",
tenant_id=tenant_id,
forecasts=deletion_stats["forecasts_deleted"],
predictions=deletion_stats["predictions_deleted"],
batches=deletion_stats["batches_deleted"])
deletion_stats["success"] = len(deletion_stats["errors"]) == 0
deletion_stats["expected_counts"] = {
"forecasts": forecasts_count,
"predictions": predictions_count,
"batches": batches_count
}
except HTTPException:
raise
return deletion_stats
except Exception as e:
logger.error("Unexpected error deleting tenant forecasts",
await db.rollback()
logger.error("Failed to delete tenant forecasts",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete tenant forecasts: {str(e)}"
detail="Failed to delete tenant forecasts"
)
@router.get("/tenants/{tenant_id}/forecasts/count")
async def get_tenant_forecasts_count(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get count of forecasts and predictions for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.forecasts import Forecast, Prediction, PredictionBatch
# Count forecasts
forecasts_count_query = select(func.count(Forecast.id)).where(
Forecast.tenant_id == tenant_uuid
)
forecasts_count_result = await db.execute(forecasts_count_query)
forecasts_count = forecasts_count_result.scalar()
# Count predictions
predictions_count_query = select(func.count(Prediction.id)).where(
Prediction.tenant_id == tenant_uuid
)
predictions_count_result = await db.execute(predictions_count_query)
predictions_count = predictions_count_result.scalar()
# Count batches
batches_count_query = select(func.count(PredictionBatch.id)).where(
PredictionBatch.tenant_id == tenant_uuid
)
batches_count_result = await db.execute(batches_count_query)
batches_count = batches_count_result.scalar()
return {
"tenant_id": tenant_id,
"forecasts_count": forecasts_count,
"predictions_count": predictions_count,
"batches_count": batches_count,
"total_forecasting_assets": forecasts_count + predictions_count + batches_count
}
except Exception as e:
logger.error("Failed to get tenant forecasts count",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get forecasts count"
)

View File

@@ -10,11 +10,15 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Dict, Any
from datetime import date, datetime, timedelta
from sqlalchemy import select, delete, func
import uuid
from app.core.database import get_db
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep
get_current_tenant_id_dep,
get_current_user_dep,
require_admin_role
)
from app.services.prediction_service import PredictionService
from app.schemas.forecasts import ForecastRequest
@@ -140,3 +144,128 @@ async def get_quick_prediction(
detail="Internal server error"
)
@router.post("/tenants/{tenant_id}/predictions/cancel-batches")
async def cancel_tenant_prediction_batches(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Cancel all active prediction batches for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.forecasts import PredictionBatch
# Find active prediction batches
active_batches_query = select(PredictionBatch).where(
PredictionBatch.tenant_id == tenant_uuid,
PredictionBatch.status.in_(["queued", "running", "pending"])
)
active_batches_result = await db.execute(active_batches_query)
active_batches = active_batches_result.scalars().all()
batches_cancelled = 0
cancelled_batch_ids = []
errors = []
for batch in active_batches:
try:
batch.status = "cancelled"
batch.updated_at = datetime.utcnow()
batch.cancelled_by = current_user.get("user_id")
batches_cancelled += 1
cancelled_batch_ids.append(str(batch.id))
logger.info("Cancelled prediction batch",
batch_id=str(batch.id),
tenant_id=tenant_id)
except Exception as e:
error_msg = f"Failed to cancel batch {batch.id}: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
if batches_cancelled > 0:
await db.commit()
return {
"success": True,
"tenant_id": tenant_id,
"batches_cancelled": batches_cancelled,
"cancelled_batch_ids": cancelled_batch_ids,
"errors": errors,
"cancelled_at": datetime.utcnow().isoformat()
}
except Exception as e:
await db.rollback()
logger.error("Failed to cancel tenant prediction batches",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel prediction batches"
)
@router.delete("/tenants/{tenant_id}/predictions/cache")
async def clear_tenant_prediction_cache(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Clear all prediction cache for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.forecasts import PredictionCache
# Count cache entries before deletion
cache_count_query = select(func.count(PredictionCache.id)).where(
PredictionCache.tenant_id == tenant_uuid
)
cache_count_result = await db.execute(cache_count_query)
cache_count = cache_count_result.scalar()
# Delete cache entries
cache_delete_query = delete(PredictionCache).where(
PredictionCache.tenant_id == tenant_uuid
)
cache_delete_result = await db.execute(cache_delete_query)
await db.commit()
logger.info("Cleared tenant prediction cache",
tenant_id=tenant_id,
cache_cleared=cache_delete_result.rowcount)
return {
"success": True,
"tenant_id": tenant_id,
"cache_cleared": cache_delete_result.rowcount,
"expected_count": cache_count,
"cleared_at": datetime.utcnow().isoformat()
}
except Exception as e:
await db.rollback()
logger.error("Failed to clear tenant prediction cache",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to clear prediction cache"
)

View File

@@ -81,6 +81,8 @@ class PredictionBatch(Base):
error_message = Column(Text)
processing_time_ms = Column(Integer)
cancelled_by = Column(String, nullable=True)
def __repr__(self):
return f"<PredictionBatch(id={self.id}, status={self.status})>"