From 277e8bec73c22bbeaa60696bba47307788eb9bec Mon Sep 17 00:00:00 2001 From: Urtzi Alfaro Date: Sat, 2 Aug 2025 09:41:50 +0200 Subject: [PATCH] Add user role --- services/auth/app/api/auth.py | 2 +- services/auth/app/api/users.py | 7 +- services/auth/app/models/users.py | 1 + services/auth/app/schemas/auth.py | 1 + services/auth/app/services/auth_service.py | 9 +- services/forecasting/app/api/forecasts.py | 202 ++++++++- .../forecasting/app/services/messaging.py | 19 +- services/tenant/app/api/tenants.py | 154 ++++++- services/tenant/app/services/messaging.py | 19 +- services/training/app/api/models.py | 248 ++++++++++- services/training/app/services/messaging.py | 21 +- shared/auth/decorators.py | 393 +++++++++++++++++- tests/test_onboarding_flow.sh | 3 +- 13 files changed, 1051 insertions(+), 28 deletions(-) diff --git a/services/auth/app/api/auth.py b/services/auth/app/api/auth.py index e1318bf8..eab80ce9 100644 --- a/services/auth/app/api/auth.py +++ b/services/auth/app/api/auth.py @@ -32,7 +32,7 @@ async def register( # ✅ DEBUG: Log incoming registration data (without password) logger.info(f"Registration attempt for email: {user_data.email}") - logger.debug(f"Registration data - email: {user_data.email}, full_name: {user_data.full_name}") + logger.debug(f"Registration data - email: {user_data.email}, full_name: {user_data.full_name}, role: {user_data.role}") try: # ✅ DEBUG: Validate input data diff --git a/services/auth/app/api/users.py b/services/auth/app/api/users.py index a286d0f2..96b8ba99 100644 --- a/services/auth/app/api/users.py +++ b/services/auth/app/api/users.py @@ -21,8 +21,7 @@ from app.services.admin_delete import AdminUserDeleteService # Import unified authentication from shared library from shared.auth.decorators import ( get_current_user_dep, - get_current_tenant_id_dep, - require_role # For admin-only endpoints + require_admin_role ) logger = structlog.get_logger() @@ -126,7 +125,7 @@ async def delete_admin_user( user_id: str, background_tasks: BackgroundTasks, current_user = Depends(get_current_user_dep), - #_admin_check = Depends(require_admin_role), + _admin_check = Depends(require_admin_role), db: AsyncSession = Depends(get_db) ): """ @@ -191,7 +190,7 @@ async def delete_admin_user( async def preview_user_deletion( user_id: str, current_user = Depends(get_current_user_dep), - #_admin_check = Depends(require_admin_role), + _admin_check = Depends(require_admin_role), db: AsyncSession = Depends(get_db) ): """ diff --git a/services/auth/app/models/users.py b/services/auth/app/models/users.py index 27f8104b..2dfdebce 100644 --- a/services/auth/app/models/users.py +++ b/services/auth/app/models/users.py @@ -31,6 +31,7 @@ class User(Base): phone = Column(String(20)) language = Column(String(10), default="es") timezone = Column(String(50), default="Europe/Madrid") + role = Column(String(20), default="user") # REMOVED: All tenant relationships - these are handled by tenant service # No tenant_memberships, tenants relationships diff --git a/services/auth/app/schemas/auth.py b/services/auth/app/schemas/auth.py index d4d037dd..1784cf44 100644 --- a/services/auth/app/schemas/auth.py +++ b/services/auth/app/schemas/auth.py @@ -18,6 +18,7 @@ class UserRegistration(BaseModel): password: str = Field(..., min_length=8, max_length=128) full_name: str = Field(..., min_length=1, max_length=255) tenant_name: Optional[str] = Field(None, max_length=255) + role: Optional[str] = Field("user", pattern=r'^(user|admin|manager)$') class UserLogin(BaseModel): """User login request""" diff --git a/services/auth/app/services/auth_service.py b/services/auth/app/services/auth_service.py index a1e59c2a..5bbd1db4 100644 --- a/services/auth/app/services/auth_service.py +++ b/services/auth/app/services/auth_service.py @@ -48,7 +48,8 @@ class AuthService: is_active=True, is_verified=False, created_at=datetime.now(timezone.utc), - updated_at=datetime.now(timezone.utc) + updated_at=datetime.now(timezone.utc), + role=user_data.role ) db.add(new_user) @@ -115,7 +116,8 @@ class AuthService: "full_name": new_user.full_name, "is_active": new_user.is_active, "is_verified": new_user.is_verified, - "created_at": new_user.created_at.isoformat() + "created_at": new_user.created_at.isoformat(), + "role": new_user.role } } @@ -242,7 +244,8 @@ class AuthService: "full_name": user.full_name, "is_active": user.is_active, "is_verified": user.is_verified, - "created_at": user.created_at.isoformat() + "created_at": user.created_at.isoformat(), + "role": user.role } } diff --git a/services/forecasting/app/api/forecasts.py b/services/forecasting/app/api/forecasts.py index c260b667..171fd5c1 100644 --- a/services/forecasting/app/api/forecasts.py +++ b/services/forecasting/app/api/forecasts.py @@ -9,12 +9,14 @@ import structlog from fastapi import APIRouter, Depends, HTTPException, status, Query, Path from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Optional -from datetime import date +from datetime import date, datetime +from sqlalchemy import select, delete, func +import uuid from app.core.database import get_db from shared.auth.decorators import ( get_current_user_dep, - get_current_tenant_id_dep + require_admin_role ) from app.services.forecasting_service import ForecastingService from app.schemas.forecasts import ( @@ -22,6 +24,7 @@ from app.schemas.forecasts import ( BatchForecastResponse, AlertResponse ) from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert +from app.services.messaging import publish_forecasts_deleted_event logger = structlog.get_logger() router = APIRouter() @@ -318,4 +321,197 @@ async def acknowledge_alert( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal server error" - ) \ No newline at end of file + ) + +@router.delete("/forecasts/tenant/{tenant_id}") +async def delete_tenant_forecasts_complete( + tenant_id: str, + current_user = Depends(get_current_user_dep), + _admin_check = Depends(require_admin_role), + db: AsyncSession = Depends(get_db) +): + """ + Delete all forecasts and predictions for a tenant. + + **WARNING: This operation is irreversible!** + + This endpoint: + 1. Cancels any active prediction batches + 2. Clears prediction cache + 3. Deletes all forecast records + 4. Deletes prediction batch records + 5. Deletes model performance metrics + 6. Publishes deletion event + + Used by admin user deletion process to clean up all forecasting data. + """ + + try: + tenant_uuid = uuid.UUID(tenant_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid tenant ID format" + ) + + try: + from app.models.forecasts import Forecast, PredictionBatch + from app.models.predictions import ModelPerformanceMetric, PredictionCache + + deletion_stats = { + "tenant_id": tenant_id, + "deleted_at": datetime.utcnow().isoformat(), + "batches_cancelled": 0, + "forecasts_deleted": 0, + "prediction_batches_deleted": 0, + "performance_metrics_deleted": 0, + "cache_entries_deleted": 0, + "errors": [] + } + + # Step 1: Cancel active prediction batches + try: + active_batches_query = select(PredictionBatch).where( + PredictionBatch.tenant_id == tenant_uuid, + PredictionBatch.status.in_(["pending", "processing"]) + ) + active_batches_result = await db.execute(active_batches_query) + active_batches = active_batches_result.scalars().all() + + for batch in active_batches: + batch.status = "cancelled" + batch.completed_at = datetime.utcnow() + deletion_stats["batches_cancelled"] += 1 + + if active_batches: + await db.commit() + logger.info("Cancelled active prediction batches", + tenant_id=tenant_id, + count=len(active_batches)) + + except Exception as e: + error_msg = f"Error cancelling prediction batches: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 2: Delete prediction cache + try: + cache_count_query = select(func.count(PredictionCache.id)).where( + PredictionCache.tenant_id == tenant_uuid + ) + cache_count_result = await db.execute(cache_count_query) + cache_count = cache_count_result.scalar() + + cache_delete_query = delete(PredictionCache).where( + PredictionCache.tenant_id == tenant_uuid + ) + await db.execute(cache_delete_query) + deletion_stats["cache_entries_deleted"] = cache_count + + logger.info("Deleted prediction cache entries", + tenant_id=tenant_id, + count=cache_count) + + except Exception as e: + error_msg = f"Error deleting prediction cache: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 3: Delete model performance metrics + try: + metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where( + ModelPerformanceMetric.tenant_id == tenant_uuid + ) + metrics_count_result = await db.execute(metrics_count_query) + metrics_count = metrics_count_result.scalar() + + metrics_delete_query = delete(ModelPerformanceMetric).where( + ModelPerformanceMetric.tenant_id == tenant_uuid + ) + await db.execute(metrics_delete_query) + deletion_stats["performance_metrics_deleted"] = metrics_count + + logger.info("Deleted performance metrics", + tenant_id=tenant_id, + count=metrics_count) + + except Exception as e: + error_msg = f"Error deleting performance metrics: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 4: Delete prediction batches + try: + batches_count_query = select(func.count(PredictionBatch.id)).where( + PredictionBatch.tenant_id == tenant_uuid + ) + batches_count_result = await db.execute(batches_count_query) + batches_count = batches_count_result.scalar() + + batches_delete_query = delete(PredictionBatch).where( + PredictionBatch.tenant_id == tenant_uuid + ) + await db.execute(batches_delete_query) + deletion_stats["prediction_batches_deleted"] = batches_count + + logger.info("Deleted prediction batches", + tenant_id=tenant_id, + count=batches_count) + + except Exception as e: + error_msg = f"Error deleting prediction batches: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 5: Delete forecasts (main data) + try: + forecasts_count_query = select(func.count(Forecast.id)).where( + Forecast.tenant_id == tenant_uuid + ) + forecasts_count_result = await db.execute(forecasts_count_query) + forecasts_count = forecasts_count_result.scalar() + + forecasts_delete_query = delete(Forecast).where( + Forecast.tenant_id == tenant_uuid + ) + await db.execute(forecasts_delete_query) + deletion_stats["forecasts_deleted"] = forecasts_count + + await db.commit() + + logger.info("Deleted forecasts", + tenant_id=tenant_id, + count=forecasts_count) + + except Exception as e: + await db.rollback() + error_msg = f"Error deleting forecasts: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=error_msg + ) + + # Step 6: Publish deletion event + try: + await publish_forecasts_deleted_event(tenant_id, deletion_stats) + except Exception as e: + logger.warning("Failed to publish forecasts deletion event", error=str(e)) + + return { + "success": True, + "message": f"All forecasting data for tenant {tenant_id} deleted successfully", + "deletion_details": deletion_stats + } + + except HTTPException: + raise + except Exception as e: + logger.error("Unexpected error deleting tenant forecasts", + tenant_id=tenant_id, + error=str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete tenant forecasts: {str(e)}" + ) diff --git a/services/forecasting/app/services/messaging.py b/services/forecasting/app/services/messaging.py index 06139ec6..0a329788 100644 --- a/services/forecasting/app/services/messaging.py +++ b/services/forecasting/app/services/messaging.py @@ -9,6 +9,7 @@ import structlog import json from typing import Dict, Any import asyncio +import datetime from shared.messaging.rabbitmq import RabbitMQClient from shared.messaging.events import ( @@ -132,4 +133,20 @@ async def handle_weather_updated(data: Dict[str, Any]): # Could trigger re-forecasting if needed except Exception as e: - logger.error("Error handling weather updated event", error=str(e)) \ No newline at end of file + logger.error("Error handling weather updated event", error=str(e)) + +async def publish_forecasts_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]): + """Publish forecasts deletion event to message queue""" + try: + await rabbitmq_client.publish_event( + exchange="forecasting_events", + routing_key="forecasting.tenant.deleted", + message={ + "event_type": "tenant_forecasts_deleted", + "tenant_id": tenant_id, + "timestamp": datetime.utcnow().isoformat(), + "deletion_stats": deletion_stats + } + ) + except Exception as e: + logger.error("Failed to publish forecasts deletion event", error=str(e)) \ No newline at end of file diff --git a/services/tenant/app/api/tenants.py b/services/tenant/app/api/tenants.py index 1d8beedb..e65edf05 100644 --- a/services/tenant/app/api/tenants.py +++ b/services/tenant/app/api/tenants.py @@ -8,18 +8,20 @@ from sqlalchemy.ext.asyncio import AsyncSession from typing import List, Dict, Any import structlog from uuid import UUID +from sqlalchemy import select, delete, func +from datetime import datetime +import uuid from app.core.database import get_db +from app.services.messaging import publish_tenant_deleted_event from app.schemas.tenants import ( BakeryRegistration, TenantResponse, TenantAccessResponse, TenantUpdate, TenantMemberResponse ) from app.services.tenant_service import TenantService -# Import unified authentication from shared.auth.decorators import ( get_current_user_dep, - get_current_tenant_id_dep, - require_role + require_admin_role ) logger = structlog.get_logger() @@ -163,4 +165,150 @@ async def add_team_member( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to add team member" + ) + +@router.delete("/tenants/{tenant_id}") +async def delete_tenant_complete( + tenant_id: str, + current_user = Depends(get_current_user_dep), + _admin_check = Depends(require_admin_role), + db: AsyncSession = Depends(get_db) +): + """ + Delete a tenant completely with all associated data. + + **WARNING: This operation is irreversible!** + + This endpoint: + 1. Validates tenant exists and user has permissions + 2. Deletes all tenant memberships + 3. Deletes tenant subscription data + 4. Deletes the tenant record + 5. Publishes deletion event + + Used by admin user deletion process when a tenant has no other admins. + """ + + try: + tenant_uuid = uuid.UUID(tenant_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid tenant ID format" + ) + + try: + from app.models.tenants import Tenant, TenantMember, Subscription + + # Step 1: Verify tenant exists + tenant_query = select(Tenant).where(Tenant.id == tenant_uuid) + tenant_result = await db.execute(tenant_query) + tenant = tenant_result.scalar_one_or_none() + + if not tenant: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Tenant {tenant_id} not found" + ) + + deletion_stats = { + "tenant_id": tenant_id, + "tenant_name": tenant.name, + "deleted_at": datetime.utcnow().isoformat(), + "memberships_deleted": 0, + "subscriptions_deleted": 0, + "errors": [] + } + + # Step 2: Delete all tenant memberships + try: + membership_count_query = select(func.count(TenantMember.id)).where( + TenantMember.tenant_id == tenant_uuid + ) + membership_count_result = await db.execute(membership_count_query) + membership_count = membership_count_result.scalar() + + membership_delete_query = delete(TenantMember).where( + TenantMember.tenant_id == tenant_uuid + ) + await db.execute(membership_delete_query) + deletion_stats["memberships_deleted"] = membership_count + + logger.info("Deleted tenant memberships", + tenant_id=tenant_id, + count=membership_count) + + except Exception as e: + error_msg = f"Error deleting memberships: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 3: Delete subscription data + try: + subscription_count_query = select(func.count(Subscription.id)).where( + Subscription.tenant_id == tenant_uuid + ) + subscription_count_result = await db.execute(subscription_count_query) + subscription_count = subscription_count_result.scalar() + + subscription_delete_query = delete(Subscription).where( + Subscription.tenant_id == tenant_uuid + ) + await db.execute(subscription_delete_query) + deletion_stats["subscriptions_deleted"] = subscription_count + + logger.info("Deleted tenant subscriptions", + tenant_id=tenant_id, + count=subscription_count) + + except Exception as e: + error_msg = f"Error deleting subscriptions: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 4: Delete the tenant record + try: + tenant_delete_query = delete(Tenant).where(Tenant.id == tenant_uuid) + tenant_result = await db.execute(tenant_delete_query) + + if tenant_result.rowcount == 0: + raise Exception("Tenant record was not deleted") + + await db.commit() + + logger.info("Tenant deleted successfully", + tenant_id=tenant_id, + tenant_name=tenant.name) + + except Exception as e: + await db.rollback() + error_msg = f"Error deleting tenant record: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=error_msg + ) + + # Step 5: Publish tenant deletion event + try: + await publish_tenant_deleted_event(tenant_id, deletion_stats) + except Exception as e: + logger.warning("Failed to publish tenant deletion event", error=str(e)) + + return { + "success": True, + "message": f"Tenant {tenant_id} deleted successfully", + "deletion_details": deletion_stats + } + + except HTTPException: + raise + except Exception as e: + logger.error("Unexpected error deleting tenant", + tenant_id=tenant_id, + error=str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete tenant: {str(e)}" ) \ No newline at end of file diff --git a/services/tenant/app/services/messaging.py b/services/tenant/app/services/messaging.py index 850b0422..12a76913 100644 --- a/services/tenant/app/services/messaging.py +++ b/services/tenant/app/services/messaging.py @@ -6,6 +6,7 @@ from shared.messaging.rabbitmq import RabbitMQClient from app.core.config import settings import structlog from datetime import datetime +from typing import Dict, Any logger = structlog.get_logger() @@ -40,4 +41,20 @@ async def publish_member_added(tenant_id: str, user_id: str, role: str): } ) except Exception as e: - logger.error(f"Failed to publish tenant.member.added event: {e}") \ No newline at end of file + logger.error(f"Failed to publish tenant.member.added event: {e}") + +async def publish_tenant_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]): + """Publish tenant deletion event to message queue""" + try: + await data_publisher.publish_event( + exchange="tenant_events", + routing_key="tenant.deleted", + message={ + "event_type": "tenant_deleted", + "tenant_id": tenant_id, + "timestamp": datetime.utcnow().isoformat(), + "deletion_stats": deletion_stats + } + ) + except Exception as e: + logger.error("Failed to publish tenant deletion event", error=str(e)) \ No newline at end of file diff --git a/services/training/app/api/models.py b/services/training/app/api/models.py index e589b63e..9c99f635 100644 --- a/services/training/app/api/models.py +++ b/services/training/app/api/models.py @@ -12,9 +12,15 @@ from app.core.database import get_db from app.schemas.training import TrainedModelResponse, ModelMetricsResponse from app.services.training_service import TrainingService from datetime import datetime +from sqlalchemy import select, delete, func +import uuid +import shutil + +from app.services.messaging import publish_models_deleted_event from shared.auth.decorators import ( - get_current_tenant_id_dep + get_current_user_dep, + require_admin_role ) logger = structlog.get_logger() @@ -212,4 +218,244 @@ async def list_models( raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to retrieve models" + ) + +@router.delete("/models/tenant/{tenant_id}") +async def delete_tenant_models_complete( + tenant_id: str, + current_user = Depends(get_current_user_dep), + _admin_check = Depends(require_admin_role), + db: AsyncSession = Depends(get_db) +): + """ + Delete all trained models and artifacts for a tenant. + + **WARNING: This operation is irreversible!** + + This endpoint: + 1. Cancels any active training jobs for the tenant + 2. Deletes all model artifacts (files) from storage + 3. Deletes model records from database + 4. Deletes training logs and performance metrics + 5. Publishes deletion event + + Used by admin user deletion process to clean up all training data. + """ + + try: + tenant_uuid = uuid.UUID(tenant_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid tenant ID format" + ) + + try: + from app.models.training import ( + ModelTrainingLog, + TrainedModel, + ModelArtifact, + ModelPerformanceMetric, + TrainingJobQueue + ) + from app.core.config import settings + + deletion_stats = { + "tenant_id": tenant_id, + "deleted_at": datetime.utcnow().isoformat(), + "jobs_cancelled": 0, + "models_deleted": 0, + "artifacts_deleted": 0, + "artifacts_files_deleted": 0, + "training_logs_deleted": 0, + "performance_metrics_deleted": 0, + "storage_freed_bytes": 0, + "errors": [] + } + + # Step 1: Cancel active training jobs + try: + active_jobs_query = select(TrainingJobQueue).where( + TrainingJobQueue.tenant_id == tenant_uuid, + TrainingJobQueue.status.in_(["queued", "running", "pending"]) + ) + active_jobs_result = await db.execute(active_jobs_query) + active_jobs = active_jobs_result.scalars().all() + + for job in active_jobs: + job.status = "cancelled" + job.updated_at = datetime.utcnow() + deletion_stats["jobs_cancelled"] += 1 + + if active_jobs: + await db.commit() + logger.info("Cancelled active training jobs", + tenant_id=tenant_id, + count=len(active_jobs)) + + except Exception as e: + error_msg = f"Error cancelling training jobs: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 2: Delete model artifact files from storage + try: + artifacts_query = select(ModelArtifact).where( + ModelArtifact.tenant_id == tenant_uuid + ) + artifacts_result = await db.execute(artifacts_query) + artifacts = artifacts_result.scalars().all() + + storage_freed = 0 + files_deleted = 0 + + for artifact in artifacts: + try: + file_path = Path(artifact.file_path) + if file_path.exists(): + file_size = file_path.stat().st_size + file_path.unlink() # Delete file + storage_freed += file_size + files_deleted += 1 + logger.debug("Deleted artifact file", + file_path=str(file_path), + size_bytes=file_size) + + # Also try to delete parent directories if empty + try: + if file_path.parent.exists() and not any(file_path.parent.iterdir()): + file_path.parent.rmdir() + except: + pass # Ignore errors cleaning up directories + + except Exception as e: + error_msg = f"Error deleting artifact file {artifact.file_path}: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.warning(error_msg) + + deletion_stats["artifacts_files_deleted"] = files_deleted + deletion_stats["storage_freed_bytes"] = storage_freed + + logger.info("Deleted artifact files", + tenant_id=tenant_id, + files_deleted=files_deleted, + storage_freed_mb=storage_freed / (1024 * 1024)) + + except Exception as e: + error_msg = f"Error processing artifact files: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + + # Step 3: Delete database records + try: + # Delete model performance metrics + metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where( + ModelPerformanceMetric.tenant_id == tenant_uuid + ) + metrics_count_result = await db.execute(metrics_count_query) + metrics_count = metrics_count_result.scalar() + + metrics_delete_query = delete(ModelPerformanceMetric).where( + ModelPerformanceMetric.tenant_id == tenant_uuid + ) + await db.execute(metrics_delete_query) + deletion_stats["performance_metrics_deleted"] = metrics_count + + # Delete model artifacts records + artifacts_count_query = select(func.count(ModelArtifact.id)).where( + ModelArtifact.tenant_id == tenant_uuid + ) + artifacts_count_result = await db.execute(artifacts_count_query) + artifacts_count = artifacts_count_result.scalar() + + artifacts_delete_query = delete(ModelArtifact).where( + ModelArtifact.tenant_id == tenant_uuid + ) + await db.execute(artifacts_delete_query) + deletion_stats["artifacts_deleted"] = artifacts_count + + # Delete trained models + models_count_query = select(func.count(TrainedModel.id)).where( + TrainedModel.tenant_id == tenant_uuid + ) + models_count_result = await db.execute(models_count_query) + models_count = models_count_result.scalar() + + models_delete_query = delete(TrainedModel).where( + TrainedModel.tenant_id == tenant_uuid + ) + await db.execute(models_delete_query) + deletion_stats["models_deleted"] = models_count + + # Delete training logs + logs_count_query = select(func.count(ModelTrainingLog.id)).where( + ModelTrainingLog.tenant_id == tenant_uuid + ) + logs_count_result = await db.execute(logs_count_query) + logs_count = logs_count_result.scalar() + + logs_delete_query = delete(ModelTrainingLog).where( + ModelTrainingLog.tenant_id == tenant_uuid + ) + await db.execute(logs_delete_query) + deletion_stats["training_logs_deleted"] = logs_count + + # Delete job queue entries + queue_delete_query = delete(TrainingJobQueue).where( + TrainingJobQueue.tenant_id == tenant_uuid + ) + await db.execute(queue_delete_query) + + await db.commit() + + logger.info("Deleted training database records", + tenant_id=tenant_id, + models=models_count, + artifacts=artifacts_count, + logs=logs_count, + metrics=metrics_count) + + except Exception as e: + await db.rollback() + error_msg = f"Error deleting database records: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.error(error_msg) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=error_msg + ) + + # Step 4: Clean up tenant model directory + try: + tenant_model_dir = Path(settings.MODEL_STORAGE_PATH) / tenant_id + if tenant_model_dir.exists(): + shutil.rmtree(tenant_model_dir) + logger.info("Deleted tenant model directory", + directory=str(tenant_model_dir)) + except Exception as e: + error_msg = f"Error deleting model directory: {str(e)}" + deletion_stats["errors"].append(error_msg) + logger.warning(error_msg) + + # Step 5: Publish deletion event + try: + await publish_models_deleted_event(tenant_id, deletion_stats) + except Exception as e: + logger.warning("Failed to publish models deletion event", error=str(e)) + + return { + "success": True, + "message": f"All training data for tenant {tenant_id} deleted successfully", + "deletion_details": deletion_stats + } + + except HTTPException: + raise + except Exception as e: + logger.error("Unexpected error deleting tenant models", + tenant_id=tenant_id, + error=str(e)) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete tenant models: {str(e)}" ) \ No newline at end of file diff --git a/services/training/app/services/messaging.py b/services/training/app/services/messaging.py index 749370f1..6b834470 100644 --- a/services/training/app/services/messaging.py +++ b/services/training/app/services/messaging.py @@ -442,6 +442,24 @@ async def publish_data_validation_completed( } ) + +async def publish_models_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]): + """Publish models deletion event to message queue""" + try: + await training_publisher.publish_event( + exchange="training_events", + routing_key="training.tenant.models.deleted", + message={ + "event_type": "tenant_models_deleted", + "tenant_id": tenant_id, + "timestamp": datetime.utcnow().isoformat(), + "deletion_stats": deletion_stats + } + ) + except Exception as e: + logger.error("Failed to publish models deletion event", error=str(e)) + + # ========================================= # UTILITY FUNCTIONS FOR BATCH PUBLISHING # ========================================= @@ -549,4 +567,5 @@ class TrainingStatusPublisher: async def job_failed(self, error: str, error_details: Optional[Dict] = None): """Publish job failure with clean error details""" clean_error_details = safe_json_serialize(error_details) if error_details else None - await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details) \ No newline at end of file + await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details) + \ No newline at end of file diff --git a/shared/auth/decorators.py b/shared/auth/decorators.py index 543c1bf8..eed35df6 100644 --- a/shared/auth/decorators.py +++ b/shared/auth/decorators.py @@ -1,12 +1,16 @@ +# ================================================================ +# shared/auth/decorators.py - ENHANCED WITH ADMIN ROLE DECORATOR +# ================================================================ """ -Unified authentication decorators for microservices -Designed to work with gateway authentication middleware +Enhanced authentication decorators for microservices including admin role validation. +Designed to work with gateway authentication middleware and provide centralized +role-based access control across all services. """ from functools import wraps from fastapi import HTTPException, status, Request, Depends from fastapi.security import HTTPBearer -from typing import Callable, Optional, Dict, Any +from typing import Callable, Optional, Dict, Any, List import structlog logger = structlog.get_logger() @@ -111,6 +115,208 @@ def require_role(role: str): return decorator +def require_admin_role(func: Callable) -> Callable: + """ + Decorator to require admin role - simplified version for FastAPI dependencies + + This decorator ensures only users with 'admin' role can access the endpoint. + Can be used as a FastAPI dependency or function decorator. + + Usage as dependency: + @router.delete("/admin/users/{user_id}") + async def delete_user( + user_id: str, + current_user = Depends(get_current_user_dep), + _admin_check = Depends(require_admin_role), + ): + # Admin-only logic here + + Usage as decorator: + @require_admin_role + @router.delete("/admin/users/{user_id}") + async def delete_user(...): + # Admin-only logic here + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + # Find request object in arguments + request = None + current_user = None + + # Extract request and current_user from arguments + for arg in args: + if isinstance(arg, Request): + request = arg + elif isinstance(arg, dict) and 'user_id' in arg: + current_user = arg + + # Check kwargs for request and current_user + if not request: + request = kwargs.get('request') + if not current_user: + current_user = kwargs.get('current_user') + + # If we still don't have current_user, try to get it from request + if not current_user and request: + current_user = get_current_user(request) + + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required" + ) + + # Check if user has admin role + user_role = current_user.get('role', '').lower() + + if user_role != 'admin': + logger.warning("Non-admin user attempted admin operation", + user_id=current_user.get('user_id'), + role=user_role) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin role required" + ) + + logger.info("Admin operation authorized", + user_id=current_user.get('user_id'), + endpoint=func.__name__) + + return await func(*args, **kwargs) + + return wrapper + +def require_roles(allowed_roles: List[str]): + """ + Decorator to require one of multiple roles + + Args: + allowed_roles: List of roles that are allowed to access the endpoint + + Usage: + @require_roles(['admin', 'manager']) + @router.post("/sensitive-operation") + async def sensitive_operation(...): + # Only admins and managers can access + """ + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + request = None + current_user = None + + # Extract request and current_user from arguments + for arg in args: + if isinstance(arg, Request): + request = arg + elif isinstance(arg, dict) and 'user_id' in arg: + current_user = arg + + # Check kwargs + if not request: + request = kwargs.get('request') + if not current_user: + current_user = kwargs.get('current_user') + + # Get user from request if not provided + if not current_user and request: + current_user = get_current_user(request) + + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required" + ) + + # Check if user has one of the allowed roles + user_role = current_user.get('role', '').lower() + allowed_roles_lower = [role.lower() for role in allowed_roles] + + if user_role not in allowed_roles_lower: + logger.warning("Unauthorized role attempted restricted operation", + user_id=current_user.get('user_id'), + user_role=user_role, + allowed_roles=allowed_roles) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"One of these roles required: {', '.join(allowed_roles)}" + ) + + logger.info("Role-based operation authorized", + user_id=current_user.get('user_id'), + user_role=user_role, + endpoint=func.__name__) + + return await func(*args, **kwargs) + + return wrapper + + return decorator + +def require_tenant_admin(func: Callable) -> Callable: + """ + Decorator to require admin role within a specific tenant context + + This checks that the user is an admin AND has access to the tenant + being operated on. Useful for tenant-scoped admin operations. + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + request = None + current_user = None + + # Extract request and current_user from arguments + for arg in args: + if isinstance(arg, Request): + request = arg + elif isinstance(arg, dict) and 'user_id' in arg: + current_user = arg + + if not request: + request = kwargs.get('request') + if not current_user: + current_user = kwargs.get('current_user') + + if not current_user and request: + current_user = get_current_user(request) + + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required" + ) + + # Check admin role first + user_role = current_user.get('role', '').lower() + if user_role != 'admin': + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin role required" + ) + + # Check tenant access + tenant_id = get_current_tenant_id(request) if request else None + if not tenant_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Tenant context required" + ) + + # Additional tenant admin validation could go here + # For now, we trust that admin users have access to operate on any tenant + + logger.info("Tenant admin operation authorized", + user_id=current_user.get('user_id'), + tenant_id=tenant_id, + endpoint=func.__name__) + + return await func(*args, **kwargs) + + return wrapper + def get_current_user(request: Request) -> Dict[str, Any]: """Get current user from request state or headers""" if hasattr(request.state, 'user') and request.state.user: @@ -145,14 +351,18 @@ def extract_user_from_headers(request: Request) -> Optional[Dict[str, Any]]: "email": request.headers.get("x-user-email", ""), "role": request.headers.get("x-user-role", "user"), "tenant_id": request.headers.get("x-tenant-id"), - "permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else [] + "permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else [], + "full_name": request.headers.get("x-user-full-name", "") } def extract_tenant_from_headers(request: Request) -> Optional[str]: """Extract tenant ID from headers""" return request.headers.get("x-tenant-id") -# FastAPI Dependencies for injection +# ================================================================ +# FASTAPI DEPENDENCY FUNCTIONS +# ================================================================ + async def get_current_user_dep(request: Request) -> Dict[str, Any]: """FastAPI dependency to get current user""" return get_current_user(request) @@ -161,15 +371,180 @@ async def get_current_tenant_id_dep(request: Request) -> Optional[str]: """FastAPI dependency to get current tenant ID""" return get_current_tenant_id(request) +async def require_admin_role_dep( + current_user: Dict[str, Any] = Depends(get_current_user_dep) +) -> Dict[str, Any]: + """ + FastAPI dependency that requires admin role + + Usage: + @router.delete("/admin/users/{user_id}") + async def delete_user( + user_id: str, + admin_user: Dict[str, Any] = Depends(require_admin_role_dep) + ): + # admin_user is guaranteed to have admin role + """ + + user_role = current_user.get('role', '').lower() + + if user_role != 'admin': + logger.warning("Non-admin user attempted admin operation", + user_id=current_user.get('user_id'), + role=user_role) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Admin role required" + ) + + logger.info("Admin operation authorized via dependency", + user_id=current_user.get('user_id')) + + return current_user + +async def require_roles_dep(allowed_roles: List[str]): + """ + FastAPI dependency factory that requires one of multiple roles + + Usage: + require_manager_or_admin = require_roles_dep(['admin', 'manager']) + + @router.post("/sensitive-operation") + async def sensitive_operation( + user: Dict[str, Any] = Depends(require_manager_or_admin) + ): + # Only admins and managers can access + """ + + async def check_roles( + current_user: Dict[str, Any] = Depends(get_current_user_dep) + ) -> Dict[str, Any]: + user_role = current_user.get('role', '').lower() + allowed_roles_lower = [role.lower() for role in allowed_roles] + + if user_role not in allowed_roles_lower: + logger.warning("Unauthorized role attempted restricted operation", + user_id=current_user.get('user_id'), + user_role=user_role, + allowed_roles=allowed_roles) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"One of these roles required: {', '.join(allowed_roles)}" + ) + + logger.info("Role-based operation authorized via dependency", + user_id=current_user.get('user_id'), + user_role=user_role) + + return current_user + + return check_roles + +# ================================================================ +# UTILITY FUNCTIONS FOR ROLE CHECKING +# ================================================================ + +def is_admin_user(user: Dict[str, Any]) -> bool: + """Check if user has admin role""" + return user.get('role', '').lower() == 'admin' + +def is_user_in_roles(user: Dict[str, Any], allowed_roles: List[str]) -> bool: + """Check if user has one of the allowed roles""" + user_role = user.get('role', '').lower() + allowed_roles_lower = [role.lower() for role in allowed_roles] + return user_role in allowed_roles_lower + +def get_user_permissions(user: Dict[str, Any]) -> List[str]: + """Get user permissions list""" + return user.get('permissions', []) + +def has_permission(user: Dict[str, Any], permission: str) -> bool: + """Check if user has specific permission""" + permissions = get_user_permissions(user) + return permission in permissions + +# ================================================================ +# ADVANCED ROLE DECORATORS +# ================================================================ + +def require_permission(permission: str): + """ + Decorator to require specific permission + + Usage: + @require_permission('delete_users') + @router.delete("/users/{user_id}") + async def delete_user(...): + # Only users with 'delete_users' permission can access + """ + + def decorator(func: Callable) -> Callable: + @wraps(func) + async def wrapper(*args, **kwargs): + current_user = None + + # Extract current_user from arguments + for arg in args: + if isinstance(arg, dict) and 'user_id' in arg: + current_user = arg + break + + if not current_user: + current_user = kwargs.get('current_user') + + if not current_user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentication required" + ) + + # Check permission + if not has_permission(current_user, permission): + # Admins bypass permission checks + if not is_admin_user(current_user): + logger.warning("User lacks required permission", + user_id=current_user.get('user_id'), + required_permission=permission, + user_permissions=get_user_permissions(current_user)) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Permission '{permission}' required" + ) + + logger.info("Permission-based operation authorized", + user_id=current_user.get('user_id'), + permission=permission) + + return await func(*args, **kwargs) + + return wrapper + + return decorator + # Export all decorators and functions __all__ = [ + # Main decorators 'require_authentication', - 'require_tenant_access', + 'require_tenant_access', 'require_role', - 'get_current_user', - 'get_current_tenant_id', + 'require_admin_role', + 'require_roles', + 'require_tenant_admin', + 'require_permission', + + # FastAPI dependencies 'get_current_user_dep', 'get_current_tenant_id_dep', + 'require_admin_role_dep', + 'require_roles_dep', + + # Utility functions + 'get_current_user', + 'get_current_tenant_id', 'extract_user_from_headers', - 'extract_tenant_from_headers' + 'extract_tenant_from_headers', + 'is_admin_user', + 'is_user_in_roles', + 'get_user_permissions', + 'has_permission' ] \ No newline at end of file diff --git a/tests/test_onboarding_flow.sh b/tests/test_onboarding_flow.sh index 62e1b11b..36345065 100755 --- a/tests/test_onboarding_flow.sh +++ b/tests/test_onboarding_flow.sh @@ -455,7 +455,8 @@ REGISTER_RESPONSE=$(curl -s -X POST "$API_BASE/api/v1/auth/register" \ -d "{ \"email\": \"$TEST_EMAIL\", \"password\": \"$TEST_PASSWORD\", - \"full_name\": \"$TEST_NAME\" + \"full_name\": \"$TEST_NAME\", + \"role\": \"admin\" }") echo "Registration Response:"