Fix training log race conditions and audit event error

Critical fixes for training session logging:

1. Training log race condition fix:
   - Add explicit session commits after creating training logs
   - Handle duplicate key errors gracefully when multiple sessions
     try to create the same log simultaneously
   - Implement retry logic to query for existing logs after
     duplicate key violations
   - Prevents "Training log not found" errors during training

2. Audit event async generator error fix:
   - Replace incorrect next(get_db()) usage with proper
     async context manager (database_manager.get_session())
   - Fixes "'async_generator' object is not an iterator" error
   - Ensures audit logging works correctly

These changes address race conditions in concurrent database
sessions and ensure training logs are properly synchronized
across the training pipeline.
This commit is contained in:
Claude
2025-11-05 13:24:22 +00:00
parent 15025fdf1d
commit 8df90338b2
2 changed files with 71 additions and 35 deletions

View File

@@ -236,27 +236,27 @@ async def start_training_job(
# Log audit event for training job creation
try:
from app.core.database import get_db
db = next(get_db())
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id,
user_id=current_user["user_id"],
action=AuditAction.CREATE.value,
resource_type="training_job",
resource_id=job_id,
severity=AuditSeverity.MEDIUM.value,
description=f"Started training job (tier: {tier})",
metadata={
"job_id": job_id,
"tier": tier,
"estimated_dataset_size": estimated_dataset_size,
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
"quota_limit": quota_limit if quota_limit else "unlimited"
},
endpoint="/jobs",
method="POST"
)
from app.core.database import database_manager
async with database_manager.get_session() as db:
await audit_logger.log_event(
db_session=db,
tenant_id=tenant_id,
user_id=current_user["user_id"],
action=AuditAction.CREATE.value,
resource_type="training_job",
resource_id=job_id,
severity=AuditSeverity.MEDIUM.value,
description=f"Started training job (tier: {tier})",
metadata={
"job_id": job_id,
"tier": tier,
"estimated_dataset_size": estimated_dataset_size,
"quota_usage": quota_result.get('current', 0) if quota_result else 0,
"quota_limit": quota_limit if quota_limit else "unlimited"
},
endpoint="/jobs",
method="POST"
)
except Exception as audit_error:
logger.warning("Failed to log audit event", error=str(audit_error))

View File

@@ -174,14 +174,15 @@ class EnhancedTrainingService:
await self._init_repositories(session)
try:
# Check if training log already exists, create if not
# Check if training log already exists, update if found, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
logger.info("Training log already exists, updating status", job_id=job_id)
training_log = await self.training_log_repo.update_log_progress(
job_id, 0, "initializing", "running"
)
await session.commit()
else:
# Create new training log entry
log_data = {
@@ -191,8 +192,21 @@ class EnhancedTrainingService:
"progress": 0,
"current_step": "initializing"
}
training_log = await self.training_log_repo.create_training_log(log_data)
try:
training_log = await self.training_log_repo.create_training_log(log_data)
await session.commit() # Explicit commit so other sessions can see it
except Exception as create_error:
# Handle race condition: log may have been created by another session
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), updating instead", job_id=job_id)
await session.rollback()
training_log = await self.training_log_repo.update_log_progress(
job_id, 0, "initializing", "running"
)
await session.commit()
else:
raise
# Step 1: Prepare training dataset (includes sales data validation)
logger.info("Step 1: Preparing and aligning training data (with validation)")
await self.training_log_repo.update_log_progress(
@@ -717,10 +731,10 @@ class EnhancedTrainingService:
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if not existing_log:
# Create initial log entry
if not tenant_id:
@@ -732,7 +746,7 @@ class EnhancedTrainingService:
tenant_id = parts[2]
except Exception:
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
if tenant_id:
log_data = {
"job_id": job_id,
@@ -742,15 +756,35 @@ class EnhancedTrainingService:
"current_step": current_step or "initializing",
"start_time": datetime.now(timezone.utc)
}
if error_message:
log_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
log_data["results"] = make_json_serializable(results)
await self.training_log_repo.create_training_log(log_data)
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
try:
await self.training_log_repo.create_training_log(log_data)
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
# Handle race condition: another session may have created the log
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
await session.rollback()
# Query again to get the existing log
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
# Update the existing log instead
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
await session.commit()
else:
raise
else:
logger.error("Cannot create training log without tenant_id", job_id=job_id)
return
@@ -762,7 +796,7 @@ class EnhancedTrainingService:
current_step=current_step,
status=status
)
# Update additional fields if provided
if error_message or results:
update_data = {}
@@ -773,10 +807,12 @@ class EnhancedTrainingService:
update_data["results"] = make_json_serializable(results)
if status in ["completed", "failed"]:
update_data["end_time"] = datetime.now(timezone.utc)
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
await session.commit() # Explicit commit after updates
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,