# services/forecasting/app/services/tenant_deletion_service.py """ Tenant Data Deletion Service for Forecasting Service Handles deletion of all forecasting-related data for a tenant """ from typing import Dict from sqlalchemy import select, func, delete from sqlalchemy.ext.asyncio import AsyncSession import structlog from shared.services.tenant_deletion import ( BaseTenantDataDeletionService, TenantDataDeletionResult ) from app.models import ( Forecast, PredictionBatch, ModelPerformanceMetric, PredictionCache, AuditLog ) logger = structlog.get_logger(__name__) class ForecastingTenantDeletionService(BaseTenantDataDeletionService): """Service for deleting all forecasting-related data for a tenant""" def __init__(self, db: AsyncSession): self.db = db self.service_name = "forecasting" async def get_tenant_data_preview(self, tenant_id: str) -> Dict[str, int]: """ Get counts of what would be deleted for a tenant (dry-run) Args: tenant_id: The tenant ID to preview deletion for Returns: Dictionary with entity names and their counts """ logger.info("forecasting.tenant_deletion.preview", tenant_id=tenant_id) preview = {} try: # Count forecasts forecast_count = await self.db.scalar( select(func.count(Forecast.id)).where( Forecast.tenant_id == tenant_id ) ) preview["forecasts"] = forecast_count or 0 # Count prediction batches batch_count = await self.db.scalar( select(func.count(PredictionBatch.id)).where( PredictionBatch.tenant_id == tenant_id ) ) preview["prediction_batches"] = batch_count or 0 # Count model performance metrics metric_count = await self.db.scalar( select(func.count(ModelPerformanceMetric.id)).where( ModelPerformanceMetric.tenant_id == tenant_id ) ) preview["model_performance_metrics"] = metric_count or 0 # Count prediction cache entries cache_count = await self.db.scalar( select(func.count(PredictionCache.id)).where( PredictionCache.tenant_id == tenant_id ) ) preview["prediction_cache"] = cache_count or 0 # Count audit logs audit_count = await self.db.scalar( select(func.count(AuditLog.id)).where( AuditLog.tenant_id == tenant_id ) ) preview["audit_logs"] = audit_count or 0 logger.info( "forecasting.tenant_deletion.preview_complete", tenant_id=tenant_id, preview=preview ) except Exception as e: logger.error( "forecasting.tenant_deletion.preview_error", tenant_id=tenant_id, error=str(e), exc_info=True ) raise return preview async def delete_tenant_data(self, tenant_id: str) -> TenantDataDeletionResult: """ Permanently delete all forecasting data for a tenant Deletion order: 1. PredictionCache (independent) 2. ModelPerformanceMetric (independent) 3. PredictionBatch (independent) 4. Forecast (independent) 5. AuditLog (independent) Note: All tables are independent with no foreign key relationships Args: tenant_id: The tenant ID to delete data for Returns: TenantDataDeletionResult with deletion counts and any errors """ logger.info("forecasting.tenant_deletion.started", tenant_id=tenant_id) result = TenantDataDeletionResult(tenant_id=tenant_id, service_name=self.service_name) try: # Step 1: Delete prediction cache logger.info("forecasting.tenant_deletion.deleting_cache", tenant_id=tenant_id) cache_result = await self.db.execute( delete(PredictionCache).where( PredictionCache.tenant_id == tenant_id ) ) result.deleted_counts["prediction_cache"] = cache_result.rowcount logger.info( "forecasting.tenant_deletion.cache_deleted", tenant_id=tenant_id, count=cache_result.rowcount ) # Step 2: Delete model performance metrics logger.info("forecasting.tenant_deletion.deleting_metrics", tenant_id=tenant_id) metrics_result = await self.db.execute( delete(ModelPerformanceMetric).where( ModelPerformanceMetric.tenant_id == tenant_id ) ) result.deleted_counts["model_performance_metrics"] = metrics_result.rowcount logger.info( "forecasting.tenant_deletion.metrics_deleted", tenant_id=tenant_id, count=metrics_result.rowcount ) # Step 3: Delete prediction batches logger.info("forecasting.tenant_deletion.deleting_batches", tenant_id=tenant_id) batches_result = await self.db.execute( delete(PredictionBatch).where( PredictionBatch.tenant_id == tenant_id ) ) result.deleted_counts["prediction_batches"] = batches_result.rowcount logger.info( "forecasting.tenant_deletion.batches_deleted", tenant_id=tenant_id, count=batches_result.rowcount ) # Step 4: Delete forecasts logger.info("forecasting.tenant_deletion.deleting_forecasts", tenant_id=tenant_id) forecasts_result = await self.db.execute( delete(Forecast).where( Forecast.tenant_id == tenant_id ) ) result.deleted_counts["forecasts"] = forecasts_result.rowcount logger.info( "forecasting.tenant_deletion.forecasts_deleted", tenant_id=tenant_id, count=forecasts_result.rowcount ) # Step 5: Delete audit logs logger.info("forecasting.tenant_deletion.deleting_audit_logs", tenant_id=tenant_id) audit_result = await self.db.execute( delete(AuditLog).where( AuditLog.tenant_id == tenant_id ) ) result.deleted_counts["audit_logs"] = audit_result.rowcount logger.info( "forecasting.tenant_deletion.audit_logs_deleted", tenant_id=tenant_id, count=audit_result.rowcount ) # Commit the transaction await self.db.commit() # Calculate total deleted total_deleted = sum(result.deleted_counts.values()) logger.info( "forecasting.tenant_deletion.completed", tenant_id=tenant_id, total_deleted=total_deleted, breakdown=result.deleted_counts ) result.success = True except Exception as e: await self.db.rollback() error_msg = f"Failed to delete forecasting data for tenant {tenant_id}: {str(e)}" logger.error( "forecasting.tenant_deletion.failed", tenant_id=tenant_id, error=str(e), exc_info=True ) result.errors.append(error_msg) result.success = False return result def get_forecasting_tenant_deletion_service( db: AsyncSession ) -> ForecastingTenantDeletionService: """ Factory function to create ForecastingTenantDeletionService instance Args: db: AsyncSession database session Returns: ForecastingTenantDeletionService instance """ return ForecastingTenantDeletionService(db)