Improve user delete flow
This commit is contained in:
@@ -14,6 +14,7 @@ import uuid
|
||||
|
||||
from app.core.database import get_db, get_background_db_session
|
||||
from app.services.training_service import TrainingService
|
||||
from sqlalchemy import select, delete, func
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
SingleProductTrainingRequest
|
||||
@@ -33,8 +34,7 @@ from app.services.messaging import (
|
||||
)
|
||||
|
||||
|
||||
# Import shared auth decorators (assuming they exist in your microservices)
|
||||
from shared.auth.decorators import get_current_tenant_id_dep
|
||||
from shared.auth.decorators import require_admin_role, get_current_user_dep, get_current_tenant_id_dep
|
||||
|
||||
logger = structlog.get_logger()
|
||||
router = APIRouter()
|
||||
@@ -224,6 +224,9 @@ async def start_single_product_training(
|
||||
|
||||
Uses the same pipeline but filters for specific product.
|
||||
"""
|
||||
|
||||
training_service = TrainingService(db_session=db)
|
||||
|
||||
try:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
@@ -260,36 +263,6 @@ async def start_single_product_training(
|
||||
detail="Single product training failed"
|
||||
)
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/jobs/{job_id}/cancel")
|
||||
async def cancel_training_job(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
job_id: str = Path(..., description="Job ID"),
|
||||
current_tenant: str = Depends(get_current_tenant_id_dep),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Cancel a running training job.
|
||||
"""
|
||||
try:
|
||||
# Validate tenant access
|
||||
if tenant_id != current_tenant:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied to tenant resources"
|
||||
)
|
||||
|
||||
# TODO: Implement job cancellation
|
||||
logger.info(f"Cancelling training job {job_id} for tenant {tenant_id}")
|
||||
|
||||
return {"message": "Training job cancelled successfully"}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to cancel training job: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cancel training job"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/{job_id}/logs")
|
||||
async def get_training_logs(
|
||||
tenant_id: str = Path(..., description="Tenant ID"),
|
||||
@@ -337,4 +310,189 @@ async def health_check():
|
||||
"service": "training",
|
||||
"version": "1.0.0",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
@router.post("/tenants/{tenant_id}/training/jobs/cancel")
|
||||
async def cancel_tenant_training_jobs(
|
||||
cancel_data: dict, # {"tenant_id": str}
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Cancel all active training jobs for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_id = cancel_data.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="tenant_id is required"
|
||||
)
|
||||
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.training import TrainingJobQueue
|
||||
|
||||
# Find all active jobs for the tenant
|
||||
active_jobs_query = select(TrainingJobQueue).where(
|
||||
TrainingJobQueue.tenant_id == tenant_uuid,
|
||||
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
||||
)
|
||||
active_jobs_result = await db.execute(active_jobs_query)
|
||||
active_jobs = active_jobs_result.scalars().all()
|
||||
|
||||
jobs_cancelled = 0
|
||||
cancelled_job_ids = []
|
||||
errors = []
|
||||
|
||||
for job in active_jobs:
|
||||
try:
|
||||
job.status = "cancelled"
|
||||
job.updated_at = datetime.utcnow()
|
||||
job.cancelled_by = current_user.get("user_id")
|
||||
jobs_cancelled += 1
|
||||
cancelled_job_ids.append(str(job.id))
|
||||
|
||||
logger.info("Cancelled training job",
|
||||
job_id=str(job.id),
|
||||
tenant_id=tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to cancel job {job.id}: {str(e)}"
|
||||
errors.append(error_msg)
|
||||
logger.error(error_msg)
|
||||
|
||||
if jobs_cancelled > 0:
|
||||
await db.commit()
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"tenant_id": tenant_id,
|
||||
"jobs_cancelled": jobs_cancelled,
|
||||
"cancelled_job_ids": cancelled_job_ids,
|
||||
"errors": errors,
|
||||
"cancelled_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
if errors:
|
||||
result["success"] = len(errors) < len(active_jobs)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await db.rollback()
|
||||
logger.error("Failed to cancel tenant training jobs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to cancel training jobs"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/active")
|
||||
async def get_tenant_active_jobs(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get all active training jobs for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.training import TrainingJobQueue
|
||||
|
||||
# Get active jobs
|
||||
active_jobs_query = select(TrainingJobQueue).where(
|
||||
TrainingJobQueue.tenant_id == tenant_uuid,
|
||||
TrainingJobQueue.status.in_(["queued", "running", "pending"])
|
||||
)
|
||||
active_jobs_result = await db.execute(active_jobs_query)
|
||||
active_jobs = active_jobs_result.scalars().all()
|
||||
|
||||
jobs = []
|
||||
for job in active_jobs:
|
||||
jobs.append({
|
||||
"id": str(job.id),
|
||||
"tenant_id": str(job.tenant_id),
|
||||
"status": job.status,
|
||||
"created_at": job.created_at.isoformat() if job.created_at else None,
|
||||
"updated_at": job.updated_at.isoformat() if job.updated_at else None,
|
||||
"started_at": job.started_at.isoformat() if job.started_at else None,
|
||||
"progress": getattr(job, 'progress', 0)
|
||||
})
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"active_jobs_count": len(jobs),
|
||||
"jobs": jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant active jobs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get active jobs"
|
||||
)
|
||||
|
||||
@router.get("/tenants/{tenant_id}/training/jobs/count")
|
||||
async def get_tenant_models_count(
|
||||
tenant_id: str,
|
||||
current_user = Depends(get_current_user_dep),
|
||||
_admin_check = Depends(require_admin_role),
|
||||
db: AsyncSession = Depends(get_db)
|
||||
):
|
||||
"""Get count of trained models for a tenant (admin only)"""
|
||||
try:
|
||||
tenant_uuid = uuid.UUID(tenant_id)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid tenant ID format"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.models.training import TrainedModel, ModelArtifact
|
||||
|
||||
# Count models
|
||||
models_count_query = select(func.count(TrainedModel.id)).where(
|
||||
TrainedModel.tenant_id == tenant_uuid
|
||||
)
|
||||
models_count_result = await db.execute(models_count_query)
|
||||
models_count = models_count_result.scalar()
|
||||
|
||||
# Count artifacts
|
||||
artifacts_count_query = select(func.count(ModelArtifact.id)).where(
|
||||
ModelArtifact.tenant_id == tenant_uuid
|
||||
)
|
||||
artifacts_count_result = await db.execute(artifacts_count_query)
|
||||
artifacts_count = artifacts_count_result.scalar()
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"models_count": models_count,
|
||||
"artifacts_count": artifacts_count,
|
||||
"total_training_assets": models_count + artifacts_count
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get tenant models count",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to get models count"
|
||||
)
|
||||
@@ -96,6 +96,7 @@ class TrainingJobQueue(Base):
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
cancelled_by = Column(String, nullable=True)
|
||||
|
||||
class ModelArtifact(Base):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user