Add training job status in the db
This commit is contained in:
@@ -13,7 +13,7 @@ from datetime import datetime, timezone
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from app.core.database import get_db, get_background_db_session
|
from app.core.database import get_db, get_background_db_session
|
||||||
from app.services.training_service import TrainingService
|
from app.services.training_service import TrainingService, TrainingStatusManager
|
||||||
from sqlalchemy import select, delete, func
|
from sqlalchemy import select, delete, func
|
||||||
from app.schemas.training import (
|
from app.schemas.training import (
|
||||||
TrainingJobRequest,
|
TrainingJobRequest,
|
||||||
@@ -172,10 +172,20 @@ async def execute_training_job_background(
|
|||||||
# ✅ FIX: Create training service with isolated DB session
|
# ✅ FIX: Create training service with isolated DB session
|
||||||
training_service = TrainingService(db_session=db_session)
|
training_service = TrainingService(db_session=db_session)
|
||||||
|
|
||||||
|
status_manager = TrainingStatusManager(db_session=db_session)
|
||||||
|
|
||||||
# Publish progress event
|
# Publish progress event
|
||||||
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
|
await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
await status_manager.update_job_status(
|
||||||
|
job_id=job_id,
|
||||||
|
status="running",
|
||||||
|
progress=0,
|
||||||
|
current_step="Initializing training pipeline"
|
||||||
|
)
|
||||||
|
|
||||||
# Execute the actual training pipeline
|
# Execute the actual training pipeline
|
||||||
result = await training_service.start_training_job(
|
result = await training_service.start_training_job(
|
||||||
tenant_id=tenant_id,
|
tenant_id=tenant_id,
|
||||||
@@ -185,6 +195,14 @@ async def execute_training_job_background(
|
|||||||
requested_end=requested_end
|
requested_end=requested_end
|
||||||
)
|
)
|
||||||
|
|
||||||
|
await status_manager.update_job_status(
|
||||||
|
job_id=job_id,
|
||||||
|
status="completed",
|
||||||
|
progress=100,
|
||||||
|
current_step="Training completed successfully",
|
||||||
|
results=result
|
||||||
|
)
|
||||||
|
|
||||||
# Publish completion event
|
# Publish completion event
|
||||||
await publish_job_completed(
|
await publish_job_completed(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
@@ -196,7 +214,15 @@ async def execute_training_job_background(
|
|||||||
|
|
||||||
except Exception as training_error:
|
except Exception as training_error:
|
||||||
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
|
logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}")
|
||||||
|
|
||||||
|
await status_manager.update_job_status(
|
||||||
|
job_id=job_id,
|
||||||
|
status="failed",
|
||||||
|
progress=0,
|
||||||
|
current_step="Training failed",
|
||||||
|
error_message=str(training_error)
|
||||||
|
)
|
||||||
|
|
||||||
# Publish failure event
|
# Publish failure event
|
||||||
await publish_job_failed(
|
await publish_job_failed(
|
||||||
job_id=job_id,
|
job_id=job_id,
|
||||||
|
|||||||
@@ -16,6 +16,9 @@ from app.services.training_orchestrator import TrainingDataOrchestrator
|
|||||||
|
|
||||||
from app.core.database import get_db_session
|
from app.core.database import get_db_session
|
||||||
|
|
||||||
|
from app.models.training import ModelTrainingLog
|
||||||
|
from sqlalchemy import select, delete, text
|
||||||
|
|
||||||
from app.services.messaging import (
|
from app.services.messaging import (
|
||||||
publish_job_progress,
|
publish_job_progress,
|
||||||
publish_data_validation_started,
|
publish_data_validation_started,
|
||||||
@@ -374,4 +377,62 @@ class TrainingService:
|
|||||||
"completed_at": final_result.get("completed_at")
|
"completed_at": final_result.get("completed_at")
|
||||||
}
|
}
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
|
|
||||||
|
class TrainingStatusManager:
|
||||||
|
"""Class to handle database status updates during training"""
|
||||||
|
|
||||||
|
def __init__(self, db_session: AsyncSession):
|
||||||
|
self.db_session = db_session
|
||||||
|
|
||||||
|
async def update_job_status(
|
||||||
|
self,
|
||||||
|
job_id: str,
|
||||||
|
status: str,
|
||||||
|
progress: int = None,
|
||||||
|
current_step: str = None,
|
||||||
|
error_message: str = None,
|
||||||
|
results: dict = None
|
||||||
|
):
|
||||||
|
"""Update training job status in database"""
|
||||||
|
try:
|
||||||
|
# Find the training log record
|
||||||
|
query = select(ModelTrainingLog).where(
|
||||||
|
ModelTrainingLog.job_id == job_id
|
||||||
|
)
|
||||||
|
result = await self.db_session.execute(query)
|
||||||
|
training_log = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if not training_log:
|
||||||
|
logger.error(f"Training log not found for job {job_id}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Update status fields
|
||||||
|
training_log.status = status
|
||||||
|
if progress is not None:
|
||||||
|
training_log.progress = progress
|
||||||
|
if current_step:
|
||||||
|
training_log.current_step = current_step
|
||||||
|
if error_message:
|
||||||
|
training_log.error_message = error_message
|
||||||
|
if results:
|
||||||
|
training_log.results = results
|
||||||
|
|
||||||
|
# Set end time for completed/failed jobs
|
||||||
|
if status in ["completed", "failed", "cancelled"]:
|
||||||
|
training_log.end_time = datetime.now()
|
||||||
|
|
||||||
|
# Update timestamp
|
||||||
|
training_log.updated_at = datetime.now()
|
||||||
|
|
||||||
|
# Commit changes
|
||||||
|
await self.db_session.commit()
|
||||||
|
await self.db_session.refresh(training_log)
|
||||||
|
|
||||||
|
logger.info(f"Updated training job {job_id} status to {status}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to update job status: {str(e)}")
|
||||||
|
await self.db_session.rollback()
|
||||||
|
return False
|
||||||
Reference in New Issue
Block a user