Improve user delete flow

This commit is contained in:
Urtzi Alfaro
2025-08-02 17:09:53 +02:00
parent 277e8bec73
commit 3681429e11
10 changed files with 1334 additions and 210 deletions

View File

@@ -14,6 +14,7 @@ import uuid
from app.core.database import get_db, get_background_db_session
from app.services.training_service import TrainingService
from sqlalchemy import select, delete, func
from app.schemas.training import (
TrainingJobRequest,
SingleProductTrainingRequest
@@ -33,8 +34,7 @@ from app.services.messaging import (
)
# Import shared auth decorators (assuming they exist in your microservices)
from shared.auth.decorators import get_current_tenant_id_dep
from shared.auth.decorators import require_admin_role, get_current_user_dep, get_current_tenant_id_dep
logger = structlog.get_logger()
router = APIRouter()
@@ -224,6 +224,9 @@ async def start_single_product_training(
Uses the same pipeline but filters for specific product.
"""
training_service = TrainingService(db_session=db)
try:
# Validate tenant access
if tenant_id != current_tenant:
@@ -260,36 +263,6 @@ async def start_single_product_training(
detail="Single product training failed"
)
@router.post("/tenants/{tenant_id}/training/jobs/{job_id}/cancel")
async def cancel_training_job(
tenant_id: str = Path(..., description="Tenant ID"),
job_id: str = Path(..., description="Job ID"),
current_tenant: str = Depends(get_current_tenant_id_dep),
db: AsyncSession = Depends(get_db)
):
"""
Cancel a running training job.
"""
try:
# Validate tenant access
if tenant_id != current_tenant:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied to tenant resources"
)
# TODO: Implement job cancellation
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
return {"message": "Training job cancelled successfully"}
except Exception as e:
logger.error(f"Failed to cancel training job: {str(e)}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training job"
)
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
async def get_training_logs(
tenant_id: str = Path(..., description="Tenant ID"),
@@ -337,4 +310,189 @@ async def health_check():
"service": "training",
"version": "1.0.0",
"timestamp": datetime.now().isoformat()
}
}
@router.post("/tenants/{tenant_id}/training/jobs/cancel")
async def cancel_tenant_training_jobs(
cancel_data: dict, # {"tenant_id": str}
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Cancel all active training jobs for a tenant (admin only)"""
try:
tenant_id = cancel_data.get("tenant_id")
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="tenant_id is required"
)
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Find all active jobs for the tenant
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs_cancelled = 0
cancelled_job_ids = []
errors = []
for job in active_jobs:
try:
job.status = "cancelled"
job.updated_at = datetime.utcnow()
job.cancelled_by = current_user.get("user_id")
jobs_cancelled += 1
cancelled_job_ids.append(str(job.id))
logger.info("Cancelled training job",
job_id=str(job.id),
tenant_id=tenant_id)
except Exception as e:
error_msg = f"Failed to cancel job {job.id}: {str(e)}"
errors.append(error_msg)
logger.error(error_msg)
if jobs_cancelled > 0:
await db.commit()
result = {
"success": True,
"tenant_id": tenant_id,
"jobs_cancelled": jobs_cancelled,
"cancelled_job_ids": cancelled_job_ids,
"errors": errors,
"cancelled_at": datetime.utcnow().isoformat()
}
if errors:
result["success"] = len(errors) < len(active_jobs)
return result
except Exception as e:
await db.rollback()
logger.error("Failed to cancel tenant training jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to cancel training jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/active")
async def get_tenant_active_jobs(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get all active training jobs for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainingJobQueue
# Get active jobs
active_jobs_query = select(TrainingJobQueue).where(
TrainingJobQueue.tenant_id == tenant_uuid,
TrainingJobQueue.status.in_(["queued", "running", "pending"])
)
active_jobs_result = await db.execute(active_jobs_query)
active_jobs = active_jobs_result.scalars().all()
jobs = []
for job in active_jobs:
jobs.append({
"id": str(job.id),
"tenant_id": str(job.tenant_id),
"status": job.status,
"created_at": job.created_at.isoformat() if job.created_at else None,
"updated_at": job.updated_at.isoformat() if job.updated_at else None,
"started_at": job.started_at.isoformat() if job.started_at else None,
"progress": getattr(job, 'progress', 0)
})
return {
"tenant_id": tenant_id,
"active_jobs_count": len(jobs),
"jobs": jobs
}
except Exception as e:
logger.error("Failed to get tenant active jobs",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get active jobs"
)
@router.get("/tenants/{tenant_id}/training/jobs/count")
async def get_tenant_models_count(
tenant_id: str,
current_user = Depends(get_current_user_dep),
_admin_check = Depends(require_admin_role),
db: AsyncSession = Depends(get_db)
):
"""Get count of trained models for a tenant (admin only)"""
try:
tenant_uuid = uuid.UUID(tenant_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid tenant ID format"
)
try:
from app.models.training import TrainedModel, ModelArtifact
# Count models
models_count_query = select(func.count(TrainedModel.id)).where(
TrainedModel.tenant_id == tenant_uuid
)
models_count_result = await db.execute(models_count_query)
models_count = models_count_result.scalar()
# Count artifacts
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
ModelArtifact.tenant_id == tenant_uuid
)
artifacts_count_result = await db.execute(artifacts_count_query)
artifacts_count = artifacts_count_result.scalar()
return {
"tenant_id": tenant_id,
"models_count": models_count,
"artifacts_count": artifacts_count,
"total_training_assets": models_count + artifacts_count
}
except Exception as e:
logger.error("Failed to get tenant models count",
tenant_id=tenant_id,
error=str(e))
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to get models count"
)

View File

@@ -96,6 +96,7 @@ class TrainingJobQueue(Base):
# Metadata
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
cancelled_by = Column(String, nullable=True)
class ModelArtifact(Base):
"""