Add user role
This commit is contained in:
@@ -32,7 +32,7 @@ async def register(
|
|||||||
|
|
||||||
# ✅ DEBUG: Log incoming registration data (without password)
|
# ✅ DEBUG: Log incoming registration data (without password)
|
||||||
logger.info(f"Registration attempt for email: {user_data.email}")
|
logger.info(f"Registration attempt for email: {user_data.email}")
|
||||||
logger.debug(f"Registration data - email: {user_data.email}, full_name: {user_data.full_name}")
|
logger.debug(f"Registration data - email: {user_data.email}, full_name: {user_data.full_name}, role: {user_data.role}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ✅ DEBUG: Validate input data
|
# ✅ DEBUG: Validate input data
|
||||||
|
|||||||
@@ -21,8 +21,7 @@ from app.services.admin_delete import AdminUserDeleteService
|
|||||||
# Import unified authentication from shared library
|
# Import unified authentication from shared library
|
||||||
from shared.auth.decorators import (
|
from shared.auth.decorators import (
|
||||||
get_current_user_dep,
|
get_current_user_dep,
|
||||||
get_current_tenant_id_dep,
|
require_admin_role
|
||||||
require_role # For admin-only endpoints
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
@@ -126,7 +125,7 @@ async def delete_admin_user(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
background_tasks: BackgroundTasks,
|
background_tasks: BackgroundTasks,
|
||||||
current_user = Depends(get_current_user_dep),
|
current_user = Depends(get_current_user_dep),
|
||||||
#_admin_check = Depends(require_admin_role),
|
_admin_check = Depends(require_admin_role),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -191,7 +190,7 @@ async def delete_admin_user(
|
|||||||
async def preview_user_deletion(
|
async def preview_user_deletion(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
current_user = Depends(get_current_user_dep),
|
current_user = Depends(get_current_user_dep),
|
||||||
#_admin_check = Depends(require_admin_role),
|
_admin_check = Depends(require_admin_role),
|
||||||
db: AsyncSession = Depends(get_db)
|
db: AsyncSession = Depends(get_db)
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ class User(Base):
|
|||||||
phone = Column(String(20))
|
phone = Column(String(20))
|
||||||
language = Column(String(10), default="es")
|
language = Column(String(10), default="es")
|
||||||
timezone = Column(String(50), default="Europe/Madrid")
|
timezone = Column(String(50), default="Europe/Madrid")
|
||||||
|
role = Column(String(20), default="user")
|
||||||
|
|
||||||
# REMOVED: All tenant relationships - these are handled by tenant service
|
# REMOVED: All tenant relationships - these are handled by tenant service
|
||||||
# No tenant_memberships, tenants relationships
|
# No tenant_memberships, tenants relationships
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class UserRegistration(BaseModel):
|
|||||||
password: str = Field(..., min_length=8, max_length=128)
|
password: str = Field(..., min_length=8, max_length=128)
|
||||||
full_name: str = Field(..., min_length=1, max_length=255)
|
full_name: str = Field(..., min_length=1, max_length=255)
|
||||||
tenant_name: Optional[str] = Field(None, max_length=255)
|
tenant_name: Optional[str] = Field(None, max_length=255)
|
||||||
|
role: Optional[str] = Field("user", pattern=r'^(user|admin|manager)$')
|
||||||
|
|
||||||
class UserLogin(BaseModel):
|
class UserLogin(BaseModel):
|
||||||
"""User login request"""
|
"""User login request"""
|
||||||
|
|||||||
@@ -48,7 +48,8 @@ class AuthService:
|
|||||||
is_active=True,
|
is_active=True,
|
||||||
is_verified=False,
|
is_verified=False,
|
||||||
created_at=datetime.now(timezone.utc),
|
created_at=datetime.now(timezone.utc),
|
||||||
updated_at=datetime.now(timezone.utc)
|
updated_at=datetime.now(timezone.utc),
|
||||||
|
role=user_data.role
|
||||||
)
|
)
|
||||||
|
|
||||||
db.add(new_user)
|
db.add(new_user)
|
||||||
@@ -115,7 +116,8 @@ class AuthService:
|
|||||||
"full_name": new_user.full_name,
|
"full_name": new_user.full_name,
|
||||||
"is_active": new_user.is_active,
|
"is_active": new_user.is_active,
|
||||||
"is_verified": new_user.is_verified,
|
"is_verified": new_user.is_verified,
|
||||||
"created_at": new_user.created_at.isoformat()
|
"created_at": new_user.created_at.isoformat(),
|
||||||
|
"role": new_user.role
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -242,7 +244,8 @@ class AuthService:
|
|||||||
"full_name": user.full_name,
|
"full_name": user.full_name,
|
||||||
"is_active": user.is_active,
|
"is_active": user.is_active,
|
||||||
"is_verified": user.is_verified,
|
"is_verified": user.is_verified,
|
||||||
"created_at": user.created_at.isoformat()
|
"created_at": user.created_at.isoformat(),
|
||||||
|
"role": user.role
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,12 +9,14 @@ import structlog
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
from fastapi import APIRouter, Depends, HTTPException, status, Query, Path
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from datetime import date
|
from datetime import date, datetime
|
||||||
|
from sqlalchemy import select, delete, func
|
||||||
|
import uuid
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
from shared.auth.decorators import (
|
from shared.auth.decorators import (
|
||||||
get_current_user_dep,
|
get_current_user_dep,
|
||||||
get_current_tenant_id_dep
|
require_admin_role
|
||||||
)
|
)
|
||||||
from app.services.forecasting_service import ForecastingService
|
from app.services.forecasting_service import ForecastingService
|
||||||
from app.schemas.forecasts import (
|
from app.schemas.forecasts import (
|
||||||
@@ -22,6 +24,7 @@ from app.schemas.forecasts import (
|
|||||||
BatchForecastResponse, AlertResponse
|
BatchForecastResponse, AlertResponse
|
||||||
)
|
)
|
||||||
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
from app.models.forecasts import Forecast, PredictionBatch, ForecastAlert
|
||||||
|
from app.services.messaging import publish_forecasts_deleted_event
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -319,3 +322,196 @@ async def acknowledge_alert(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Internal server error"
|
detail="Internal server error"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.delete("/forecasts/tenant/{tenant_id}")
|
||||||
|
async def delete_tenant_forecasts_complete(
|
||||||
|
tenant_id: str,
|
||||||
|
current_user = Depends(get_current_user_dep),
|
||||||
|
_admin_check = Depends(require_admin_role),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete all forecasts and predictions for a tenant.
|
||||||
|
|
||||||
|
**WARNING: This operation is irreversible!**
|
||||||
|
|
||||||
|
This endpoint:
|
||||||
|
1. Cancels any active prediction batches
|
||||||
|
2. Clears prediction cache
|
||||||
|
3. Deletes all forecast records
|
||||||
|
4. Deletes prediction batch records
|
||||||
|
5. Deletes model performance metrics
|
||||||
|
6. Publishes deletion event
|
||||||
|
|
||||||
|
Used by admin user deletion process to clean up all forecasting data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
tenant_uuid = uuid.UUID(tenant_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid tenant ID format"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.models.forecasts import Forecast, PredictionBatch
|
||||||
|
from app.models.predictions import ModelPerformanceMetric, PredictionCache
|
||||||
|
|
||||||
|
deletion_stats = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"deleted_at": datetime.utcnow().isoformat(),
|
||||||
|
"batches_cancelled": 0,
|
||||||
|
"forecasts_deleted": 0,
|
||||||
|
"prediction_batches_deleted": 0,
|
||||||
|
"performance_metrics_deleted": 0,
|
||||||
|
"cache_entries_deleted": 0,
|
||||||
|
"errors": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Step 1: Cancel active prediction batches
|
||||||
|
try:
|
||||||
|
active_batches_query = select(PredictionBatch).where(
|
||||||
|
PredictionBatch.tenant_id == tenant_uuid,
|
||||||
|
PredictionBatch.status.in_(["pending", "processing"])
|
||||||
|
)
|
||||||
|
active_batches_result = await db.execute(active_batches_query)
|
||||||
|
active_batches = active_batches_result.scalars().all()
|
||||||
|
|
||||||
|
for batch in active_batches:
|
||||||
|
batch.status = "cancelled"
|
||||||
|
batch.completed_at = datetime.utcnow()
|
||||||
|
deletion_stats["batches_cancelled"] += 1
|
||||||
|
|
||||||
|
if active_batches:
|
||||||
|
await db.commit()
|
||||||
|
logger.info("Cancelled active prediction batches",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=len(active_batches))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error cancelling prediction batches: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 2: Delete prediction cache
|
||||||
|
try:
|
||||||
|
cache_count_query = select(func.count(PredictionCache.id)).where(
|
||||||
|
PredictionCache.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
cache_count_result = await db.execute(cache_count_query)
|
||||||
|
cache_count = cache_count_result.scalar()
|
||||||
|
|
||||||
|
cache_delete_query = delete(PredictionCache).where(
|
||||||
|
PredictionCache.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(cache_delete_query)
|
||||||
|
deletion_stats["cache_entries_deleted"] = cache_count
|
||||||
|
|
||||||
|
logger.info("Deleted prediction cache entries",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=cache_count)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting prediction cache: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 3: Delete model performance metrics
|
||||||
|
try:
|
||||||
|
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
|
||||||
|
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
metrics_count_result = await db.execute(metrics_count_query)
|
||||||
|
metrics_count = metrics_count_result.scalar()
|
||||||
|
|
||||||
|
metrics_delete_query = delete(ModelPerformanceMetric).where(
|
||||||
|
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(metrics_delete_query)
|
||||||
|
deletion_stats["performance_metrics_deleted"] = metrics_count
|
||||||
|
|
||||||
|
logger.info("Deleted performance metrics",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=metrics_count)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting performance metrics: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 4: Delete prediction batches
|
||||||
|
try:
|
||||||
|
batches_count_query = select(func.count(PredictionBatch.id)).where(
|
||||||
|
PredictionBatch.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
batches_count_result = await db.execute(batches_count_query)
|
||||||
|
batches_count = batches_count_result.scalar()
|
||||||
|
|
||||||
|
batches_delete_query = delete(PredictionBatch).where(
|
||||||
|
PredictionBatch.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(batches_delete_query)
|
||||||
|
deletion_stats["prediction_batches_deleted"] = batches_count
|
||||||
|
|
||||||
|
logger.info("Deleted prediction batches",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=batches_count)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting prediction batches: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 5: Delete forecasts (main data)
|
||||||
|
try:
|
||||||
|
forecasts_count_query = select(func.count(Forecast.id)).where(
|
||||||
|
Forecast.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
forecasts_count_result = await db.execute(forecasts_count_query)
|
||||||
|
forecasts_count = forecasts_count_result.scalar()
|
||||||
|
|
||||||
|
forecasts_delete_query = delete(Forecast).where(
|
||||||
|
Forecast.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(forecasts_delete_query)
|
||||||
|
deletion_stats["forecasts_deleted"] = forecasts_count
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info("Deleted forecasts",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=forecasts_count)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = f"Error deleting forecasts: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 6: Publish deletion event
|
||||||
|
try:
|
||||||
|
await publish_forecasts_deleted_event(tenant_id, deletion_stats)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to publish forecasts deletion event", error=str(e))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"All forecasting data for tenant {tenant_id} deleted successfully",
|
||||||
|
"deletion_details": deletion_stats
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Unexpected error deleting tenant forecasts",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete tenant forecasts: {str(e)}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import structlog
|
|||||||
import json
|
import json
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
|
||||||
from shared.messaging.rabbitmq import RabbitMQClient
|
from shared.messaging.rabbitmq import RabbitMQClient
|
||||||
from shared.messaging.events import (
|
from shared.messaging.events import (
|
||||||
@@ -133,3 +134,19 @@ async def handle_weather_updated(data: Dict[str, Any]):
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Error handling weather updated event", error=str(e))
|
logger.error("Error handling weather updated event", error=str(e))
|
||||||
|
|
||||||
|
async def publish_forecasts_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
|
||||||
|
"""Publish forecasts deletion event to message queue"""
|
||||||
|
try:
|
||||||
|
await rabbitmq_client.publish_event(
|
||||||
|
exchange="forecasting_events",
|
||||||
|
routing_key="forecasting.tenant.deleted",
|
||||||
|
message={
|
||||||
|
"event_type": "tenant_forecasts_deleted",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"deletion_stats": deletion_stats
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to publish forecasts deletion event", error=str(e))
|
||||||
@@ -8,18 +8,20 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any
|
||||||
import structlog
|
import structlog
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
from sqlalchemy import select, delete, func
|
||||||
|
from datetime import datetime
|
||||||
|
import uuid
|
||||||
|
|
||||||
from app.core.database import get_db
|
from app.core.database import get_db
|
||||||
|
from app.services.messaging import publish_tenant_deleted_event
|
||||||
from app.schemas.tenants import (
|
from app.schemas.tenants import (
|
||||||
BakeryRegistration, TenantResponse, TenantAccessResponse,
|
BakeryRegistration, TenantResponse, TenantAccessResponse,
|
||||||
TenantUpdate, TenantMemberResponse
|
TenantUpdate, TenantMemberResponse
|
||||||
)
|
)
|
||||||
from app.services.tenant_service import TenantService
|
from app.services.tenant_service import TenantService
|
||||||
# Import unified authentication
|
|
||||||
from shared.auth.decorators import (
|
from shared.auth.decorators import (
|
||||||
get_current_user_dep,
|
get_current_user_dep,
|
||||||
get_current_tenant_id_dep,
|
require_admin_role
|
||||||
require_role
|
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
@@ -164,3 +166,149 @@ async def add_team_member(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to add team member"
|
detail="Failed to add team member"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.delete("/tenants/{tenant_id}")
|
||||||
|
async def delete_tenant_complete(
|
||||||
|
tenant_id: str,
|
||||||
|
current_user = Depends(get_current_user_dep),
|
||||||
|
_admin_check = Depends(require_admin_role),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete a tenant completely with all associated data.
|
||||||
|
|
||||||
|
**WARNING: This operation is irreversible!**
|
||||||
|
|
||||||
|
This endpoint:
|
||||||
|
1. Validates tenant exists and user has permissions
|
||||||
|
2. Deletes all tenant memberships
|
||||||
|
3. Deletes tenant subscription data
|
||||||
|
4. Deletes the tenant record
|
||||||
|
5. Publishes deletion event
|
||||||
|
|
||||||
|
Used by admin user deletion process when a tenant has no other admins.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
tenant_uuid = uuid.UUID(tenant_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid tenant ID format"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.models.tenants import Tenant, TenantMember, Subscription
|
||||||
|
|
||||||
|
# Step 1: Verify tenant exists
|
||||||
|
tenant_query = select(Tenant).where(Tenant.id == tenant_uuid)
|
||||||
|
tenant_result = await db.execute(tenant_query)
|
||||||
|
tenant = tenant_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not tenant:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Tenant {tenant_id} not found"
|
||||||
|
)
|
||||||
|
|
||||||
|
deletion_stats = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"tenant_name": tenant.name,
|
||||||
|
"deleted_at": datetime.utcnow().isoformat(),
|
||||||
|
"memberships_deleted": 0,
|
||||||
|
"subscriptions_deleted": 0,
|
||||||
|
"errors": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Step 2: Delete all tenant memberships
|
||||||
|
try:
|
||||||
|
membership_count_query = select(func.count(TenantMember.id)).where(
|
||||||
|
TenantMember.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
membership_count_result = await db.execute(membership_count_query)
|
||||||
|
membership_count = membership_count_result.scalar()
|
||||||
|
|
||||||
|
membership_delete_query = delete(TenantMember).where(
|
||||||
|
TenantMember.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(membership_delete_query)
|
||||||
|
deletion_stats["memberships_deleted"] = membership_count
|
||||||
|
|
||||||
|
logger.info("Deleted tenant memberships",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=membership_count)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting memberships: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 3: Delete subscription data
|
||||||
|
try:
|
||||||
|
subscription_count_query = select(func.count(Subscription.id)).where(
|
||||||
|
Subscription.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
subscription_count_result = await db.execute(subscription_count_query)
|
||||||
|
subscription_count = subscription_count_result.scalar()
|
||||||
|
|
||||||
|
subscription_delete_query = delete(Subscription).where(
|
||||||
|
Subscription.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(subscription_delete_query)
|
||||||
|
deletion_stats["subscriptions_deleted"] = subscription_count
|
||||||
|
|
||||||
|
logger.info("Deleted tenant subscriptions",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=subscription_count)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting subscriptions: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 4: Delete the tenant record
|
||||||
|
try:
|
||||||
|
tenant_delete_query = delete(Tenant).where(Tenant.id == tenant_uuid)
|
||||||
|
tenant_result = await db.execute(tenant_delete_query)
|
||||||
|
|
||||||
|
if tenant_result.rowcount == 0:
|
||||||
|
raise Exception("Tenant record was not deleted")
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info("Tenant deleted successfully",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
tenant_name=tenant.name)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = f"Error deleting tenant record: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 5: Publish tenant deletion event
|
||||||
|
try:
|
||||||
|
await publish_tenant_deleted_event(tenant_id, deletion_stats)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to publish tenant deletion event", error=str(e))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"Tenant {tenant_id} deleted successfully",
|
||||||
|
"deletion_details": deletion_stats
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Unexpected error deleting tenant",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete tenant: {str(e)}"
|
||||||
|
)
|
||||||
@@ -6,6 +6,7 @@ from shared.messaging.rabbitmq import RabbitMQClient
|
|||||||
from app.core.config import settings
|
from app.core.config import settings
|
||||||
import structlog
|
import structlog
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -41,3 +42,19 @@ async def publish_member_added(tenant_id: str, user_id: str, role: str):
|
|||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to publish tenant.member.added event: {e}")
|
logger.error(f"Failed to publish tenant.member.added event: {e}")
|
||||||
|
|
||||||
|
async def publish_tenant_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
|
||||||
|
"""Publish tenant deletion event to message queue"""
|
||||||
|
try:
|
||||||
|
await data_publisher.publish_event(
|
||||||
|
exchange="tenant_events",
|
||||||
|
routing_key="tenant.deleted",
|
||||||
|
message={
|
||||||
|
"event_type": "tenant_deleted",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"deletion_stats": deletion_stats
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to publish tenant deletion event", error=str(e))
|
||||||
@@ -12,9 +12,15 @@ from app.core.database import get_db
|
|||||||
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
|
from app.schemas.training import TrainedModelResponse, ModelMetricsResponse
|
||||||
from app.services.training_service import TrainingService
|
from app.services.training_service import TrainingService
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy import select, delete, func
|
||||||
|
import uuid
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
from app.services.messaging import publish_models_deleted_event
|
||||||
|
|
||||||
from shared.auth.decorators import (
|
from shared.auth.decorators import (
|
||||||
get_current_tenant_id_dep
|
get_current_user_dep,
|
||||||
|
require_admin_role
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
@@ -213,3 +219,243 @@ async def list_models(
|
|||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
detail="Failed to retrieve models"
|
detail="Failed to retrieve models"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@router.delete("/models/tenant/{tenant_id}")
|
||||||
|
async def delete_tenant_models_complete(
|
||||||
|
tenant_id: str,
|
||||||
|
current_user = Depends(get_current_user_dep),
|
||||||
|
_admin_check = Depends(require_admin_role),
|
||||||
|
db: AsyncSession = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete all trained models and artifacts for a tenant.
|
||||||
|
|
||||||
|
**WARNING: This operation is irreversible!**
|
||||||
|
|
||||||
|
This endpoint:
|
||||||
|
1. Cancels any active training jobs for the tenant
|
||||||
|
2. Deletes all model artifacts (files) from storage
|
||||||
|
3. Deletes model records from database
|
||||||
|
4. Deletes training logs and performance metrics
|
||||||
|
5. Publishes deletion event
|
||||||
|
|
||||||
|
Used by admin user deletion process to clean up all training data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
tenant_uuid = uuid.UUID(tenant_id)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid tenant ID format"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from app.models.training import (
|
||||||
|
ModelTrainingLog,
|
||||||
|
TrainedModel,
|
||||||
|
ModelArtifact,
|
||||||
|
ModelPerformanceMetric,
|
||||||
|
TrainingJobQueue
|
||||||
|
)
|
||||||
|
from app.core.config import settings
|
||||||
|
|
||||||
|
deletion_stats = {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"deleted_at": datetime.utcnow().isoformat(),
|
||||||
|
"jobs_cancelled": 0,
|
||||||
|
"models_deleted": 0,
|
||||||
|
"artifacts_deleted": 0,
|
||||||
|
"artifacts_files_deleted": 0,
|
||||||
|
"training_logs_deleted": 0,
|
||||||
|
"performance_metrics_deleted": 0,
|
||||||
|
"storage_freed_bytes": 0,
|
||||||
|
"errors": []
|
||||||
|
}
|
||||||
|
|
||||||
|
# Step 1: Cancel active training jobs
|
||||||
|
try:
|
||||||
|
active_jobs_query = select(TrainingJobQueue).where(
|
||||||
|
TrainingJobQueue.tenant_id == tenant_uuid,
|
||||||
|
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
||||||
|
)
|
||||||
|
active_jobs_result = await db.execute(active_jobs_query)
|
||||||
|
active_jobs = active_jobs_result.scalars().all()
|
||||||
|
|
||||||
|
for job in active_jobs:
|
||||||
|
job.status = "cancelled"
|
||||||
|
job.updated_at = datetime.utcnow()
|
||||||
|
deletion_stats["jobs_cancelled"] += 1
|
||||||
|
|
||||||
|
if active_jobs:
|
||||||
|
await db.commit()
|
||||||
|
logger.info("Cancelled active training jobs",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
count=len(active_jobs))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error cancelling training jobs: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 2: Delete model artifact files from storage
|
||||||
|
try:
|
||||||
|
artifacts_query = select(ModelArtifact).where(
|
||||||
|
ModelArtifact.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
artifacts_result = await db.execute(artifacts_query)
|
||||||
|
artifacts = artifacts_result.scalars().all()
|
||||||
|
|
||||||
|
storage_freed = 0
|
||||||
|
files_deleted = 0
|
||||||
|
|
||||||
|
for artifact in artifacts:
|
||||||
|
try:
|
||||||
|
file_path = Path(artifact.file_path)
|
||||||
|
if file_path.exists():
|
||||||
|
file_size = file_path.stat().st_size
|
||||||
|
file_path.unlink() # Delete file
|
||||||
|
storage_freed += file_size
|
||||||
|
files_deleted += 1
|
||||||
|
logger.debug("Deleted artifact file",
|
||||||
|
file_path=str(file_path),
|
||||||
|
size_bytes=file_size)
|
||||||
|
|
||||||
|
# Also try to delete parent directories if empty
|
||||||
|
try:
|
||||||
|
if file_path.parent.exists() and not any(file_path.parent.iterdir()):
|
||||||
|
file_path.parent.rmdir()
|
||||||
|
except:
|
||||||
|
pass # Ignore errors cleaning up directories
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting artifact file {artifact.file_path}: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.warning(error_msg)
|
||||||
|
|
||||||
|
deletion_stats["artifacts_files_deleted"] = files_deleted
|
||||||
|
deletion_stats["storage_freed_bytes"] = storage_freed
|
||||||
|
|
||||||
|
logger.info("Deleted artifact files",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
files_deleted=files_deleted,
|
||||||
|
storage_freed_mb=storage_freed / (1024 * 1024))
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error processing artifact files: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
|
||||||
|
# Step 3: Delete database records
|
||||||
|
try:
|
||||||
|
# Delete model performance metrics
|
||||||
|
metrics_count_query = select(func.count(ModelPerformanceMetric.id)).where(
|
||||||
|
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
metrics_count_result = await db.execute(metrics_count_query)
|
||||||
|
metrics_count = metrics_count_result.scalar()
|
||||||
|
|
||||||
|
metrics_delete_query = delete(ModelPerformanceMetric).where(
|
||||||
|
ModelPerformanceMetric.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(metrics_delete_query)
|
||||||
|
deletion_stats["performance_metrics_deleted"] = metrics_count
|
||||||
|
|
||||||
|
# Delete model artifacts records
|
||||||
|
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
|
||||||
|
ModelArtifact.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
artifacts_count_result = await db.execute(artifacts_count_query)
|
||||||
|
artifacts_count = artifacts_count_result.scalar()
|
||||||
|
|
||||||
|
artifacts_delete_query = delete(ModelArtifact).where(
|
||||||
|
ModelArtifact.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(artifacts_delete_query)
|
||||||
|
deletion_stats["artifacts_deleted"] = artifacts_count
|
||||||
|
|
||||||
|
# Delete trained models
|
||||||
|
models_count_query = select(func.count(TrainedModel.id)).where(
|
||||||
|
TrainedModel.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
models_count_result = await db.execute(models_count_query)
|
||||||
|
models_count = models_count_result.scalar()
|
||||||
|
|
||||||
|
models_delete_query = delete(TrainedModel).where(
|
||||||
|
TrainedModel.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(models_delete_query)
|
||||||
|
deletion_stats["models_deleted"] = models_count
|
||||||
|
|
||||||
|
# Delete training logs
|
||||||
|
logs_count_query = select(func.count(ModelTrainingLog.id)).where(
|
||||||
|
ModelTrainingLog.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
logs_count_result = await db.execute(logs_count_query)
|
||||||
|
logs_count = logs_count_result.scalar()
|
||||||
|
|
||||||
|
logs_delete_query = delete(ModelTrainingLog).where(
|
||||||
|
ModelTrainingLog.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(logs_delete_query)
|
||||||
|
deletion_stats["training_logs_deleted"] = logs_count
|
||||||
|
|
||||||
|
# Delete job queue entries
|
||||||
|
queue_delete_query = delete(TrainingJobQueue).where(
|
||||||
|
TrainingJobQueue.tenant_id == tenant_uuid
|
||||||
|
)
|
||||||
|
await db.execute(queue_delete_query)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
logger.info("Deleted training database records",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
models=models_count,
|
||||||
|
artifacts=artifacts_count,
|
||||||
|
logs=logs_count,
|
||||||
|
metrics=metrics_count)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
await db.rollback()
|
||||||
|
error_msg = f"Error deleting database records: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.error(error_msg)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=error_msg
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 4: Clean up tenant model directory
|
||||||
|
try:
|
||||||
|
tenant_model_dir = Path(settings.MODEL_STORAGE_PATH) / tenant_id
|
||||||
|
if tenant_model_dir.exists():
|
||||||
|
shutil.rmtree(tenant_model_dir)
|
||||||
|
logger.info("Deleted tenant model directory",
|
||||||
|
directory=str(tenant_model_dir))
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"Error deleting model directory: {str(e)}"
|
||||||
|
deletion_stats["errors"].append(error_msg)
|
||||||
|
logger.warning(error_msg)
|
||||||
|
|
||||||
|
# Step 5: Publish deletion event
|
||||||
|
try:
|
||||||
|
await publish_models_deleted_event(tenant_id, deletion_stats)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning("Failed to publish models deletion event", error=str(e))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": f"All training data for tenant {tenant_id} deleted successfully",
|
||||||
|
"deletion_details": deletion_stats
|
||||||
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Unexpected error deleting tenant models",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
error=str(e))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
detail=f"Failed to delete tenant models: {str(e)}"
|
||||||
|
)
|
||||||
@@ -442,6 +442,24 @@ async def publish_data_validation_completed(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def publish_models_deleted_event(tenant_id: str, deletion_stats: Dict[str, Any]):
|
||||||
|
"""Publish models deletion event to message queue"""
|
||||||
|
try:
|
||||||
|
await training_publisher.publish_event(
|
||||||
|
exchange="training_events",
|
||||||
|
routing_key="training.tenant.models.deleted",
|
||||||
|
message={
|
||||||
|
"event_type": "tenant_models_deleted",
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"timestamp": datetime.utcnow().isoformat(),
|
||||||
|
"deletion_stats": deletion_stats
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Failed to publish models deletion event", error=str(e))
|
||||||
|
|
||||||
|
|
||||||
# =========================================
|
# =========================================
|
||||||
# UTILITY FUNCTIONS FOR BATCH PUBLISHING
|
# UTILITY FUNCTIONS FOR BATCH PUBLISHING
|
||||||
# =========================================
|
# =========================================
|
||||||
@@ -550,3 +568,4 @@ class TrainingStatusPublisher:
|
|||||||
"""Publish job failure with clean error details"""
|
"""Publish job failure with clean error details"""
|
||||||
clean_error_details = safe_json_serialize(error_details) if error_details else None
|
clean_error_details = safe_json_serialize(error_details) if error_details else None
|
||||||
await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details)
|
await publish_job_failed(self.job_id, self.tenant_id, error, clean_error_details)
|
||||||
|
|
||||||
@@ -1,12 +1,16 @@
|
|||||||
|
# ================================================================
|
||||||
|
# shared/auth/decorators.py - ENHANCED WITH ADMIN ROLE DECORATOR
|
||||||
|
# ================================================================
|
||||||
"""
|
"""
|
||||||
Unified authentication decorators for microservices
|
Enhanced authentication decorators for microservices including admin role validation.
|
||||||
Designed to work with gateway authentication middleware
|
Designed to work with gateway authentication middleware and provide centralized
|
||||||
|
role-based access control across all services.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from fastapi import HTTPException, status, Request, Depends
|
from fastapi import HTTPException, status, Request, Depends
|
||||||
from fastapi.security import HTTPBearer
|
from fastapi.security import HTTPBearer
|
||||||
from typing import Callable, Optional, Dict, Any
|
from typing import Callable, Optional, Dict, Any, List
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
logger = structlog.get_logger()
|
logger = structlog.get_logger()
|
||||||
@@ -111,6 +115,208 @@ def require_role(role: str):
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
def require_admin_role(func: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to require admin role - simplified version for FastAPI dependencies
|
||||||
|
|
||||||
|
This decorator ensures only users with 'admin' role can access the endpoint.
|
||||||
|
Can be used as a FastAPI dependency or function decorator.
|
||||||
|
|
||||||
|
Usage as dependency:
|
||||||
|
@router.delete("/admin/users/{user_id}")
|
||||||
|
async def delete_user(
|
||||||
|
user_id: str,
|
||||||
|
current_user = Depends(get_current_user_dep),
|
||||||
|
_admin_check = Depends(require_admin_role),
|
||||||
|
):
|
||||||
|
# Admin-only logic here
|
||||||
|
|
||||||
|
Usage as decorator:
|
||||||
|
@require_admin_role
|
||||||
|
@router.delete("/admin/users/{user_id}")
|
||||||
|
async def delete_user(...):
|
||||||
|
# Admin-only logic here
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
# Find request object in arguments
|
||||||
|
request = None
|
||||||
|
current_user = None
|
||||||
|
|
||||||
|
# Extract request and current_user from arguments
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, Request):
|
||||||
|
request = arg
|
||||||
|
elif isinstance(arg, dict) and 'user_id' in arg:
|
||||||
|
current_user = arg
|
||||||
|
|
||||||
|
# Check kwargs for request and current_user
|
||||||
|
if not request:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not current_user:
|
||||||
|
current_user = kwargs.get('current_user')
|
||||||
|
|
||||||
|
# If we still don't have current_user, try to get it from request
|
||||||
|
if not current_user and request:
|
||||||
|
current_user = get_current_user(request)
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user has admin role
|
||||||
|
user_role = current_user.get('role', '').lower()
|
||||||
|
|
||||||
|
if user_role != 'admin':
|
||||||
|
logger.warning("Non-admin user attempted admin operation",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
role=user_role)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin role required"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Admin operation authorized",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
endpoint=func.__name__)
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
def require_roles(allowed_roles: List[str]):
|
||||||
|
"""
|
||||||
|
Decorator to require one of multiple roles
|
||||||
|
|
||||||
|
Args:
|
||||||
|
allowed_roles: List of roles that are allowed to access the endpoint
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@require_roles(['admin', 'manager'])
|
||||||
|
@router.post("/sensitive-operation")
|
||||||
|
async def sensitive_operation(...):
|
||||||
|
# Only admins and managers can access
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
request = None
|
||||||
|
current_user = None
|
||||||
|
|
||||||
|
# Extract request and current_user from arguments
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, Request):
|
||||||
|
request = arg
|
||||||
|
elif isinstance(arg, dict) and 'user_id' in arg:
|
||||||
|
current_user = arg
|
||||||
|
|
||||||
|
# Check kwargs
|
||||||
|
if not request:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not current_user:
|
||||||
|
current_user = kwargs.get('current_user')
|
||||||
|
|
||||||
|
# Get user from request if not provided
|
||||||
|
if not current_user and request:
|
||||||
|
current_user = get_current_user(request)
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if user has one of the allowed roles
|
||||||
|
user_role = current_user.get('role', '').lower()
|
||||||
|
allowed_roles_lower = [role.lower() for role in allowed_roles]
|
||||||
|
|
||||||
|
if user_role not in allowed_roles_lower:
|
||||||
|
logger.warning("Unauthorized role attempted restricted operation",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
user_role=user_role,
|
||||||
|
allowed_roles=allowed_roles)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"One of these roles required: {', '.join(allowed_roles)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Role-based operation authorized",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
user_role=user_role,
|
||||||
|
endpoint=func.__name__)
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
def require_tenant_admin(func: Callable) -> Callable:
|
||||||
|
"""
|
||||||
|
Decorator to require admin role within a specific tenant context
|
||||||
|
|
||||||
|
This checks that the user is an admin AND has access to the tenant
|
||||||
|
being operated on. Useful for tenant-scoped admin operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
request = None
|
||||||
|
current_user = None
|
||||||
|
|
||||||
|
# Extract request and current_user from arguments
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, Request):
|
||||||
|
request = arg
|
||||||
|
elif isinstance(arg, dict) and 'user_id' in arg:
|
||||||
|
current_user = arg
|
||||||
|
|
||||||
|
if not request:
|
||||||
|
request = kwargs.get('request')
|
||||||
|
if not current_user:
|
||||||
|
current_user = kwargs.get('current_user')
|
||||||
|
|
||||||
|
if not current_user and request:
|
||||||
|
current_user = get_current_user(request)
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check admin role first
|
||||||
|
user_role = current_user.get('role', '').lower()
|
||||||
|
if user_role != 'admin':
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin role required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check tenant access
|
||||||
|
tenant_id = get_current_tenant_id(request) if request else None
|
||||||
|
if not tenant_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Tenant context required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Additional tenant admin validation could go here
|
||||||
|
# For now, we trust that admin users have access to operate on any tenant
|
||||||
|
|
||||||
|
logger.info("Tenant admin operation authorized",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
endpoint=func.__name__)
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
def get_current_user(request: Request) -> Dict[str, Any]:
|
def get_current_user(request: Request) -> Dict[str, Any]:
|
||||||
"""Get current user from request state or headers"""
|
"""Get current user from request state or headers"""
|
||||||
if hasattr(request.state, 'user') and request.state.user:
|
if hasattr(request.state, 'user') and request.state.user:
|
||||||
@@ -145,14 +351,18 @@ def extract_user_from_headers(request: Request) -> Optional[Dict[str, Any]]:
|
|||||||
"email": request.headers.get("x-user-email", ""),
|
"email": request.headers.get("x-user-email", ""),
|
||||||
"role": request.headers.get("x-user-role", "user"),
|
"role": request.headers.get("x-user-role", "user"),
|
||||||
"tenant_id": request.headers.get("x-tenant-id"),
|
"tenant_id": request.headers.get("x-tenant-id"),
|
||||||
"permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else []
|
"permissions": request.headers.get("X-User-Permissions", "").split(",") if request.headers.get("X-User-Permissions") else [],
|
||||||
|
"full_name": request.headers.get("x-user-full-name", "")
|
||||||
}
|
}
|
||||||
|
|
||||||
def extract_tenant_from_headers(request: Request) -> Optional[str]:
|
def extract_tenant_from_headers(request: Request) -> Optional[str]:
|
||||||
"""Extract tenant ID from headers"""
|
"""Extract tenant ID from headers"""
|
||||||
return request.headers.get("x-tenant-id")
|
return request.headers.get("x-tenant-id")
|
||||||
|
|
||||||
# FastAPI Dependencies for injection
|
# ================================================================
|
||||||
|
# FASTAPI DEPENDENCY FUNCTIONS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
async def get_current_user_dep(request: Request) -> Dict[str, Any]:
|
async def get_current_user_dep(request: Request) -> Dict[str, Any]:
|
||||||
"""FastAPI dependency to get current user"""
|
"""FastAPI dependency to get current user"""
|
||||||
return get_current_user(request)
|
return get_current_user(request)
|
||||||
@@ -161,15 +371,180 @@ async def get_current_tenant_id_dep(request: Request) -> Optional[str]:
|
|||||||
"""FastAPI dependency to get current tenant ID"""
|
"""FastAPI dependency to get current tenant ID"""
|
||||||
return get_current_tenant_id(request)
|
return get_current_tenant_id(request)
|
||||||
|
|
||||||
|
async def require_admin_role_dep(
|
||||||
|
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency that requires admin role
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.delete("/admin/users/{user_id}")
|
||||||
|
async def delete_user(
|
||||||
|
user_id: str,
|
||||||
|
admin_user: Dict[str, Any] = Depends(require_admin_role_dep)
|
||||||
|
):
|
||||||
|
# admin_user is guaranteed to have admin role
|
||||||
|
"""
|
||||||
|
|
||||||
|
user_role = current_user.get('role', '').lower()
|
||||||
|
|
||||||
|
if user_role != 'admin':
|
||||||
|
logger.warning("Non-admin user attempted admin operation",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
role=user_role)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Admin role required"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Admin operation authorized via dependency",
|
||||||
|
user_id=current_user.get('user_id'))
|
||||||
|
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
async def require_roles_dep(allowed_roles: List[str]):
|
||||||
|
"""
|
||||||
|
FastAPI dependency factory that requires one of multiple roles
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
require_manager_or_admin = require_roles_dep(['admin', 'manager'])
|
||||||
|
|
||||||
|
@router.post("/sensitive-operation")
|
||||||
|
async def sensitive_operation(
|
||||||
|
user: Dict[str, Any] = Depends(require_manager_or_admin)
|
||||||
|
):
|
||||||
|
# Only admins and managers can access
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def check_roles(
|
||||||
|
current_user: Dict[str, Any] = Depends(get_current_user_dep)
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
user_role = current_user.get('role', '').lower()
|
||||||
|
allowed_roles_lower = [role.lower() for role in allowed_roles]
|
||||||
|
|
||||||
|
if user_role not in allowed_roles_lower:
|
||||||
|
logger.warning("Unauthorized role attempted restricted operation",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
user_role=user_role,
|
||||||
|
allowed_roles=allowed_roles)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"One of these roles required: {', '.join(allowed_roles)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Role-based operation authorized via dependency",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
user_role=user_role)
|
||||||
|
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
return check_roles
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# UTILITY FUNCTIONS FOR ROLE CHECKING
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
def is_admin_user(user: Dict[str, Any]) -> bool:
|
||||||
|
"""Check if user has admin role"""
|
||||||
|
return user.get('role', '').lower() == 'admin'
|
||||||
|
|
||||||
|
def is_user_in_roles(user: Dict[str, Any], allowed_roles: List[str]) -> bool:
|
||||||
|
"""Check if user has one of the allowed roles"""
|
||||||
|
user_role = user.get('role', '').lower()
|
||||||
|
allowed_roles_lower = [role.lower() for role in allowed_roles]
|
||||||
|
return user_role in allowed_roles_lower
|
||||||
|
|
||||||
|
def get_user_permissions(user: Dict[str, Any]) -> List[str]:
|
||||||
|
"""Get user permissions list"""
|
||||||
|
return user.get('permissions', [])
|
||||||
|
|
||||||
|
def has_permission(user: Dict[str, Any], permission: str) -> bool:
|
||||||
|
"""Check if user has specific permission"""
|
||||||
|
permissions = get_user_permissions(user)
|
||||||
|
return permission in permissions
|
||||||
|
|
||||||
|
# ================================================================
|
||||||
|
# ADVANCED ROLE DECORATORS
|
||||||
|
# ================================================================
|
||||||
|
|
||||||
|
def require_permission(permission: str):
|
||||||
|
"""
|
||||||
|
Decorator to require specific permission
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@require_permission('delete_users')
|
||||||
|
@router.delete("/users/{user_id}")
|
||||||
|
async def delete_user(...):
|
||||||
|
# Only users with 'delete_users' permission can access
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
@wraps(func)
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
current_user = None
|
||||||
|
|
||||||
|
# Extract current_user from arguments
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, dict) and 'user_id' in arg:
|
||||||
|
current_user = arg
|
||||||
|
break
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
current_user = kwargs.get('current_user')
|
||||||
|
|
||||||
|
if not current_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Authentication required"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check permission
|
||||||
|
if not has_permission(current_user, permission):
|
||||||
|
# Admins bypass permission checks
|
||||||
|
if not is_admin_user(current_user):
|
||||||
|
logger.warning("User lacks required permission",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
required_permission=permission,
|
||||||
|
user_permissions=get_user_permissions(current_user))
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Permission '{permission}' required"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Permission-based operation authorized",
|
||||||
|
user_id=current_user.get('user_id'),
|
||||||
|
permission=permission)
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
# Export all decorators and functions
|
# Export all decorators and functions
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
# Main decorators
|
||||||
'require_authentication',
|
'require_authentication',
|
||||||
'require_tenant_access',
|
'require_tenant_access',
|
||||||
'require_role',
|
'require_role',
|
||||||
'get_current_user',
|
'require_admin_role',
|
||||||
'get_current_tenant_id',
|
'require_roles',
|
||||||
|
'require_tenant_admin',
|
||||||
|
'require_permission',
|
||||||
|
|
||||||
|
# FastAPI dependencies
|
||||||
'get_current_user_dep',
|
'get_current_user_dep',
|
||||||
'get_current_tenant_id_dep',
|
'get_current_tenant_id_dep',
|
||||||
|
'require_admin_role_dep',
|
||||||
|
'require_roles_dep',
|
||||||
|
|
||||||
|
# Utility functions
|
||||||
|
'get_current_user',
|
||||||
|
'get_current_tenant_id',
|
||||||
'extract_user_from_headers',
|
'extract_user_from_headers',
|
||||||
'extract_tenant_from_headers'
|
'extract_tenant_from_headers',
|
||||||
|
'is_admin_user',
|
||||||
|
'is_user_in_roles',
|
||||||
|
'get_user_permissions',
|
||||||
|
'has_permission'
|
||||||
]
|
]
|
||||||
@@ -455,7 +455,8 @@ REGISTER_RESPONSE=$(curl -s -X POST "$API_BASE/api/v1/auth/register" \
|
|||||||
-d "{
|
-d "{
|
||||||
\"email\": \"$TEST_EMAIL\",
|
\"email\": \"$TEST_EMAIL\",
|
||||||
\"password\": \"$TEST_PASSWORD\",
|
\"password\": \"$TEST_PASSWORD\",
|
||||||
\"full_name\": \"$TEST_NAME\"
|
\"full_name\": \"$TEST_NAME\",
|
||||||
|
\"role\": \"admin\"
|
||||||
}")
|
}")
|
||||||
|
|
||||||
echo "Registration Response:"
|
echo "Registration Response:"
|
||||||
|
|||||||
Reference in New Issue
Block a user