Fix training log race conditions and audit event error
Critical fixes for training session logging:
1. Training log race condition fix:
- Add explicit session commits after creating training logs
- Handle duplicate key errors gracefully when multiple sessions
try to create the same log simultaneously
- Implement retry logic to query for existing logs after
duplicate key violations
- Prevents "Training log not found" errors during training
2. Audit event async generator error fix:
- Replace incorrect next(get_db()) usage with proper
async context manager (database_manager.get_session())
- Fixes "'async_generator' object is not an iterator" error
- Ensures audit logging works correctly
These changes address race conditions in concurrent database
sessions and ensure training logs are properly synchronized
across the training pipeline.
This commit is contained in:
@@ -236,27 +236,27 @@ async def start_training_job(
|
|||||||
|
|
||||||
# Log audit event for training job creation
|
# Log audit event for training job creation
|
||||||
try:
|
try:
|
||||||
from app.core.database import get_db
|
from app.core.database import database_manager
|
||||||
db = next(get_db())
|
async with database_manager.get_session() as db:
|
||||||
await audit_logger.log_event(
|
await audit_logger.log_event(
|
||||||
db_session=db,
|
db_session=db,
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
user_id=current_user["user_id"],
|
user_id=current_user["user_id"],
|
||||||
action=AuditAction.CREATE.value,
|
action=AuditAction.CREATE.value,
|
||||||
resource_type="training_job",
|
resource_type="training_job",
|
||||||
resource_id=job_id,
|
resource_id=job_id,
|
||||||
severity=AuditSeverity.MEDIUM.value,
|
severity=AuditSeverity.MEDIUM.value,
|
||||||
description=f"Started training job (tier: {tier})",
|
description=f"Started training job (tier: {tier})",
|
||||||
metadata={
|
metadata={
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
"tier": tier,
|
"tier": tier,
|
||||||
"estimated_dataset_size": estimated_dataset_size,
|
"estimated_dataset_size": estimated_dataset_size,
|
||||||
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
|
||||||
"quota_limit": quota_limit if quota_limit else "unlimited"
|
"quota_limit": quota_limit if quota_limit else "unlimited"
|
||||||
},
|
},
|
||||||
endpoint="/jobs",
|
endpoint="/jobs",
|
||||||
method="POST"
|
method="POST"
|
||||||
)
|
)
|
||||||
except Exception as audit_error:
|
except Exception as audit_error:
|
||||||
logger.warning("Failed to log audit event", error=str(audit_error))
|
logger.warning("Failed to log audit event", error=str(audit_error))
|
||||||
|
|
||||||
|
|||||||
@@ -174,14 +174,15 @@ class EnhancedTrainingService:
|
|||||||
await self._init_repositories(session)
|
await self._init_repositories(session)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Check if training log already exists, create if not
|
# Check if training log already exists, update if found, create if not
|
||||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||||
|
|
||||||
if existing_log:
|
if existing_log:
|
||||||
logger.info("Training log already exists, updating status", job_id=job_id)
|
logger.info("Training log already exists, updating status", job_id=job_id)
|
||||||
training_log = await self.training_log_repo.update_log_progress(
|
training_log = await self.training_log_repo.update_log_progress(
|
||||||
job_id, 0, "initializing", "running"
|
job_id, 0, "initializing", "running"
|
||||||
)
|
)
|
||||||
|
await session.commit()
|
||||||
else:
|
else:
|
||||||
# Create new training log entry
|
# Create new training log entry
|
||||||
log_data = {
|
log_data = {
|
||||||
@@ -191,8 +192,21 @@ class EnhancedTrainingService:
|
|||||||
"progress": 0,
|
"progress": 0,
|
||||||
"current_step": "initializing"
|
"current_step": "initializing"
|
||||||
}
|
}
|
||||||
training_log = await self.training_log_repo.create_training_log(log_data)
|
try:
|
||||||
|
training_log = await self.training_log_repo.create_training_log(log_data)
|
||||||
|
await session.commit() # Explicit commit so other sessions can see it
|
||||||
|
except Exception as create_error:
|
||||||
|
# Handle race condition: log may have been created by another session
|
||||||
|
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
||||||
|
logger.debug("Training log already exists (race condition), updating instead", job_id=job_id)
|
||||||
|
await session.rollback()
|
||||||
|
training_log = await self.training_log_repo.update_log_progress(
|
||||||
|
job_id, 0, "initializing", "running"
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
# Step 1: Prepare training dataset (includes sales data validation)
|
# Step 1: Prepare training dataset (includes sales data validation)
|
||||||
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
logger.info("Step 1: Preparing and aligning training data (with validation)")
|
||||||
await self.training_log_repo.update_log_progress(
|
await self.training_log_repo.update_log_progress(
|
||||||
@@ -717,10 +731,10 @@ class EnhancedTrainingService:
|
|||||||
try:
|
try:
|
||||||
async with self.database_manager.get_session() as session:
|
async with self.database_manager.get_session() as session:
|
||||||
await self._init_repositories(session)
|
await self._init_repositories(session)
|
||||||
|
|
||||||
# Check if log exists, create if not
|
# Check if log exists, create if not
|
||||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||||
|
|
||||||
if not existing_log:
|
if not existing_log:
|
||||||
# Create initial log entry
|
# Create initial log entry
|
||||||
if not tenant_id:
|
if not tenant_id:
|
||||||
@@ -732,7 +746,7 @@ class EnhancedTrainingService:
|
|||||||
tenant_id = parts[2]
|
tenant_id = parts[2]
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
||||||
|
|
||||||
if tenant_id:
|
if tenant_id:
|
||||||
log_data = {
|
log_data = {
|
||||||
"job_id": job_id,
|
"job_id": job_id,
|
||||||
@@ -742,15 +756,35 @@ class EnhancedTrainingService:
|
|||||||
"current_step": current_step or "initializing",
|
"current_step": current_step or "initializing",
|
||||||
"start_time": datetime.now(timezone.utc)
|
"start_time": datetime.now(timezone.utc)
|
||||||
}
|
}
|
||||||
|
|
||||||
if error_message:
|
if error_message:
|
||||||
log_data["error_message"] = error_message
|
log_data["error_message"] = error_message
|
||||||
if results:
|
if results:
|
||||||
# Ensure results are JSON-serializable before storing
|
# Ensure results are JSON-serializable before storing
|
||||||
log_data["results"] = make_json_serializable(results)
|
log_data["results"] = make_json_serializable(results)
|
||||||
|
|
||||||
await self.training_log_repo.create_training_log(log_data)
|
try:
|
||||||
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
await self.training_log_repo.create_training_log(log_data)
|
||||||
|
await session.commit() # Explicit commit so other sessions can see it
|
||||||
|
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
||||||
|
except Exception as create_error:
|
||||||
|
# Handle race condition: another session may have created the log
|
||||||
|
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
||||||
|
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
|
||||||
|
await session.rollback()
|
||||||
|
# Query again to get the existing log
|
||||||
|
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||||
|
if existing_log:
|
||||||
|
# Update the existing log instead
|
||||||
|
await self.training_log_repo.update_log_progress(
|
||||||
|
job_id=job_id,
|
||||||
|
progress=progress,
|
||||||
|
current_step=current_step,
|
||||||
|
status=status
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
else:
|
||||||
|
raise
|
||||||
else:
|
else:
|
||||||
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
||||||
return
|
return
|
||||||
@@ -762,7 +796,7 @@ class EnhancedTrainingService:
|
|||||||
current_step=current_step,
|
current_step=current_step,
|
||||||
status=status
|
status=status
|
||||||
)
|
)
|
||||||
|
|
||||||
# Update additional fields if provided
|
# Update additional fields if provided
|
||||||
if error_message or results:
|
if error_message or results:
|
||||||
update_data = {}
|
update_data = {}
|
||||||
@@ -773,10 +807,12 @@ class EnhancedTrainingService:
|
|||||||
update_data["results"] = make_json_serializable(results)
|
update_data["results"] = make_json_serializable(results)
|
||||||
if status in ["completed", "failed"]:
|
if status in ["completed", "failed"]:
|
||||||
update_data["end_time"] = datetime.now(timezone.utc)
|
update_data["end_time"] = datetime.now(timezone.utc)
|
||||||
|
|
||||||
if update_data:
|
if update_data:
|
||||||
await self.training_log_repo.update(existing_log.id, update_data)
|
await self.training_log_repo.update(existing_log.id, update_data)
|
||||||
|
|
||||||
|
await session.commit() # Explicit commit after updates
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("Failed to update job status using repository",
|
logger.error("Failed to update job status using repository",
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
|
|||||||
Reference in New Issue
Block a user