2025-08-08 09:08:41 +02:00
|
|
|
"""
|
|
|
|
|
Training Log Repository
|
|
|
|
|
Repository for model training log operations
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
from sqlalchemy import select, and_, text, desc
|
|
|
|
|
from datetime import datetime, timedelta
|
|
|
|
|
import structlog
|
|
|
|
|
|
|
|
|
|
from .base import TrainingBaseRepository
|
|
|
|
|
from app.models.training import ModelTrainingLog
|
|
|
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
|
|
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainingLogRepository(TrainingBaseRepository):
|
|
|
|
|
"""Repository for training log operations"""
|
|
|
|
|
|
|
|
|
|
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
|
|
|
|
# Training logs change frequently, shorter cache time (5 minutes)
|
|
|
|
|
super().__init__(ModelTrainingLog, session, cache_ttl)
|
|
|
|
|
|
|
|
|
|
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
|
|
|
|
|
"""Create a new training log entry"""
|
|
|
|
|
try:
|
|
|
|
|
# Validate log data
|
|
|
|
|
validation_result = self._validate_training_data(
|
|
|
|
|
log_data,
|
|
|
|
|
["job_id", "tenant_id", "status"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if not validation_result["is_valid"]:
|
|
|
|
|
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
|
|
|
|
|
|
|
|
|
|
# Set default values
|
|
|
|
|
if "progress" not in log_data:
|
|
|
|
|
log_data["progress"] = 0
|
|
|
|
|
if "current_step" not in log_data:
|
|
|
|
|
log_data["current_step"] = "initializing"
|
|
|
|
|
|
|
|
|
|
# Create log entry
|
|
|
|
|
log_entry = await self.create(log_data)
|
|
|
|
|
|
|
|
|
|
logger.info("Training log created",
|
|
|
|
|
job_id=log_entry.job_id,
|
|
|
|
|
tenant_id=log_entry.tenant_id,
|
|
|
|
|
status=log_entry.status)
|
|
|
|
|
|
|
|
|
|
return log_entry
|
|
|
|
|
|
|
|
|
|
except ValidationError:
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to create training log",
|
|
|
|
|
job_id=log_data.get("job_id"),
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise DatabaseError(f"Failed to create training log: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
|
|
|
|
|
"""Get training log by job ID"""
|
|
|
|
|
return await self.get_by_job_id(job_id)
|
|
|
|
|
|
|
|
|
|
async def update_log_progress(
|
|
|
|
|
self,
|
|
|
|
|
job_id: str,
|
|
|
|
|
progress: int,
|
|
|
|
|
current_step: str = None,
|
|
|
|
|
status: str = None
|
|
|
|
|
) -> Optional[ModelTrainingLog]:
|
|
|
|
|
"""Update training log progress"""
|
|
|
|
|
try:
|
|
|
|
|
update_data = {"progress": progress, "updated_at": datetime.now()}
|
|
|
|
|
|
|
|
|
|
if current_step:
|
|
|
|
|
update_data["current_step"] = current_step
|
|
|
|
|
if status:
|
|
|
|
|
update_data["status"] = status
|
|
|
|
|
|
|
|
|
|
log_entry = await self.get_by_job_id(job_id)
|
|
|
|
|
if not log_entry:
|
|
|
|
|
logger.error(f"Training log not found for job {job_id}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
updated_log = await self.update(log_entry.id, update_data)
|
|
|
|
|
|
|
|
|
|
logger.debug("Training log progress updated",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
progress=progress,
|
|
|
|
|
step=current_step)
|
|
|
|
|
|
|
|
|
|
return updated_log
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to update training log progress",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise DatabaseError(f"Failed to update progress: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def complete_training_log(
|
|
|
|
|
self,
|
|
|
|
|
job_id: str,
|
|
|
|
|
results: Dict[str, Any] = None,
|
|
|
|
|
error_message: str = None
|
|
|
|
|
) -> Optional[ModelTrainingLog]:
|
|
|
|
|
"""Mark training log as completed or failed"""
|
|
|
|
|
try:
|
|
|
|
|
status = "failed" if error_message else "completed"
|
|
|
|
|
|
|
|
|
|
update_data = {
|
|
|
|
|
"status": status,
|
|
|
|
|
"progress": 100 if status == "completed" else None,
|
|
|
|
|
"end_time": datetime.now(),
|
|
|
|
|
"updated_at": datetime.now()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if results:
|
|
|
|
|
update_data["results"] = results
|
|
|
|
|
if error_message:
|
|
|
|
|
update_data["error_message"] = error_message
|
|
|
|
|
|
|
|
|
|
log_entry = await self.get_by_job_id(job_id)
|
|
|
|
|
if not log_entry:
|
|
|
|
|
logger.error(f"Training log not found for job {job_id}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
updated_log = await self.update(log_entry.id, update_data)
|
|
|
|
|
|
|
|
|
|
logger.info("Training log completed",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
status=status,
|
|
|
|
|
has_results=bool(results))
|
|
|
|
|
|
|
|
|
|
return updated_log
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to complete training log",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise DatabaseError(f"Failed to complete training log: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def get_logs_by_tenant(
|
|
|
|
|
self,
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
status: str = None,
|
|
|
|
|
skip: int = 0,
|
|
|
|
|
limit: int = 100
|
|
|
|
|
) -> List[ModelTrainingLog]:
|
|
|
|
|
"""Get training logs for a tenant"""
|
|
|
|
|
try:
|
|
|
|
|
filters = {"tenant_id": tenant_id}
|
|
|
|
|
if status:
|
|
|
|
|
filters["status"] = status
|
|
|
|
|
|
|
|
|
|
return await self.get_multi(
|
|
|
|
|
filters=filters,
|
|
|
|
|
skip=skip,
|
|
|
|
|
limit=limit,
|
|
|
|
|
order_by="created_at",
|
|
|
|
|
order_desc=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to get logs by tenant",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise DatabaseError(f"Failed to get training logs: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
|
|
|
|
|
"""Get currently running training jobs"""
|
|
|
|
|
try:
|
|
|
|
|
filters = {"status": "running"}
|
|
|
|
|
if tenant_id:
|
|
|
|
|
filters["tenant_id"] = tenant_id
|
|
|
|
|
|
|
|
|
|
return await self.get_multi(
|
|
|
|
|
filters=filters,
|
|
|
|
|
order_by="start_time",
|
|
|
|
|
order_desc=True
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to get active jobs",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
|
|
|
|
|
"""Cancel a training job"""
|
|
|
|
|
try:
|
|
|
|
|
update_data = {
|
|
|
|
|
"status": "cancelled",
|
|
|
|
|
"end_time": datetime.now(),
|
|
|
|
|
"updated_at": datetime.now()
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if cancelled_by:
|
|
|
|
|
update_data["error_message"] = f"Cancelled by {cancelled_by}"
|
|
|
|
|
|
|
|
|
|
log_entry = await self.get_by_job_id(job_id)
|
|
|
|
|
if not log_entry:
|
|
|
|
|
logger.error(f"Training log not found for job {job_id}")
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
# Only cancel if job is still running
|
|
|
|
|
if log_entry.status not in ["pending", "running"]:
|
|
|
|
|
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
|
|
|
|
|
return log_entry
|
|
|
|
|
|
|
|
|
|
updated_log = await self.update(log_entry.id, update_data)
|
|
|
|
|
|
|
|
|
|
logger.info("Training job cancelled",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
cancelled_by=cancelled_by)
|
|
|
|
|
|
|
|
|
|
return updated_log
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to cancel training job",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise DatabaseError(f"Failed to cancel job: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
|
|
|
|
"""Get training job statistics"""
|
|
|
|
|
try:
|
|
|
|
|
base_filters = {}
|
|
|
|
|
if tenant_id:
|
|
|
|
|
base_filters["tenant_id"] = tenant_id
|
|
|
|
|
|
|
|
|
|
# Get counts by status
|
|
|
|
|
total_jobs = await self.count(filters=base_filters)
|
|
|
|
|
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
|
|
|
|
|
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
|
|
|
|
|
running_jobs = await self.count(filters={**base_filters, "status": "running"})
|
|
|
|
|
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
|
|
|
|
|
|
|
|
|
|
# Get recent activity (jobs in last 7 days)
|
|
|
|
|
seven_days_ago = datetime.now() - timedelta(days=7)
|
|
|
|
|
recent_jobs = len(await self.get_records_by_date_range(
|
|
|
|
|
seven_days_ago,
|
|
|
|
|
datetime.now(),
|
|
|
|
|
limit=1000 # High limit to get accurate count
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
# Calculate success rate
|
|
|
|
|
finished_jobs = completed_jobs + failed_jobs
|
|
|
|
|
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"total_jobs": total_jobs,
|
|
|
|
|
"completed_jobs": completed_jobs,
|
|
|
|
|
"failed_jobs": failed_jobs,
|
|
|
|
|
"running_jobs": running_jobs,
|
|
|
|
|
"pending_jobs": pending_jobs,
|
|
|
|
|
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
|
|
|
|
|
"success_rate": round(success_rate, 2),
|
|
|
|
|
"recent_jobs_7d": recent_jobs
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to get job statistics",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return {
|
|
|
|
|
"total_jobs": 0,
|
|
|
|
|
"completed_jobs": 0,
|
|
|
|
|
"failed_jobs": 0,
|
|
|
|
|
"running_jobs": 0,
|
|
|
|
|
"pending_jobs": 0,
|
|
|
|
|
"cancelled_jobs": 0,
|
|
|
|
|
"success_rate": 0.0,
|
|
|
|
|
"recent_jobs_7d": 0
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async def cleanup_old_logs(self, days_old: int = 90) -> int:
|
|
|
|
|
"""Clean up old completed/failed training logs"""
|
|
|
|
|
return await self.cleanup_old_records(
|
|
|
|
|
days_old=days_old,
|
|
|
|
|
status_filter=None # Clean up all old records regardless of status
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
|
|
|
|
|
"""Get job duration statistics"""
|
|
|
|
|
try:
|
|
|
|
|
# Use raw SQL for complex duration calculations
|
|
|
|
|
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
|
|
|
|
|
params = {"tenant_id": tenant_id} if tenant_id else {}
|
|
|
|
|
|
|
|
|
|
query = text(f"""
|
|
|
|
|
SELECT
|
|
|
|
|
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
|
|
|
|
|
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
|
|
|
|
|
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
|
|
|
|
|
COUNT(*) as completed_jobs_with_duration
|
|
|
|
|
FROM model_training_logs
|
|
|
|
|
WHERE status = 'completed'
|
|
|
|
|
AND start_time IS NOT NULL
|
|
|
|
|
AND end_time IS NOT NULL
|
|
|
|
|
{tenant_filter}
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
result = await self.session.execute(query, params)
|
|
|
|
|
row = result.fetchone()
|
|
|
|
|
|
|
|
|
|
if row and row.completed_jobs_with_duration > 0:
|
|
|
|
|
return {
|
|
|
|
|
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
|
|
|
|
|
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
|
|
|
|
|
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
|
|
|
|
|
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
"avg_duration_minutes": 0.0,
|
|
|
|
|
"min_duration_minutes": 0.0,
|
|
|
|
|
"max_duration_minutes": 0.0,
|
|
|
|
|
"completed_jobs_with_duration": 0
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to get job duration statistics",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
return {
|
|
|
|
|
"avg_duration_minutes": 0.0,
|
|
|
|
|
"min_duration_minutes": 0.0,
|
|
|
|
|
"max_duration_minutes": 0.0,
|
|
|
|
|
"completed_jobs_with_duration": 0
|
Fix multiple critical bugs in onboarding training step
This commit addresses all identified bugs and issues in the training code path:
## Critical Fixes:
- Add get_start_time() method to TrainingLogRepository and fix non-existent method call
- Remove duplicate training.started event from API endpoint (trainer publishes the accurate one)
- Add missing progress events for 80-100% range (85%, 92%, 94%) to eliminate progress "dead zone"
## High Priority Fixes:
- Fix division by zero risk in time estimation with double-check and max() safety
- Remove unreachable exception handler in training_operations.py
- Simplify WebSocket token refresh logic to only reconnect on actual user session changes
## Medium Priority Fixes:
- Fix auto-start training effect with useRef to prevent duplicate starts
- Add HTTP polling debounce delay (5s) to prevent race conditions with WebSocket
- Extract all magic numbers to centralized constants files:
- Backend: services/training/app/core/training_constants.py
- Frontend: frontend/src/constants/training.ts
- Standardize error logging with exc_info=True on critical errors
## Code Quality Improvements:
- All progress percentages now use named constants
- All timeouts and intervals now use named constants
- Improved code maintainability and readability
- Better separation of concerns
## Files Changed:
- Backend: training_service.py, trainer.py, training_events.py, progress_tracker.py
- Backend: training_operations.py, training_log_repository.py, training_constants.py (new)
- Frontend: training.ts (hooks), MLTrainingStep.tsx, training.ts (constants, new)
All training progress events now properly flow from 0% to 100% with no gaps.
2025-11-05 13:02:39 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
async def get_start_time(self, job_id: str) -> Optional[datetime]:
|
|
|
|
|
"""Get the start time for a training job"""
|
|
|
|
|
try:
|
|
|
|
|
log_entry = await self.get_by_job_id(job_id)
|
|
|
|
|
if log_entry and log_entry.start_time:
|
|
|
|
|
return log_entry.start_time
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to get start time",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
error=str(e))
|
2026-01-18 09:02:27 +01:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
async def create_job_atomic(
|
|
|
|
|
self,
|
|
|
|
|
job_id: str,
|
|
|
|
|
tenant_id: str,
|
|
|
|
|
config: Dict[str, Any] = None
|
|
|
|
|
) -> tuple[Optional[ModelTrainingLog], bool]:
|
|
|
|
|
"""
|
|
|
|
|
Atomically create a training job, respecting the unique constraint.
|
|
|
|
|
|
|
|
|
|
This method uses INSERT ... ON CONFLICT to handle race conditions
|
|
|
|
|
when multiple pods try to create a job for the same tenant simultaneously.
|
|
|
|
|
The database constraint (idx_unique_active_training_per_tenant) ensures
|
|
|
|
|
only one active job per tenant can exist.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
job_id: Unique job identifier
|
|
|
|
|
tenant_id: Tenant identifier
|
|
|
|
|
config: Optional job configuration
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
Tuple of (job, created):
|
|
|
|
|
- If created: (new_job, True)
|
|
|
|
|
- If conflict (existing active job): (existing_job, False)
|
|
|
|
|
- If error: raises DatabaseError
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
# First, try to find an existing active job
|
|
|
|
|
existing = await self.get_active_jobs(tenant_id=tenant_id)
|
|
|
|
|
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
|
|
|
|
|
|
|
|
|
|
if existing or pending:
|
|
|
|
|
# Return existing job
|
|
|
|
|
active_job = existing[0] if existing else pending[0]
|
|
|
|
|
logger.info("Found existing active job, skipping creation",
|
|
|
|
|
existing_job_id=active_job.job_id,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
requested_job_id=job_id)
|
|
|
|
|
return (active_job, False)
|
|
|
|
|
|
|
|
|
|
# Try to create the new job
|
|
|
|
|
# If another pod created one in the meantime, the unique constraint will prevent this
|
|
|
|
|
log_data = {
|
|
|
|
|
"job_id": job_id,
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"status": "pending",
|
|
|
|
|
"progress": 0,
|
|
|
|
|
"current_step": "initializing",
|
|
|
|
|
"config": config or {}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
new_job = await self.create_training_log(log_data)
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
logger.info("Created new training job atomically",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id)
|
|
|
|
|
return (new_job, True)
|
|
|
|
|
except Exception as create_error:
|
|
|
|
|
error_str = str(create_error).lower()
|
|
|
|
|
# Check if this is a unique constraint violation
|
|
|
|
|
if "unique" in error_str or "duplicate" in error_str or "constraint" in error_str:
|
|
|
|
|
await self.session.rollback()
|
|
|
|
|
# Another pod created a job, fetch it
|
|
|
|
|
logger.info("Unique constraint hit, fetching existing job",
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
requested_job_id=job_id)
|
|
|
|
|
existing = await self.get_active_jobs(tenant_id=tenant_id)
|
|
|
|
|
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
|
|
|
|
|
if existing or pending:
|
|
|
|
|
active_job = existing[0] if existing else pending[0]
|
|
|
|
|
return (active_job, False)
|
|
|
|
|
# If still no job found, something went wrong
|
|
|
|
|
raise DatabaseError(f"Constraint violation but no active job found: {create_error}")
|
|
|
|
|
else:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
except DatabaseError:
|
|
|
|
|
raise
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to create job atomically",
|
|
|
|
|
job_id=job_id,
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
error=str(e))
|
|
|
|
|
raise DatabaseError(f"Failed to create training job atomically: {str(e)}")
|
|
|
|
|
|
|
|
|
|
async def recover_stale_jobs(self, stale_threshold_minutes: int = 60) -> List[ModelTrainingLog]:
|
|
|
|
|
"""
|
|
|
|
|
Find and mark stale running jobs as failed.
|
|
|
|
|
|
|
|
|
|
This is used during service startup to clean up jobs that were
|
|
|
|
|
running when a pod crashed. With multiple replicas, only stale
|
|
|
|
|
jobs (not updated recently) should be marked as failed.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
stale_threshold_minutes: Jobs not updated for this long are considered stale
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
List of jobs that were marked as failed
|
|
|
|
|
"""
|
|
|
|
|
try:
|
|
|
|
|
stale_cutoff = datetime.now() - timedelta(minutes=stale_threshold_minutes)
|
|
|
|
|
|
|
|
|
|
# Find running jobs that haven't been updated recently
|
|
|
|
|
query = text("""
|
|
|
|
|
SELECT id, job_id, tenant_id, status, updated_at
|
|
|
|
|
FROM model_training_logs
|
|
|
|
|
WHERE status IN ('running', 'pending')
|
|
|
|
|
AND updated_at < :stale_cutoff
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
result = await self.session.execute(query, {"stale_cutoff": stale_cutoff})
|
|
|
|
|
stale_jobs = result.fetchall()
|
|
|
|
|
|
|
|
|
|
recovered_jobs = []
|
|
|
|
|
for row in stale_jobs:
|
|
|
|
|
try:
|
|
|
|
|
# Mark as failed
|
|
|
|
|
update_query = text("""
|
|
|
|
|
UPDATE model_training_logs
|
|
|
|
|
SET status = 'failed',
|
|
|
|
|
error_message = :error_msg,
|
|
|
|
|
end_time = :end_time,
|
|
|
|
|
updated_at = :updated_at
|
|
|
|
|
WHERE id = :id AND status IN ('running', 'pending')
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
await self.session.execute(update_query, {
|
|
|
|
|
"id": row.id,
|
|
|
|
|
"error_msg": f"Job recovered as failed - not updated since {row.updated_at.isoformat()}. Pod may have crashed.",
|
|
|
|
|
"end_time": datetime.now(),
|
|
|
|
|
"updated_at": datetime.now()
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
logger.warning("Recovered stale training job",
|
|
|
|
|
job_id=row.job_id,
|
|
|
|
|
tenant_id=str(row.tenant_id),
|
|
|
|
|
last_updated=row.updated_at.isoformat() if row.updated_at else "unknown")
|
|
|
|
|
|
|
|
|
|
# Fetch the updated job to return
|
|
|
|
|
job = await self.get_by_job_id(row.job_id)
|
|
|
|
|
if job:
|
|
|
|
|
recovered_jobs.append(job)
|
|
|
|
|
|
|
|
|
|
except Exception as job_error:
|
|
|
|
|
logger.error("Failed to recover individual stale job",
|
|
|
|
|
job_id=row.job_id,
|
|
|
|
|
error=str(job_error))
|
|
|
|
|
|
|
|
|
|
if recovered_jobs:
|
|
|
|
|
await self.session.commit()
|
|
|
|
|
logger.info("Stale job recovery completed",
|
|
|
|
|
recovered_count=len(recovered_jobs),
|
|
|
|
|
stale_threshold_minutes=stale_threshold_minutes)
|
|
|
|
|
|
|
|
|
|
return recovered_jobs
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error("Failed to recover stale jobs",
|
|
|
|
|
error=str(e))
|
|
|
|
|
await self.session.rollback()
|
|
|
|
|
return []
|