Add training job status in the db
This commit is contained in:
@@ -13,7 +13,7 @@ from datetime import datetime, timezone
|
||||
import uuid
|
||||
|
||||
from app.core.database import get_db, get_background_db_session
|
||||
from app.services.training_service import TrainingService
|
||||
from app.services.training_service import TrainingService, TrainingStatusManager
|
||||
from sqlalchemy import select, delete, func
|
||||
from app.schemas.training import (
|
||||
TrainingJobRequest,
|
||||
@@ -172,10 +172,20 @@ async def execute_training_job_background(
|
||||
# ✅ FIX: Create training service with isolated DB session
|
||||
training_service = TrainingService(db_session=db_session)
|
||||
|
||||
status_manager = TrainingStatusManager(db_session=db_session)
|
||||
|
||||
# Publish progress event
|
||||
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
|
||||
|
||||
try:
|
||||
|
||||
await status_manager.update_job_status(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Initializing training pipeline"
|
||||
)
|
||||
|
||||
# Execute the actual training pipeline
|
||||
result = await training_service.start_training_job(
|
||||
tenant_id=tenant_id,
|
||||
@@ -185,6 +195,14 @@ async def execute_training_job_background(
|
||||
requested_end=requested_end
|
||||
)
|
||||
|
||||
await status_manager.update_job_status(
|
||||
job_id=job_id,
|
||||
status="completed",
|
||||
progress=100,
|
||||
current_step="Training completed successfully",
|
||||
results=result
|
||||
)
|
||||
|
||||
# Publish completion event
|
||||
await publish_job_completed(
|
||||
job_id=job_id,
|
||||
@@ -196,7 +214,15 @@ async def execute_training_job_background(
|
||||
|
||||
except Exception as training_error:
|
||||
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
|
||||
|
||||
|
||||
await status_manager.update_job_status(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Training failed",
|
||||
error_message=str(training_error)
|
||||
)
|
||||
|
||||
# Publish failure event
|
||||
await publish_job_failed(
|
||||
job_id=job_id,
|
||||
|
||||
@@ -16,6 +16,9 @@ from app.services.training_orchestrator import TrainingDataOrchestrator
|
||||
|
||||
from app.core.database import get_db_session
|
||||
|
||||
from app.models.training import ModelTrainingLog
|
||||
from sqlalchemy import select, delete, text
|
||||
|
||||
from app.services.messaging import (
|
||||
publish_job_progress,
|
||||
publish_data_validation_started,
|
||||
@@ -374,4 +377,62 @@ class TrainingService:
|
||||
"completed_at": final_result.get("completed_at")
|
||||
}
|
||||
|
||||
return response_data
|
||||
return response_data
|
||||
|
||||
class TrainingStatusManager:
|
||||
"""Class to handle database status updates during training"""
|
||||
|
||||
def __init__(self, db_session: AsyncSession):
|
||||
self.db_session = db_session
|
||||
|
||||
async def update_job_status(
|
||||
self,
|
||||
job_id: str,
|
||||
status: str,
|
||||
progress: int = None,
|
||||
current_step: str = None,
|
||||
error_message: str = None,
|
||||
results: dict = None
|
||||
):
|
||||
"""Update training job status in database"""
|
||||
try:
|
||||
# Find the training log record
|
||||
query = select(ModelTrainingLog).where(
|
||||
ModelTrainingLog.job_id == job_id
|
||||
)
|
||||
result = await self.db_session.execute(query)
|
||||
training_log = result.scalar_one_or_none()
|
||||
|
||||
if not training_log:
|
||||
logger.error(f"Training log not found for job {job_id}")
|
||||
return False
|
||||
|
||||
# Update status fields
|
||||
training_log.status = status
|
||||
if progress is not None:
|
||||
training_log.progress = progress
|
||||
if current_step:
|
||||
training_log.current_step = current_step
|
||||
if error_message:
|
||||
training_log.error_message = error_message
|
||||
if results:
|
||||
training_log.results = results
|
||||
|
||||
# Set end time for completed/failed jobs
|
||||
if status in ["completed", "failed", "cancelled"]:
|
||||
training_log.end_time = datetime.now()
|
||||
|
||||
# Update timestamp
|
||||
training_log.updated_at = datetime.now()
|
||||
|
||||
# Commit changes
|
||||
await self.db_session.commit()
|
||||
await self.db_session.refresh(training_log)
|
||||
|
||||
logger.info(f"Updated training job {job_id} status to {status}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update job status: {str(e)}")
|
||||
await self.db_session.rollback()
|
||||
return False
|
||||
Reference in New Issue
Block a user