Improve user delete flow
This commit is contained in:
@@ -322,30 +322,15 @@ async def acknowledge_alert(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.delete("/forecasts/tenant/{tenant_id}")
|
||||
async def delete_tenant_forecasts_complete(
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/forecasts")
|
||||
async def delete_tenant_forecasts(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Delete all forecasts and predictions for a tenant.
|
||||
|
||||
**WARNING: This operation is irreversible!**
|
||||
|
||||
This endpoint:
|
||||
1. Cancels any active prediction batches
|
||||
2. Clears prediction cache
|
||||
3. Deletes all forecast records
|
||||
4. Deletes prediction batch records
|
||||
5. Deletes model performance metrics
|
||||
6. Publishes deletion event
|
||||
|
||||
Used by admin user deletion process to clean up all forecasting data.
|
||||
"""
|
||||
|
||||
"""Delete all forecasts and predictions for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
@@ -355,163 +340,155 @@ async def delete_tenant_forecasts_complete(
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import Forecast, PredictionBatch
|
||||
from app.models.predictions import ModelPerformanceMetric, PredictionCache
|
||||
from app.models.forecasts import Forecast, Prediction, PredictionBatch
|
||||
|
||||
deletion_stats = {
|
||||
"tenant_id": tenant_id,
|
||||
"deleted_at": datetime.utcnow().isoformat(),
|
||||
"batches_cancelled": 0,
|
||||
"forecasts_deleted": 0,
|
||||
"prediction_batches_deleted": 0,
|
||||
"performance_metrics_deleted": 0,
|
||||
"cache_entries_deleted": 0,
|
||||
"predictions_deleted": 0,
|
||||
"batches_deleted": 0,
|
||||
"errors": []
|
||||
}
|
||||
|
||||
# Step 1: Cancel active prediction batches
|
||||
# Count before deletion
|
||||
forecasts_count_query = select(func.count(Forecast.id)).where(
|
||||
Forecast.tenant_id == tenant_uuid
|
||||
)
|
||||
forecasts_count_result = await db.execute(forecasts_count_query)
|
||||
forecasts_count = forecasts_count_result.scalar()
|
||||
|
||||
predictions_count_query = select(func.count(Prediction.id)).where(
|
||||
Prediction.tenant_id == tenant_uuid
|
||||
)
|
||||
predictions_count_result = await db.execute(predictions_count_query)
|
||||
predictions_count = predictions_count_result.scalar()
|
||||
|
||||
batches_count_query = select(func.count(PredictionBatch.id)).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid
|
||||
)
|
||||
batches_count_result = await db.execute(batches_count_query)
|
||||
batches_count = batches_count_result.scalar()
|
||||
|
||||
# Delete predictions first (they may reference forecasts)
|
||||
try:
|
||||
active_batches_query = select(PredictionBatch).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid,
|
||||
PredictionBatch.status.in_(["pending", "processing"])
|
||||
predictions_delete_query = delete(Prediction).where(
|
||||
Prediction.tenant_id == tenant_uuid
|
||||
)
|
||||
active_batches_result = await db.execute(active_batches_query)
|
||||
active_batches = active_batches_result.scalars().all()
|
||||
|
||||
for batch in active_batches:
|
||||
batch.status = "cancelled"
|
||||
batch.completed_at = datetime.utcnow()
|
||||
deletion_stats["batches_cancelled"] += 1
|
||||
|
||||
if active_batches:
|
||||
await db.commit()
|
||||
logger.info("Cancelled active prediction batches",
|
||||
tenant_id=tenant_id,
|
||||
count=len(active_batches))
|
||||
predictions_delete_result = await db.execute(predictions_delete_query)
|
||||
deletion_stats["predictions_deleted"] = predictions_delete_result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error cancelling prediction batches: {str(e)}"
|
||||
error_msg = f"Error deleting predictions: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Step 2: Delete prediction cache
|
||||
# Delete prediction batches
|
||||
try:
|
||||
cache_count_query = select(func.count(PredictionCache.id)).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
)
|
||||
cache_count_result = await db.execute(cache_count_query)
|
||||
cache_count = cache_count_result.scalar()
|
||||
|
||||
cache_delete_query = delete(PredictionCache).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(cache_delete_query)
|
||||
deletion_stats["cache_entries_deleted"] = cache_count
|
||||
|
||||
logger.info("Deleted prediction cache entries",
|
||||
tenant_id=tenant_id,
|
||||
count=cache_count)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting prediction cache: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Step 3: Delete model performance metrics
|
||||
try:
|
||||
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||
)
|
||||
metrics_count_result = await db.execute(metrics_count_query)
|
||||
metrics_count = metrics_count_result.scalar()
|
||||
|
||||
metrics_delete_query = delete(ModelPerformanceMetric).where(
|
||||
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(metrics_delete_query)
|
||||
deletion_stats["performance_metrics_deleted"] = metrics_count
|
||||
|
||||
logger.info("Deleted performance metrics",
|
||||
tenant_id=tenant_id,
|
||||
count=metrics_count)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting performance metrics: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Step 4: Delete prediction batches
|
||||
try:
|
||||
batches_count_query = select(func.count(PredictionBatch.id)).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid
|
||||
)
|
||||
batches_count_result = await db.execute(batches_count_query)
|
||||
batches_count = batches_count_result.scalar()
|
||||
|
||||
batches_delete_query = delete(PredictionBatch).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(batches_delete_query)
|
||||
deletion_stats["prediction_batches_deleted"] = batches_count
|
||||
|
||||
logger.info("Deleted prediction batches",
|
||||
tenant_id=tenant_id,
|
||||
count=batches_count)
|
||||
batches_delete_result = await db.execute(batches_delete_query)
|
||||
deletion_stats["batches_deleted"] = batches_delete_result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error deleting prediction batches: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
# Step 5: Delete forecasts (main data)
|
||||
# Delete forecasts
|
||||
try:
|
||||
forecasts_count_query = select(func.count(Forecast.id)).where(
|
||||
Forecast.tenant_id == tenant_uuid
|
||||
)
|
||||
forecasts_count_result = await db.execute(forecasts_count_query)
|
||||
forecasts_count = forecasts_count_result.scalar()
|
||||
|
||||
forecasts_delete_query = delete(Forecast).where(
|
||||
Forecast.tenant_id == tenant_uuid
|
||||
)
|
||||
await db.execute(forecasts_delete_query)
|
||||
deletion_stats["forecasts_deleted"] = forecasts_count
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Deleted forecasts",
|
||||
tenant_id=tenant_id,
|
||||
count=forecasts_count)
|
||||
forecasts_delete_result = await db.execute(forecasts_delete_query)
|
||||
deletion_stats["forecasts_deleted"] = forecasts_delete_result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
error_msg = f"Error deleting forecasts: {str(e)}"
|
||||
deletion_stats["errors"].append(error_msg)
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=error_msg
|
||||
)
|
||||
|
||||
# Step 6: Publish deletion event
|
||||
try:
|
||||
await publish_forecasts_deleted_event(tenant_id, deletion_stats)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to publish forecasts deletion event", error=str(e))
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"All forecasting data for tenant {tenant_id} deleted successfully",
|
||||
"deletion_details": deletion_stats
|
||||
logger.info("Deleted tenant forecasting data",
|
||||
tenant_id=tenant_id,
|
||||
forecasts=deletion_stats["forecasts_deleted"],
|
||||
predictions=deletion_stats["predictions_deleted"],
|
||||
batches=deletion_stats["batches_deleted"])
|
||||
|
||||
deletion_stats["success"] = len(deletion_stats["errors"]) == 0
|
||||
deletion_stats["expected_counts"] = {
|
||||
"forecasts": forecasts_count,
|
||||
"predictions": predictions_count,
|
||||
"batches": batches_count
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
return deletion_stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error deleting tenant forecasts",
|
||||
await db.rollback()
|
||||
logger.error("Failed to delete tenant forecasts",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to delete tenant forecasts: {str(e)}"
|
||||
detail="Failed to delete tenant forecasts"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/forecasts/count")
|
||||
async def get_tenant_forecasts_count(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get count of forecasts and predictions for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import Forecast, Prediction, PredictionBatch
|
||||
|
||||
# Count forecasts
|
||||
forecasts_count_query = select(func.count(Forecast.id)).where(
|
||||
Forecast.tenant_id == tenant_uuid
|
||||
)
|
||||
forecasts_count_result = await db.execute(forecasts_count_query)
|
||||
forecasts_count = forecasts_count_result.scalar()
|
||||
|
||||
# Count predictions
|
||||
predictions_count_query = select(func.count(Prediction.id)).where(
|
||||
Prediction.tenant_id == tenant_uuid
|
||||
)
|
||||
predictions_count_result = await db.execute(predictions_count_query)
|
||||
predictions_count = predictions_count_result.scalar()
|
||||
|
||||
# Count batches
|
||||
batches_count_query = select(func.count(PredictionBatch.id)).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid
|
||||
)
|
||||
batches_count_result = await db.execute(batches_count_query)
|
||||
batches_count = batches_count_result.scalar()
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"forecasts_count": forecasts_count,
|
||||
"predictions_count": predictions_count,
|
||||
"batches_count": batches_count,
|
||||
"total_forecasting_assets": forecasts_count + predictions_count + batches_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant forecasts count",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get forecasts count"
|
||||
)
|
||||
@@ -10,11 +10,15 @@ from fastapi import APIRouter, Depends, HTTPException, status, Query
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from typing import List, Dict, Any
|
||||
from datetime import date, datetime, timedelta
|
||||
from sqlalchemy import select, delete, func
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db
|
||||
from shared.auth.decorators import (
|
||||
get_current_user_dep,
|
||||
get_current_tenant_id_dep
|
||||
get_current_tenant_id_dep,
|
||||
get_current_user_dep,
|
||||
require_admin_role
|
||||
)
|
||||
from app.services.prediction_service import PredictionService
|
||||
from app.schemas.forecasts import ForecastRequest
|
||||
@@ -140,3 +144,128 @@ async def get_quick_prediction(
|
||||
detail="Internal server error"
|
||||
)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/predictions/cancel-batches")
|
||||
async def cancel_tenant_prediction_batches(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Cancel all active prediction batches for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import PredictionBatch
|
||||
|
||||
# Find active prediction batches
|
||||
active_batches_query = select(PredictionBatch).where(
|
||||
PredictionBatch.tenant_id == tenant_uuid,
|
||||
PredictionBatch.status.in_(["queued", "running", "pending"])
|
||||
)
|
||||
active_batches_result = await db.execute(active_batches_query)
|
||||
active_batches = active_batches_result.scalars().all()
|
||||
|
||||
batches_cancelled = 0
|
||||
cancelled_batch_ids = []
|
||||
errors = []
|
||||
|
||||
for batch in active_batches:
|
||||
try:
|
||||
batch.status = "cancelled"
|
||||
batch.updated_at = datetime.utcnow()
|
||||
batch.cancelled_by = current_user.get("user_id")
|
||||
batches_cancelled += 1
|
||||
cancelled_batch_ids.append(str(batch.id))
|
||||
|
||||
logger.info("Cancelled prediction batch",
|
||||
batch_id=str(batch.id),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to cancel batch {batch.id}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
if batches_cancelled > 0:
|
||||
await db.commit()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"batches_cancelled": batches_cancelled,
|
||||
"cancelled_batch_ids": cancelled_batch_ids,
|
||||
"errors": errors,
|
||||
"cancelled_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to cancel tenant prediction batches",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cancel prediction batches"
|
||||
)
|
||||
|
||||
@router.delete("/tenants/{tenant_id}/predictions/cache")
|
||||
async def clear_tenant_prediction_cache(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Clear all prediction cache for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.forecasts import PredictionCache
|
||||
|
||||
# Count cache entries before deletion
|
||||
cache_count_query = select(func.count(PredictionCache.id)).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
)
|
||||
cache_count_result = await db.execute(cache_count_query)
|
||||
cache_count = cache_count_result.scalar()
|
||||
|
||||
# Delete cache entries
|
||||
cache_delete_query = delete(PredictionCache).where(
|
||||
PredictionCache.tenant_id == tenant_uuid
|
||||
)
|
||||
cache_delete_result = await db.execute(cache_delete_query)
|
||||
|
||||
await db.commit()
|
||||
|
||||
logger.info("Cleared tenant prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
cache_cleared=cache_delete_result.rowcount)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"cache_cleared": cache_delete_result.rowcount,
|
||||
"expected_count": cache_count,
|
||||
"cleared_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to clear tenant prediction cache",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to clear prediction cache"
|
||||
)
|
||||
@@ -81,6 +81,8 @@ class PredictionBatch(Base):
|
||||
error_message = Column(Text)
|
||||
processing_time_ms = Column(Integer)
|
||||
|
||||
cancelled_by = Column(String, nullable=True)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PredictionBatch(id={self.id}, status={self.status})>"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user