diff --git a/services/training/app/api/training_operations.py b/services/training/app/api/training_operations.py index 46e1d69d..ccea858f 100644 --- a/services/training/app/api/training_operations.py +++ b/services/training/app/api/training_operations.py @@ -236,27 +236,27 @@ async def start_training_job( # Log audit event for training job creation try: - from app.core.database import get_db - db = next(get_db()) - await audit_logger.log_event( - db_session=db, - tenant_id=tenant_id, - user_id=current_user["user_id"], - action=AuditAction.CREATE.value, - resource_type="training_job", - resource_id=job_id, - severity=AuditSeverity.MEDIUM.value, - description=f"Started training job (tier: {tier})", - metadata={ - "job_id": job_id, - "tier": tier, - "estimated_dataset_size": estimated_dataset_size, - "quota_usage": quota_result.get('current', 0) if quota_result else 0, - "quota_limit": quota_limit if quota_limit else "unlimited" - }, - endpoint="/jobs", - method="POST" - ) + from app.core.database import database_manager + async with database_manager.get_session() as db: + await audit_logger.log_event( + db_session=db, + tenant_id=tenant_id, + user_id=current_user["user_id"], + action=AuditAction.CREATE.value, + resource_type="training_job", + resource_id=job_id, + severity=AuditSeverity.MEDIUM.value, + description=f"Started training job (tier: {tier})", + metadata={ + "job_id": job_id, + "tier": tier, + "estimated_dataset_size": estimated_dataset_size, + "quota_usage": quota_result.get('current', 0) if quota_result else 0, + "quota_limit": quota_limit if quota_limit else "unlimited" + }, + endpoint="/jobs", + method="POST" + ) except Exception as audit_error: logger.warning("Failed to log audit event", error=str(audit_error)) diff --git a/services/training/app/services/training_service.py b/services/training/app/services/training_service.py index 36af1e54..e5d0b2a0 100644 --- a/services/training/app/services/training_service.py +++ b/services/training/app/services/training_service.py @@ -174,14 +174,15 @@ class EnhancedTrainingService: await self._init_repositories(session) try: - # Check if training log already exists, create if not + # Check if training log already exists, update if found, create if not existing_log = await self.training_log_repo.get_log_by_job_id(job_id) - + if existing_log: logger.info("Training log already exists, updating status", job_id=job_id) training_log = await self.training_log_repo.update_log_progress( job_id, 0, "initializing", "running" ) + await session.commit() else: # Create new training log entry log_data = { @@ -191,8 +192,21 @@ class EnhancedTrainingService: "progress": 0, "current_step": "initializing" } - training_log = await self.training_log_repo.create_training_log(log_data) - + try: + training_log = await self.training_log_repo.create_training_log(log_data) + await session.commit() # Explicit commit so other sessions can see it + except Exception as create_error: + # Handle race condition: log may have been created by another session + if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower(): + logger.debug("Training log already exists (race condition), updating instead", job_id=job_id) + await session.rollback() + training_log = await self.training_log_repo.update_log_progress( + job_id, 0, "initializing", "running" + ) + await session.commit() + else: + raise + # Step 1: Prepare training dataset (includes sales data validation) logger.info("Step 1: Preparing and aligning training data (with validation)") await self.training_log_repo.update_log_progress( @@ -717,10 +731,10 @@ class EnhancedTrainingService: try: async with self.database_manager.get_session() as session: await self._init_repositories(session) - + # Check if log exists, create if not existing_log = await self.training_log_repo.get_log_by_job_id(job_id) - + if not existing_log: # Create initial log entry if not tenant_id: @@ -732,7 +746,7 @@ class EnhancedTrainingService: tenant_id = parts[2] except Exception: logger.warning(f"Could not extract tenant_id from job_id {job_id}") - + if tenant_id: log_data = { "job_id": job_id, @@ -742,15 +756,35 @@ class EnhancedTrainingService: "current_step": current_step or "initializing", "start_time": datetime.now(timezone.utc) } - + if error_message: log_data["error_message"] = error_message if results: # Ensure results are JSON-serializable before storing log_data["results"] = make_json_serializable(results) - - await self.training_log_repo.create_training_log(log_data) - logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id) + + try: + await self.training_log_repo.create_training_log(log_data) + await session.commit() # Explicit commit so other sessions can see it + logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id) + except Exception as create_error: + # Handle race condition: another session may have created the log + if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower(): + logger.debug("Training log already exists (race condition), querying again", job_id=job_id) + await session.rollback() + # Query again to get the existing log + existing_log = await self.training_log_repo.get_log_by_job_id(job_id) + if existing_log: + # Update the existing log instead + await self.training_log_repo.update_log_progress( + job_id=job_id, + progress=progress, + current_step=current_step, + status=status + ) + await session.commit() + else: + raise else: logger.error("Cannot create training log without tenant_id", job_id=job_id) return @@ -762,7 +796,7 @@ class EnhancedTrainingService: current_step=current_step, status=status ) - + # Update additional fields if provided if error_message or results: update_data = {} @@ -773,10 +807,12 @@ class EnhancedTrainingService: update_data["results"] = make_json_serializable(results) if status in ["completed", "failed"]: update_data["end_time"] = datetime.now(timezone.utc) - + if update_data: await self.training_log_repo.update(existing_log.id, update_data) - + + await session.commit() # Explicit commit after updates + except Exception as e: logger.error("Failed to update job status using repository", job_id=job_id,