Add user role

This commit is contained in:
Urtzi Alfaro
2025-08-02 09:41:50 +02:00
parent d4687e6375
commit 277e8bec73
13 changed files with 1051 additions and 28 deletions

View File

@@ -32,7 +32,7 @@ async def register(
# ✅ DEBUG: Log incoming registration data (without password)
logger.info(f"Registration attempt for email: {user_data.email}")
logger.debug(f"Registration data - email: {user_data.email}, full_name: {user_data.full_name}")
logger.debug(f"Registration data - email: {user_data.email}, full_name: {user_data.full_name}, role: {user_data.role}")
try:
# ✅ DEBUG: Validate input data

View File

@@ -21,8 +21,7 @@ from app.services.admin_delete import AdminUserDeleteService
# Import unified authentication from shared library
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
require_role # For admin-only endpoints
require_admin_role
)
logger = structlog.get_logger()
@@ -126,7 +125,7 @@ async def delete_admin_user(
user_id: str,
background_tasks: BackgroundTasks,
current_user = Depends(get_current_user_dep),
#_admin_check = Depends(require_admin_role),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""
@@ -191,7 +190,7 @@ async def delete_admin_user(
async def preview_user_deletion(
user_id: str,
current_user = Depends(get_current_user_dep),
#_admin_check = Depends(require_admin_role),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""

View File

@@ -31,6 +31,7 @@ class User(Base):
phone = Column(String(20))
language = Column(String(10), default="es")
timezone = Column(String(50), default="Europe/Madrid")
role = Column(String(20), default="user")
# REMOVED: All tenant relationships - these are handled by tenant service
# No tenant_memberships, tenants relationships

View File

@@ -18,6 +18,7 @@ class UserRegistration(BaseModel):
password: str = Field(..., min_length=8, max_length=128)
full_name: str = Field(..., min_length=1, max_length=255)
tenant_name: Optional[str] = Field(None, max_length=255)
role: Optional[str] = Field("user", pattern=r'^(user|admin|manager)$')
class UserLogin(BaseModel):
"""User login request"""

View File

@@ -48,7 +48,8 @@ class AuthService:
is_active=True,
is_verified=False,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc)
updated_at=datetime.now(timezone.utc),
role=user_data.role
)
db.add(new_user)
@@ -115,7 +116,8 @@ class AuthService:
"full_name": new_user.full_name,
"is_active": new_user.is_active,
"is_verified": new_user.is_verified,
"created_at": new_user.created_at.isoformat()
"created_at": new_user.created_at.isoformat(),
"role": new_user.role
}
}
@@ -242,7 +244,8 @@ class AuthService:
"full_name": user.full_name,
"is_active": user.is_active,
"is_verified": user.is_verified,
"created_at": user.created_at.isoformat()
"created_at": user.created_at.isoformat(),
"role": user.role
}
}

View File

@@ -9,12 +9,14 @@ import structlog
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Optional
from datetime import date
from datetime import date, datetime
from sqlalchemy import select, delete, func
import uuid
from app.core.database import get_db
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep
require_admin_role
)
from app.services.forecasting_service import ForecastingService
from app.schemas.forecasts import (
@@ -22,6 +24,7 @@ from app.schemas.forecasts import (
BatchForecastResponse, AlertResponse
)
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
from app.services.messaging import publish_forecasts_deleted_event
logger = structlog.get_logger()
router = APIRouter()
@@ -318,4 +321,197 @@ async def acknowledge_alert(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Internal server error"
)
)
@router.delete("/forecasts/tenant/{tenant_id}")
async def delete_tenant_forecasts_complete(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""
Delete all forecasts and predictions for a tenant.
**WARNING: This operation is irreversible!**
This endpoint:
1. Cancels any active prediction batches
2. Clears prediction cache
3. Deletes all forecast records
4. Deletes prediction batch records
5. Deletes model performance metrics
6. Publishes deletion event
Used by admin user deletion process to clean up all forecasting data.
"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.forecasts import Forecast, PredictionBatch
from app.models.predictions import ModelPerformanceMetric, PredictionCache
deletion_stats = {
"tenant_id": tenant_id,
"deleted_at": datetime.utcnow().isoformat(),
"batches_cancelled": 0,
"forecasts_deleted": 0,
"prediction_batches_deleted": 0,
"performance_metrics_deleted": 0,
"cache_entries_deleted": 0,
"errors": []
}
# Step 1: Cancel active prediction batches
try:
active_batches_query = select(PredictionBatch).where(
PredictionBatch.tenant_id == tenant_uuid,
PredictionBatch.status.in_(["pending", "processing"])
)
active_batches_result = await db.execute(active_batches_query)
active_batches = active_batches_result.scalars().all()
for batch in active_batches:
batch.status = "cancelled"
batch.completed_at = datetime.utcnow()
deletion_stats["batches_cancelled"] += 1
if active_batches:
await db.commit()
logger.info("Cancelled active prediction batches",
tenant_id=tenant_id,
count=len(active_batches))
except Exception as e:
error_msg = f"Error cancelling prediction batches: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 2: Delete prediction cache
try:
cache_count_query = select(func.count(PredictionCache.id)).where(
PredictionCache.tenant_id == tenant_uuid
)
cache_count_result = await db.execute(cache_count_query)
cache_count = cache_count_result.scalar()
cache_delete_query = delete(PredictionCache).where(
PredictionCache.tenant_id == tenant_uuid
)
await db.execute(cache_delete_query)
deletion_stats["cache_entries_deleted"] = cache_count
logger.info("Deleted prediction cache entries",
tenant_id=tenant_id,
count=cache_count)
except Exception as e:
error_msg = f"Error deleting prediction cache: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 3: Delete model performance metrics
try:
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
metrics_count_result = await db.execute(metrics_count_query)
metrics_count = metrics_count_result.scalar()
metrics_delete_query = delete(ModelPerformanceMetric).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
await db.execute(metrics_delete_query)
deletion_stats["performance_metrics_deleted"] = metrics_count
logger.info("Deleted performance metrics",
tenant_id=tenant_id,
count=metrics_count)
except Exception as e:
error_msg = f"Error deleting performance metrics: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 4: Delete prediction batches
try:
batches_count_query = select(func.count(PredictionBatch.id)).where(
PredictionBatch.tenant_id == tenant_uuid
)
batches_count_result = await db.execute(batches_count_query)
batches_count = batches_count_result.scalar()
batches_delete_query = delete(PredictionBatch).where(
PredictionBatch.tenant_id == tenant_uuid
)
await db.execute(batches_delete_query)
deletion_stats["prediction_batches_deleted"] = batches_count
logger.info("Deleted prediction batches",
tenant_id=tenant_id,
count=batches_count)
except Exception as e:
error_msg = f"Error deleting prediction batches: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 5: Delete forecasts (main data)
try:
forecasts_count_query = select(func.count(Forecast.id)).where(
Forecast.tenant_id == tenant_uuid
)
forecasts_count_result = await db.execute(forecasts_count_query)
forecasts_count = forecasts_count_result.scalar()
forecasts_delete_query = delete(Forecast).where(
Forecast.tenant_id == tenant_uuid
)
await db.execute(forecasts_delete_query)
deletion_stats["forecasts_deleted"] = forecasts_count
await db.commit()
logger.info("Deleted forecasts",
tenant_id=tenant_id,
count=forecasts_count)
except Exception as e:
await db.rollback()
error_msg = f"Error deleting forecasts: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg
)
# Step 6: Publish deletion event
try:
await publish_forecasts_deleted_event(tenant_id, deletion_stats)
except Exception as e:
logger.warning("Failed to publish forecasts deletion event", error=str(e))
return {
"success": True,
"message": f"All forecasting data for tenant {tenant_id} deleted successfully",
"deletion_details": deletion_stats
}
except HTTPException:
raise
except Exception as e:
logger.error("Unexpected error deleting tenant forecasts",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete tenant forecasts: {str(e)}"
)

View File

@@ -9,6 +9,7 @@ import structlog
import json
from typing import Dict, Any
import asyncio
import datetime
from shared.messaging.rabbitmq import RabbitMQClient
from shared.messaging.events import (
@@ -132,4 +133,20 @@ async def handle_weather_updated(data: Dict[str, Any]):
# Could trigger re-forecasting if needed
except Exception as e:
logger.error("Error handling weather updated event", error=str(e))
logger.error("Error handling weather updated event", error=str(e))
async def publish_forecasts_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
"""Publish forecasts deletion event to message queue"""
try:
await rabbitmq_client.publish_event(
exchange="forecasting_events",
routing_key="forecasting.tenant.deleted",
message={
"event_type": "tenant_forecasts_deleted",
"tenant_id": tenant_id,
"timestamp": datetime.utcnow().isoformat(),
"deletion_stats": deletion_stats
}
)
except Exception as e:
logger.error("Failed to publish forecasts deletion event", error=str(e))

View File

@@ -8,18 +8,20 @@ from sqlalchemy.ext.asyncio import AsyncSession
from typing import List, Dict, Any
import structlog
from uuid import UUID
from sqlalchemy import select, delete, func
from datetime import datetime
import uuid
from app.core.database import get_db
from app.services.messaging import publish_tenant_deleted_event
from app.schemas.tenants import (
BakeryRegistration, TenantResponse, TenantAccessResponse,
TenantUpdate, TenantMemberResponse
)
from app.services.tenant_service import TenantService
# Import unified authentication
from shared.auth.decorators import (
get_current_user_dep,
get_current_tenant_id_dep,
require_role
require_admin_role
)
logger = structlog.get_logger()
@@ -163,4 +165,150 @@ async def add_team_member(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to add team member"
)
@router.delete("/tenants/{tenant_id}")
async def delete_tenant_complete(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""
Delete a tenant completely with all associated data.
**WARNING: This operation is irreversible!**
This endpoint:
1. Validates tenant exists and user has permissions
2. Deletes all tenant memberships
3. Deletes tenant subscription data
4. Deletes the tenant record
5. Publishes deletion event
Used by admin user deletion process when a tenant has no other admins.
"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.tenants import Tenant, TenantMember, Subscription
# Step 1: Verify tenant exists
tenant_query = select(Tenant).where(Tenant.id == tenant_uuid)
tenant_result = await db.execute(tenant_query)
tenant = tenant_result.scalar_one_or_none()
if not tenant:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Tenant {tenant_id} not found"
)
deletion_stats = {
"tenant_id": tenant_id,
"tenant_name": tenant.name,
"deleted_at": datetime.utcnow().isoformat(),
"memberships_deleted": 0,
"subscriptions_deleted": 0,
"errors": []
}
# Step 2: Delete all tenant memberships
try:
membership_count_query = select(func.count(TenantMember.id)).where(
TenantMember.tenant_id == tenant_uuid
)
membership_count_result = await db.execute(membership_count_query)
membership_count = membership_count_result.scalar()
membership_delete_query = delete(TenantMember).where(
TenantMember.tenant_id == tenant_uuid
)
await db.execute(membership_delete_query)
deletion_stats["memberships_deleted"] = membership_count
logger.info("Deleted tenant memberships",
tenant_id=tenant_id,
count=membership_count)
except Exception as e:
error_msg = f"Error deleting memberships: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 3: Delete subscription data
try:
subscription_count_query = select(func.count(Subscription.id)).where(
Subscription.tenant_id == tenant_uuid
)
subscription_count_result = await db.execute(subscription_count_query)
subscription_count = subscription_count_result.scalar()
subscription_delete_query = delete(Subscription).where(
Subscription.tenant_id == tenant_uuid
)
await db.execute(subscription_delete_query)
deletion_stats["subscriptions_deleted"] = subscription_count
logger.info("Deleted tenant subscriptions",
tenant_id=tenant_id,
count=subscription_count)
except Exception as e:
error_msg = f"Error deleting subscriptions: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 4: Delete the tenant record
try:
tenant_delete_query = delete(Tenant).where(Tenant.id == tenant_uuid)
tenant_result = await db.execute(tenant_delete_query)
if tenant_result.rowcount == 0:
raise Exception("Tenant record was not deleted")
await db.commit()
logger.info("Tenant deleted successfully",
tenant_id=tenant_id,
tenant_name=tenant.name)
except Exception as e:
await db.rollback()
error_msg = f"Error deleting tenant record: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg
)
# Step 5: Publish tenant deletion event
try:
await publish_tenant_deleted_event(tenant_id, deletion_stats)
except Exception as e:
logger.warning("Failed to publish tenant deletion event", error=str(e))
return {
"success": True,
"message": f"Tenant {tenant_id} deleted successfully",
"deletion_details": deletion_stats
}
except HTTPException:
raise
except Exception as e:
logger.error("Unexpected error deleting tenant",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete tenant: {str(e)}"
)

View File

@@ -6,6 +6,7 @@ from shared.messaging.rabbitmq import RabbitMQClient
from app.core.config import settings
import structlog
from datetime import datetime
from typing import Dict, Any
logger = structlog.get_logger()
@@ -40,4 +41,20 @@ async def publish_member_added(tenant_id: str, user_id: str, role: str):
}
)
except Exception as e:
logger.error(f"Failed to publish tenant.member.added event: {e}")
logger.error(f"Failed to publish tenant.member.added event: {e}")
async def publish_tenant_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
"""Publish tenant deletion event to message queue"""
try:
await data_publisher.publish_event(
exchange="tenant_events",
routing_key="tenant.deleted",
message={
"event_type": "tenant_deleted",
"tenant_id": tenant_id,
"timestamp": datetime.utcnow().isoformat(),
"deletion_stats": deletion_stats
}
)
except Exception as e:
logger.error("Failed to publish tenant deletion event", error=str(e))

View File

@@ -12,9 +12,15 @@ from app.core.database import get_db
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
from app.services.training_service import TrainingService
from datetime import datetime
from sqlalchemy import select, delete, func
import uuid
import shutil
from app.services.messaging import publish_models_deleted_event
from shared.auth.decorators import (
get_current_tenant_id_dep
get_current_user_dep,
require_admin_role
)
logger = structlog.get_logger()
@@ -212,4 +218,244 @@ async def list_models(
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve models"
)
@router.delete("/models/tenant/{tenant_id}")
async def delete_tenant_models_complete(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""
Delete all trained models and artifacts for a tenant.
**WARNING: This operation is irreversible!**
This endpoint:
1. Cancels any active training jobs for the tenant
2. Deletes all model artifacts (files) from storage
3. Deletes model records from database
4. Deletes training logs and performance metrics
5. Publishes deletion event
Used by admin user deletion process to clean up all training data.
"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelArtifact,
ModelPerformanceMetric,
TrainingJobQueue
)
from app.core.config import settings
deletion_stats = {
"tenant_id": tenant_id,
"deleted_at": datetime.utcnow().isoformat(),
"jobs_cancelled": 0,
"models_deleted": 0,
"artifacts_deleted": 0,
"artifacts_files_deleted": 0,
"training_logs_deleted": 0,
"performance_metrics_deleted": 0,
"storage_freed_bytes": 0,
"errors": []
}
# Step 1: Cancel active training jobs
try:
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
for job in active_jobs:
job.status = "cancelled"
job.updated_at = datetime.utcnow()
deletion_stats["jobs_cancelled"] += 1
if active_jobs:
await db.commit()
logger.info("Cancelled active training jobs",
tenant_id=tenant_id,
count=len(active_jobs))
except Exception as e:
error_msg = f"Error cancelling training jobs: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 2: Delete model artifact files from storage
try:
artifacts_query = select(ModelArtifact).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_result = await db.execute(artifacts_query)
artifacts = artifacts_result.scalars().all()
storage_freed = 0
files_deleted = 0
for artifact in artifacts:
try:
file_path = Path(artifact.file_path)
if file_path.exists():
file_size = file_path.stat().st_size
file_path.unlink() # Delete file
storage_freed += file_size
files_deleted += 1
logger.debug("Deleted artifact file",
file_path=str(file_path),
size_bytes=file_size)
# Also try to delete parent directories if empty
try:
if file_path.parent.exists() and not any(file_path.parent.iterdir()):
file_path.parent.rmdir()
except:
pass # Ignore errors cleaning up directories
except Exception as e:
error_msg = f"Error deleting artifact file {artifact.file_path}: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.warning(error_msg)
deletion_stats["artifacts_files_deleted"] = files_deleted
deletion_stats["storage_freed_bytes"] = storage_freed
logger.info("Deleted artifact files",
tenant_id=tenant_id,
files_deleted=files_deleted,
storage_freed_mb=storage_freed / (1024 * 1024))
except Exception as e:
error_msg = f"Error processing artifact files: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
# Step 3: Delete database records
try:
# Delete model performance metrics
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
metrics_count_result = await db.execute(metrics_count_query)
metrics_count = metrics_count_result.scalar()
metrics_delete_query = delete(ModelPerformanceMetric).where(
ModelPerformanceMetric.tenant_id == tenant_uuid
)
await db.execute(metrics_delete_query)
deletion_stats["performance_metrics_deleted"] = metrics_count
# Delete model artifacts records
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_count_result = await db.execute(artifacts_count_query)
artifacts_count = artifacts_count_result.scalar()
artifacts_delete_query = delete(ModelArtifact).where(
ModelArtifact.tenant_id == tenant_uuid
)
await db.execute(artifacts_delete_query)
deletion_stats["artifacts_deleted"] = artifacts_count
# Delete trained models
models_count_query = select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_uuid
)
models_count_result = await db.execute(models_count_query)
models_count = models_count_result.scalar()
models_delete_query = delete(TrainedModel).where(
TrainedModel.tenant_id == tenant_uuid
)
await db.execute(models_delete_query)
deletion_stats["models_deleted"] = models_count
# Delete training logs
logs_count_query = select(func.count(ModelTrainingLog.id)).where(
ModelTrainingLog.tenant_id == tenant_uuid
)
logs_count_result = await db.execute(logs_count_query)
logs_count = logs_count_result.scalar()
logs_delete_query = delete(ModelTrainingLog).where(
ModelTrainingLog.tenant_id == tenant_uuid
)
await db.execute(logs_delete_query)
deletion_stats["training_logs_deleted"] = logs_count
# Delete job queue entries
queue_delete_query = delete(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid
)
await db.execute(queue_delete_query)
await db.commit()
logger.info("Deleted training database records",
tenant_id=tenant_id,
models=models_count,
artifacts=artifacts_count,
logs=logs_count,
metrics=metrics_count)
except Exception as e:
await db.rollback()
error_msg = f"Error deleting database records: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.error(error_msg)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=error_msg
)
# Step 4: Clean up tenant model directory
try:
tenant_model_dir = Path(settings.MODEL_STORAGE_PATH) / tenant_id
if tenant_model_dir.exists():
shutil.rmtree(tenant_model_dir)
logger.info("Deleted tenant model directory",
directory=str(tenant_model_dir))
except Exception as e:
error_msg = f"Error deleting model directory: {str(e)}"
deletion_stats["errors"].append(error_msg)
logger.warning(error_msg)
# Step 5: Publish deletion event
try:
await publish_models_deleted_event(tenant_id, deletion_stats)
except Exception as e:
logger.warning("Failed to publish models deletion event", error=str(e))
return {
"success": True,
"message": f"All training data for tenant {tenant_id} deleted successfully",
"deletion_details": deletion_stats
}
except HTTPException:
raise
except Exception as e:
logger.error("Unexpected error deleting tenant models",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to delete tenant models: {str(e)}"
)

View File

@@ -442,6 +442,24 @@ async def publish_data_validation_completed(
}
)
async def publish_models_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
"""Publish models deletion event to message queue"""
try:
await training_publisher.publish_event(
exchange="training_events",
routing_key="training.tenant.models.deleted",
message={
"event_type": "tenant_models_deleted",
"tenant_id": tenant_id,
"timestamp": datetime.utcnow().isoformat(),
"deletion_stats": deletion_stats
}
)
except Exception as e:
logger.error("Failed to publish models deletion event", error=str(e))
# =========================================
# UTILITY FUNCTIONS FOR BATCH PUBLISHING
# =========================================
@@ -549,4 +567,5 @@ class TrainingStatusPublisher:
async def job_failed(self, error: str, error_details: Optional[Dict] = None):
"""Publish job failure with clean error details"""
clean_error_details = safe_json_serialize(error_details) if error_details else None
await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details)
await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details)

View File

@@ -1,12 +1,16 @@
# ================================================================
# shared/auth/decorators.py - ENHANCED WITH ADMIN ROLE DECORATOR
# ================================================================
"""
Unified authentication decorators for microservices
Designed to work with gateway authentication middleware
Enhanced authentication decorators for microservices including admin role validation.
Designed to work with gateway authentication middleware and provide centralized
role-based access control across all services.
"""
from functools import wraps
from fastapi import HTTPException, status, Request, Depends
from fastapi.security import HTTPBearer
from typing import Callable, Optional, Dict, Any
from typing import Callable, Optional, Dict, Any, List
import structlog
logger = structlog.get_logger()
@@ -111,6 +115,208 @@ def require_role(role: str):
return decorator
def require_admin_role(func: Callable) -> Callable:
"""
Decorator to require admin role - simplified version for FastAPI dependencies
This decorator ensures only users with 'admin' role can access the endpoint.
Can be used as a FastAPI dependency or function decorator.
Usage as dependency:
@router.delete("/admin/users/{user_id}")
async def delete_user(
user_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
):
# Admin-only logic here
Usage as decorator:
@require_admin_role
@router.delete("/admin/users/{user_id}")
async def delete_user(...):
# Admin-only logic here
"""
@wraps(func)
async def wrapper(*args, **kwargs):
# Find request object in arguments
request = None
current_user = None
# Extract request and current_user from arguments
for arg in args:
if isinstance(arg, Request):
request = arg
elif isinstance(arg, dict) and 'user_id' in arg:
current_user = arg
# Check kwargs for request and current_user
if not request:
request = kwargs.get('request')
if not current_user:
current_user = kwargs.get('current_user')
# If we still don't have current_user, try to get it from request
if not current_user and request:
current_user = get_current_user(request)
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
# Check if user has admin role
user_role = current_user.get('role', '').lower()
if user_role != 'admin':
logger.warning("Non-admin user attempted admin operation",
user_id=current_user.get('user_id'),
role=user_role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin role required"
)
logger.info("Admin operation authorized",
user_id=current_user.get('user_id'),
endpoint=func.__name__)
return await func(*args, **kwargs)
return wrapper
def require_roles(allowed_roles: List[str]):
"""
Decorator to require one of multiple roles
Args:
allowed_roles: List of roles that are allowed to access the endpoint
Usage:
@require_roles(['admin', 'manager'])
@router.post("/sensitive-operation")
async def sensitive_operation(...):
# Only admins and managers can access
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
request = None
current_user = None
# Extract request and current_user from arguments
for arg in args:
if isinstance(arg, Request):
request = arg
elif isinstance(arg, dict) and 'user_id' in arg:
current_user = arg
# Check kwargs
if not request:
request = kwargs.get('request')
if not current_user:
current_user = kwargs.get('current_user')
# Get user from request if not provided
if not current_user and request:
current_user = get_current_user(request)
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
# Check if user has one of the allowed roles
user_role = current_user.get('role', '').lower()
allowed_roles_lower = [role.lower() for role in allowed_roles]
if user_role not in allowed_roles_lower:
logger.warning("Unauthorized role attempted restricted operation",
user_id=current_user.get('user_id'),
user_role=user_role,
allowed_roles=allowed_roles)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"One of these roles required: {', '.join(allowed_roles)}"
)
logger.info("Role-based operation authorized",
user_id=current_user.get('user_id'),
user_role=user_role,
endpoint=func.__name__)
return await func(*args, **kwargs)
return wrapper
return decorator
def require_tenant_admin(func: Callable) -> Callable:
"""
Decorator to require admin role within a specific tenant context
This checks that the user is an admin AND has access to the tenant
being operated on. Useful for tenant-scoped admin operations.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
request = None
current_user = None
# Extract request and current_user from arguments
for arg in args:
if isinstance(arg, Request):
request = arg
elif isinstance(arg, dict) and 'user_id' in arg:
current_user = arg
if not request:
request = kwargs.get('request')
if not current_user:
current_user = kwargs.get('current_user')
if not current_user and request:
current_user = get_current_user(request)
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
# Check admin role first
user_role = current_user.get('role', '').lower()
if user_role != 'admin':
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin role required"
)
# Check tenant access
tenant_id = get_current_tenant_id(request) if request else None
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Tenant context required"
)
# Additional tenant admin validation could go here
# For now, we trust that admin users have access to operate on any tenant
logger.info("Tenant admin operation authorized",
user_id=current_user.get('user_id'),
tenant_id=tenant_id,
endpoint=func.__name__)
return await func(*args, **kwargs)
return wrapper
def get_current_user(request: Request) -> Dict[str, Any]:
"""Get current user from request state or headers"""
if hasattr(request.state, 'user') and request.state.user:
@@ -145,14 +351,18 @@ def extract_user_from_headers(request: Request) -> Optional[Dict[str, Any]]:
"email": request.headers.get("x-user-email", ""),
"role": request.headers.get("x-user-role", "user"),
"tenant_id": request.headers.get("x-tenant-id"),
"permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else []
"permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else [],
"full_name": request.headers.get("x-user-full-name", "")
}
def extract_tenant_from_headers(request: Request) -> Optional[str]:
"""Extract tenant ID from headers"""
return request.headers.get("x-tenant-id")
# FastAPI Dependencies for injection
# ================================================================
# FASTAPI DEPENDENCY FUNCTIONS
# ================================================================
async def get_current_user_dep(request: Request) -> Dict[str, Any]:
"""FastAPI dependency to get current user"""
return get_current_user(request)
@@ -161,15 +371,180 @@ async def get_current_tenant_id_dep(request: Request) -> Optional[str]:
"""FastAPI dependency to get current tenant ID"""
return get_current_tenant_id(request)
async def require_admin_role_dep(
current_user: Dict[str, Any] = Depends(get_current_user_dep)
) -> Dict[str, Any]:
"""
FastAPI dependency that requires admin role
Usage:
@router.delete("/admin/users/{user_id}")
async def delete_user(
user_id: str,
admin_user: Dict[str, Any] = Depends(require_admin_role_dep)
):
# admin_user is guaranteed to have admin role
"""
user_role = current_user.get('role', '').lower()
if user_role != 'admin':
logger.warning("Non-admin user attempted admin operation",
user_id=current_user.get('user_id'),
role=user_role)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Admin role required"
)
logger.info("Admin operation authorized via dependency",
user_id=current_user.get('user_id'))
return current_user
async def require_roles_dep(allowed_roles: List[str]):
"""
FastAPI dependency factory that requires one of multiple roles
Usage:
require_manager_or_admin = require_roles_dep(['admin', 'manager'])
@router.post("/sensitive-operation")
async def sensitive_operation(
user: Dict[str, Any] = Depends(require_manager_or_admin)
):
# Only admins and managers can access
"""
async def check_roles(
current_user: Dict[str, Any] = Depends(get_current_user_dep)
) -> Dict[str, Any]:
user_role = current_user.get('role', '').lower()
allowed_roles_lower = [role.lower() for role in allowed_roles]
if user_role not in allowed_roles_lower:
logger.warning("Unauthorized role attempted restricted operation",
user_id=current_user.get('user_id'),
user_role=user_role,
allowed_roles=allowed_roles)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"One of these roles required: {', '.join(allowed_roles)}"
)
logger.info("Role-based operation authorized via dependency",
user_id=current_user.get('user_id'),
user_role=user_role)
return current_user
return check_roles
# ================================================================
# UTILITY FUNCTIONS FOR ROLE CHECKING
# ================================================================
def is_admin_user(user: Dict[str, Any]) -> bool:
"""Check if user has admin role"""
return user.get('role', '').lower() == 'admin'
def is_user_in_roles(user: Dict[str, Any], allowed_roles: List[str]) -> bool:
"""Check if user has one of the allowed roles"""
user_role = user.get('role', '').lower()
allowed_roles_lower = [role.lower() for role in allowed_roles]
return user_role in allowed_roles_lower
def get_user_permissions(user: Dict[str, Any]) -> List[str]:
"""Get user permissions list"""
return user.get('permissions', [])
def has_permission(user: Dict[str, Any], permission: str) -> bool:
"""Check if user has specific permission"""
permissions = get_user_permissions(user)
return permission in permissions
# ================================================================
# ADVANCED ROLE DECORATORS
# ================================================================
def require_permission(permission: str):
"""
Decorator to require specific permission
Usage:
@require_permission('delete_users')
@router.delete("/users/{user_id}")
async def delete_user(...):
# Only users with 'delete_users' permission can access
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
current_user = None
# Extract current_user from arguments
for arg in args:
if isinstance(arg, dict) and 'user_id' in arg:
current_user = arg
break
if not current_user:
current_user = kwargs.get('current_user')
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required"
)
# Check permission
if not has_permission(current_user, permission):
# Admins bypass permission checks
if not is_admin_user(current_user):
logger.warning("User lacks required permission",
user_id=current_user.get('user_id'),
required_permission=permission,
user_permissions=get_user_permissions(current_user))
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Permission '{permission}' required"
)
logger.info("Permission-based operation authorized",
user_id=current_user.get('user_id'),
permission=permission)
return await func(*args, **kwargs)
return wrapper
return decorator
# Export all decorators and functions
__all__ = [
# Main decorators
'require_authentication',
'require_tenant_access',
'require_tenant_access',
'require_role',
'get_current_user',
'get_current_tenant_id',
'require_admin_role',
'require_roles',
'require_tenant_admin',
'require_permission',
# FastAPI dependencies
'get_current_user_dep',
'get_current_tenant_id_dep',
'require_admin_role_dep',
'require_roles_dep',
# Utility functions
'get_current_user',
'get_current_tenant_id',
'extract_user_from_headers',
'extract_tenant_from_headers'
'extract_tenant_from_headers',
'is_admin_user',
'is_user_in_roles',
'get_user_permissions',
'has_permission'
]

View File

@@ -455,7 +455,8 @@ REGISTER_RESPONSE=$(curl -s -X POST "$API_BASE/api/v1/auth/register" \
-d "{
\"email\": \"$TEST_EMAIL\",
\"password\": \"$TEST_PASSWORD\",
\"full_name\": \"$TEST_NAME\"
\"full_name\": \"$TEST_NAME\",
\"role\": \"admin\"
}")
echo "Registration Response:"