Fix training log race conditions and audit event error
Critical fixes for training session logging:
1. Training log race condition fix:
- Add explicit session commits after creating training logs
- Handle duplicate key errors gracefully when multiple sessions
try to create the same log simultaneously
- Implement retry logic to query for existing logs after
duplicate key violations
- Prevents "Training log not found" errors during training
2. Audit event async generator error fix:
- Replace incorrect next(get_db()) usage with proper
async context manager (database_manager.get_session())
- Fixes "'async_generator' object is not an iterator" error
- Ensures audit logging works correctly
These changes address race conditions in concurrent database
sessions and ensure training logs are properly synchronized
across the training pipeline.
This commit is contained in:
@@ -236,27 +236,27 @@ async def start_training_job(
|
||||
|
||||
# Log audit event for training job creation
|
||||
try:
|
||||
from app.core.database import get_db
|
||||
db = next(get_db())
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.CREATE.value,
|
||||
resource_type="training_job",
|
||||
resource_id=job_id,
|
||||
severity=AuditSeverity.MEDIUM.value,
|
||||
description=f"Started training job (tier: {tier})",
|
||||
metadata={
|
||||
"job_id": job_id,
|
||||
"tier": tier,
|
||||
"estimated_dataset_size": estimated_dataset_size,
|
||||
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
||||
"quota_limit": quota_limit if quota_limit else "unlimited"
|
||||
},
|
||||
endpoint="/jobs",
|
||||
method="POST"
|
||||
)
|
||||
from app.core.database import database_manager
|
||||
async with database_manager.get_session() as db:
|
||||
await audit_logger.log_event(
|
||||
db_session=db,
|
||||
tenant_id=tenant_id,
|
||||
user_id=current_user["user_id"],
|
||||
action=AuditAction.CREATE.value,
|
||||
resource_type="training_job",
|
||||
resource_id=job_id,
|
||||
severity=AuditSeverity.MEDIUM.value,
|
||||
description=f"Started training job (tier: {tier})",
|
||||
metadata={
|
||||
"job_id": job_id,
|
||||
"tier": tier,
|
||||
"estimated_dataset_size": estimated_dataset_size,
|
||||
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
||||
"quota_limit": quota_limit if quota_limit else "unlimited"
|
||||
},
|
||||
endpoint="/jobs",
|
||||
method="POST"
|
||||
)
|
||||
except Exception as audit_error:
|
||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||
|
||||
|
||||
@@ -174,14 +174,15 @@ class EnhancedTrainingService:
|
||||
await self._init_repositories(session)
|
||||
|
||||
try:
|
||||
# Check if training log already exists, create if not
|
||||
# Check if training log already exists, update if found, create if not
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
|
||||
|
||||
if existing_log:
|
||||
logger.info("Training log already exists, updating status", job_id=job_id)
|
||||
training_log = await self.training_log_repo.update_log_progress(
|
||||
job_id, 0, "initializing", "running"
|
||||
)
|
||||
await session.commit()
|
||||
else:
|
||||
# Create new training log entry
|
||||
log_data = {
|
||||
@@ -191,8 +192,21 @@ class EnhancedTrainingService:
|
||||
"progress": 0,
|
||||
"current_step": "initializing"
|
||||
}
|
||||
training_log = await self.training_log_repo.create_training_log(log_data)
|
||||
|
||||
try:
|
||||
training_log = await self.training_log_repo.create_training_log(log_data)
|
||||
await session.commit() # Explicit commit so other sessions can see it
|
||||
except Exception as create_error:
|
||||
# Handle race condition: log may have been created by another session
|
||||
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
||||
logger.debug("Training log already exists (race condition), updating instead", job_id=job_id)
|
||||
await session.rollback()
|
||||
training_log = await self.training_log_repo.update_log_progress(
|
||||
job_id, 0, "initializing", "running"
|
||||
)
|
||||
await session.commit()
|
||||
else:
|
||||
raise
|
||||
|
||||
# Step 1: Prepare training dataset (includes sales data validation)
|
||||
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
||||
await self.training_log_repo.update_log_progress(
|
||||
@@ -717,10 +731,10 @@ class EnhancedTrainingService:
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
|
||||
|
||||
# Check if log exists, create if not
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
|
||||
|
||||
if not existing_log:
|
||||
# Create initial log entry
|
||||
if not tenant_id:
|
||||
@@ -732,7 +746,7 @@ class EnhancedTrainingService:
|
||||
tenant_id = parts[2]
|
||||
except Exception:
|
||||
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
||||
|
||||
|
||||
if tenant_id:
|
||||
log_data = {
|
||||
"job_id": job_id,
|
||||
@@ -742,15 +756,35 @@ class EnhancedTrainingService:
|
||||
"current_step": current_step or "initializing",
|
||||
"start_time": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
|
||||
if error_message:
|
||||
log_data["error_message"] = error_message
|
||||
if results:
|
||||
# Ensure results are JSON-serializable before storing
|
||||
log_data["results"] = make_json_serializable(results)
|
||||
|
||||
await self.training_log_repo.create_training_log(log_data)
|
||||
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
await self.training_log_repo.create_training_log(log_data)
|
||||
await session.commit() # Explicit commit so other sessions can see it
|
||||
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
||||
except Exception as create_error:
|
||||
# Handle race condition: another session may have created the log
|
||||
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
||||
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
|
||||
await session.rollback()
|
||||
# Query again to get the existing log
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
if existing_log:
|
||||
# Update the existing log instead
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step,
|
||||
status=status
|
||||
)
|
||||
await session.commit()
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
||||
return
|
||||
@@ -762,7 +796,7 @@ class EnhancedTrainingService:
|
||||
current_step=current_step,
|
||||
status=status
|
||||
)
|
||||
|
||||
|
||||
# Update additional fields if provided
|
||||
if error_message or results:
|
||||
update_data = {}
|
||||
@@ -773,10 +807,12 @@ class EnhancedTrainingService:
|
||||
update_data["results"] = make_json_serializable(results)
|
||||
if status in ["completed", "failed"]:
|
||||
update_data["end_time"] = datetime.now(timezone.utc)
|
||||
|
||||
|
||||
if update_data:
|
||||
await self.training_log_repo.update(existing_log.id, update_data)
|
||||
|
||||
|
||||
await session.commit() # Explicit commit after updates
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update job status using repository",
|
||||
job_id=job_id,
|
||||
|
||||
Reference in New Issue
Block a user