Add training job status in the db

This commit is contained in:
Urtzi Alfaro
2025-08-03 14:55:13 +02:00
parent b0d83720fd
commit 935f45a283
2 changed files with 90 additions and 3 deletions

View File

@@ -16,6 +16,9 @@ from app.services.training_orchestrator import TrainingDataOrchestrator
from app.core.database import get_db_session
from app.models.training import ModelTrainingLog
from sqlalchemy import select, delete, text
from app.services.messaging import (
publish_job_progress,
publish_data_validation_started,
@@ -374,4 +377,62 @@ class TrainingService:
"completed_at": final_result.get("completed_at")
}
return response_data
return response_data
class TrainingStatusManager:
"""Class to handle database status updates during training"""
def __init__(self, db_session: AsyncSession):
self.db_session = db_session
async def update_job_status(
self,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: dict = None
):
"""Update training job status in database"""
try:
# Find the training log record
query = select(ModelTrainingLog).where(
ModelTrainingLog.job_id == job_id
)
result = await self.db_session.execute(query)
training_log = result.scalar_one_or_none()
if not training_log:
logger.error(f"Training log not found for job {job_id}")
return False
# Update status fields
training_log.status = status
if progress is not None:
training_log.progress = progress
if current_step:
training_log.current_step = current_step
if error_message:
training_log.error_message = error_message
if results:
training_log.results = results
# Set end time for completed/failed jobs
if status in ["completed", "failed", "cancelled"]:
training_log.end_time = datetime.now()
# Update timestamp
training_log.updated_at = datetime.now()
# Commit changes
await self.db_session.commit()
await self.db_session.refresh(training_log)
logger.info(f"Updated training job {job_id} status to {status}")
return True
except Exception as e:
logger.error(f"Failed to update job status: {str(e)}")
await self.db_session.rollback()
return False