diff --git a/services/training/app/api/training.py b/services/training/app/api/training.py index 08554a19..ed578a29 100644 --- a/services/training/app/api/training.py +++ b/services/training/app/api/training.py @@ -13,7 +13,7 @@ from datetime import datetime, timezone import uuid from app.core.database import get_db, get_background_db_session -from app.services.training_service import TrainingService +from app.services.training_service import TrainingService, TrainingStatusManager from sqlalchemy import select, delete, func from app.schemas.training import ( TrainingJobRequest, @@ -172,10 +172,20 @@ async def execute_training_job_background( # ✅ FIX: Create training service with isolated DB session training_service = TrainingService(db_session=db_session) + status_manager = TrainingStatusManager(db_session=db_session) + # Publish progress event await publish_job_progress(job_id, tenant_id, 5, "Initializing training pipeline") try: + + await status_manager.update_job_status( + job_id=job_id, + status="running", + progress=0, + current_step="Initializing training pipeline" + ) + # Execute the actual training pipeline result = await training_service.start_training_job( tenant_id=tenant_id, @@ -185,6 +195,14 @@ async def execute_training_job_background( requested_end=requested_end ) + await status_manager.update_job_status( + job_id=job_id, + status="completed", + progress=100, + current_step="Training completed successfully", + results=result + ) + # Publish completion event await publish_job_completed( job_id=job_id, @@ -196,7 +214,15 @@ async def execute_training_job_background( except Exception as training_error: logger.error(f"❌ Training pipeline failed for job {job_id}: {str(training_error)}") - + + await status_manager.update_job_status( + job_id=job_id, + status="failed", + progress=0, + current_step="Training failed", + error_message=str(training_error) + ) + # Publish failure event await publish_job_failed( job_id=job_id, diff --git a/services/training/app/services/training_service.py b/services/training/app/services/training_service.py index 2152724d..f6a85b6c 100644 --- a/services/training/app/services/training_service.py +++ b/services/training/app/services/training_service.py @@ -16,6 +16,9 @@ from app.services.training_orchestrator import TrainingDataOrchestrator from app.core.database import get_db_session +from app.models.training import ModelTrainingLog +from sqlalchemy import select, delete, text + from app.services.messaging import ( publish_job_progress, publish_data_validation_started, @@ -374,4 +377,62 @@ class TrainingService: "completed_at": final_result.get("completed_at") } - return response_data \ No newline at end of file + return response_data + +class TrainingStatusManager: + """Class to handle database status updates during training""" + + def __init__(self, db_session: AsyncSession): + self.db_session = db_session + + async def update_job_status( + self, + job_id: str, + status: str, + progress: int = None, + current_step: str = None, + error_message: str = None, + results: dict = None + ): + """Update training job status in database""" + try: + # Find the training log record + query = select(ModelTrainingLog).where( + ModelTrainingLog.job_id == job_id + ) + result = await self.db_session.execute(query) + training_log = result.scalar_one_or_none() + + if not training_log: + logger.error(f"Training log not found for job {job_id}") + return False + + # Update status fields + training_log.status = status + if progress is not None: + training_log.progress = progress + if current_step: + training_log.current_step = current_step + if error_message: + training_log.error_message = error_message + if results: + training_log.results = results + + # Set end time for completed/failed jobs + if status in ["completed", "failed", "cancelled"]: + training_log.end_time = datetime.now() + + # Update timestamp + training_log.updated_at = datetime.now() + + # Commit changes + await self.db_session.commit() + await self.db_session.refresh(training_log) + + logger.info(f"Updated training job {job_id} status to {status}") + return True + + except Exception as e: + logger.error(f"Failed to update job status: {str(e)}") + await self.db_session.rollback() + return False \ No newline at end of file