445 lines
17 KiB
Python
445 lines
17 KiB
Python
"""
|
|
Job Queue Repository
|
|
Repository for training job queue operations
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, text, desc
|
|
from datetime import datetime, timedelta
|
|
import structlog
|
|
|
|
from .base import TrainingBaseRepository
|
|
from app.models.training import TrainingJobQueue
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class JobQueueRepository(TrainingBaseRepository):
|
|
"""Repository for training job queue operations"""
|
|
|
|
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
|
|
# Job queue changes frequently, very short cache time (1 minute)
|
|
super().__init__(TrainingJobQueue, session, cache_ttl)
|
|
|
|
async def enqueue_job(self, job_data: Dict[str, Any]) -> TrainingJobQueue:
|
|
"""Add a job to the training queue"""
|
|
try:
|
|
# Validate job data
|
|
validation_result = self._validate_training_data(
|
|
job_data,
|
|
["job_id", "tenant_id", "job_type"]
|
|
)
|
|
|
|
if not validation_result["is_valid"]:
|
|
raise ValidationError(f"Invalid job data: {validation_result['errors']}")
|
|
|
|
# Set default values
|
|
if "priority" not in job_data:
|
|
job_data["priority"] = 1
|
|
if "status" not in job_data:
|
|
job_data["status"] = "queued"
|
|
if "max_retries" not in job_data:
|
|
job_data["max_retries"] = 3
|
|
|
|
# Create queue entry
|
|
queued_job = await self.create(job_data)
|
|
|
|
logger.info("Job enqueued",
|
|
job_id=queued_job.job_id,
|
|
tenant_id=queued_job.tenant_id,
|
|
job_type=queued_job.job_type,
|
|
priority=queued_job.priority)
|
|
|
|
return queued_job
|
|
|
|
except ValidationError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to enqueue job",
|
|
job_id=job_data.get("job_id"),
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to enqueue job: {str(e)}")
|
|
|
|
async def get_next_job(self, job_types: List[str] = None) -> Optional[TrainingJobQueue]:
|
|
"""Get the next job to process from the queue"""
|
|
try:
|
|
# Build filters for job types if specified
|
|
filters = {"status": "queued"}
|
|
|
|
if job_types:
|
|
# For multiple job types, we need to use raw SQL
|
|
job_types_str = "', '".join(job_types)
|
|
query_text = f"""
|
|
SELECT * FROM training_job_queue
|
|
WHERE status = 'queued'
|
|
AND job_type IN ('{job_types_str}')
|
|
AND (scheduled_at IS NULL OR scheduled_at <= :now)
|
|
ORDER BY priority DESC, created_at ASC
|
|
LIMIT 1
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"now": datetime.now()})
|
|
row = result.fetchone()
|
|
|
|
if row:
|
|
record_dict = dict(row._mapping)
|
|
return self.model(**record_dict)
|
|
return None
|
|
else:
|
|
# Simple case - get any queued job
|
|
jobs = await self.get_multi(
|
|
filters=filters,
|
|
limit=1,
|
|
order_by="priority",
|
|
order_desc=True
|
|
)
|
|
return jobs[0] if jobs else None
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get next job from queue",
|
|
job_types=job_types,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get next job: {str(e)}")
|
|
|
|
async def start_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
|
"""Mark a job as started"""
|
|
try:
|
|
job = await self.get_by_job_id(job_id)
|
|
if not job:
|
|
logger.error(f"Job not found in queue: {job_id}")
|
|
return None
|
|
|
|
if job.status != "queued":
|
|
logger.warning(f"Job {job_id} is not queued (status: {job.status})")
|
|
return job
|
|
|
|
updated_job = await self.update(job.id, {
|
|
"status": "running",
|
|
"started_at": datetime.now(),
|
|
"updated_at": datetime.now()
|
|
})
|
|
|
|
logger.info("Job started",
|
|
job_id=job_id,
|
|
job_type=job.job_type)
|
|
|
|
return updated_job
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to start job",
|
|
job_id=job_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to start job: {str(e)}")
|
|
|
|
async def complete_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
|
"""Mark a job as completed"""
|
|
try:
|
|
job = await self.get_by_job_id(job_id)
|
|
if not job:
|
|
logger.error(f"Job not found in queue: {job_id}")
|
|
return None
|
|
|
|
updated_job = await self.update(job.id, {
|
|
"status": "completed",
|
|
"updated_at": datetime.now()
|
|
})
|
|
|
|
logger.info("Job completed",
|
|
job_id=job_id,
|
|
job_type=job.job_type if job else "unknown")
|
|
|
|
return updated_job
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to complete job",
|
|
job_id=job_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to complete job: {str(e)}")
|
|
|
|
async def fail_job(self, job_id: str, error_message: str = None) -> Optional[TrainingJobQueue]:
|
|
"""Mark a job as failed and handle retries"""
|
|
try:
|
|
job = await self.get_by_job_id(job_id)
|
|
if not job:
|
|
logger.error(f"Job not found in queue: {job_id}")
|
|
return None
|
|
|
|
# Increment retry count
|
|
new_retry_count = job.retry_count + 1
|
|
|
|
# Check if we should retry
|
|
if new_retry_count < job.max_retries:
|
|
# Reset to queued for retry
|
|
updated_job = await self.update(job.id, {
|
|
"status": "queued",
|
|
"retry_count": new_retry_count,
|
|
"updated_at": datetime.now(),
|
|
"started_at": None # Reset started_at for retry
|
|
})
|
|
|
|
logger.info("Job failed, queued for retry",
|
|
job_id=job_id,
|
|
retry_count=new_retry_count,
|
|
max_retries=job.max_retries)
|
|
else:
|
|
# Mark as permanently failed
|
|
updated_job = await self.update(job.id, {
|
|
"status": "failed",
|
|
"retry_count": new_retry_count,
|
|
"updated_at": datetime.now()
|
|
})
|
|
|
|
logger.error("Job permanently failed",
|
|
job_id=job_id,
|
|
retry_count=new_retry_count,
|
|
error_message=error_message)
|
|
|
|
return updated_job
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to handle job failure",
|
|
job_id=job_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to handle job failure: {str(e)}")
|
|
|
|
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[TrainingJobQueue]:
|
|
"""Cancel a job"""
|
|
try:
|
|
job = await self.get_by_job_id(job_id)
|
|
if not job:
|
|
logger.error(f"Job not found in queue: {job_id}")
|
|
return None
|
|
|
|
if job.status in ["completed", "failed"]:
|
|
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
|
|
return job
|
|
|
|
updated_job = await self.update(job.id, {
|
|
"status": "cancelled",
|
|
"cancelled_by": cancelled_by,
|
|
"updated_at": datetime.now()
|
|
})
|
|
|
|
logger.info("Job cancelled",
|
|
job_id=job_id,
|
|
cancelled_by=cancelled_by)
|
|
|
|
return updated_job
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cancel job",
|
|
job_id=job_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to cancel job: {str(e)}")
|
|
|
|
async def get_queue_status(self, tenant_id: str = None) -> Dict[str, Any]:
|
|
"""Get queue status and statistics"""
|
|
try:
|
|
base_filters = {}
|
|
if tenant_id:
|
|
base_filters["tenant_id"] = tenant_id
|
|
|
|
# Get counts by status
|
|
queued_jobs = await self.count(filters={**base_filters, "status": "queued"})
|
|
running_jobs = await self.count(filters={**base_filters, "status": "running"})
|
|
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
|
|
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
|
|
cancelled_jobs = await self.count(filters={**base_filters, "status": "cancelled"})
|
|
|
|
# Get jobs by type
|
|
type_query = text(f"""
|
|
SELECT job_type, COUNT(*) as count
|
|
FROM training_job_queue
|
|
WHERE 1=1
|
|
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
|
GROUP BY job_type
|
|
ORDER BY count DESC
|
|
""")
|
|
|
|
params = {"tenant_id": tenant_id} if tenant_id else {}
|
|
result = await self.session.execute(type_query, params)
|
|
jobs_by_type = {row.job_type: row.count for row in result.fetchall()}
|
|
|
|
# Get average wait time for completed jobs
|
|
wait_time_query = text(f"""
|
|
SELECT
|
|
AVG(EXTRACT(EPOCH FROM (started_at - created_at))/60) as avg_wait_minutes
|
|
FROM training_job_queue
|
|
WHERE status = 'completed'
|
|
AND started_at IS NOT NULL
|
|
AND created_at IS NOT NULL
|
|
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
|
""")
|
|
|
|
wait_result = await self.session.execute(wait_time_query, params)
|
|
wait_row = wait_result.fetchone()
|
|
avg_wait_time = float(wait_row.avg_wait_minutes) if wait_row and wait_row.avg_wait_minutes else 0.0
|
|
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"queue_counts": {
|
|
"queued": queued_jobs,
|
|
"running": running_jobs,
|
|
"completed": completed_jobs,
|
|
"failed": failed_jobs,
|
|
"cancelled": cancelled_jobs,
|
|
"total": queued_jobs + running_jobs + completed_jobs + failed_jobs + cancelled_jobs
|
|
},
|
|
"jobs_by_type": jobs_by_type,
|
|
"avg_wait_time_minutes": round(avg_wait_time, 2),
|
|
"queue_health": {
|
|
"has_queued_jobs": queued_jobs > 0,
|
|
"has_running_jobs": running_jobs > 0,
|
|
"failure_rate": round((failed_jobs / max(completed_jobs + failed_jobs, 1)) * 100, 2)
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get queue status",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {
|
|
"tenant_id": tenant_id,
|
|
"queue_counts": {
|
|
"queued": 0, "running": 0, "completed": 0,
|
|
"failed": 0, "cancelled": 0, "total": 0
|
|
},
|
|
"jobs_by_type": {},
|
|
"avg_wait_time_minutes": 0.0,
|
|
"queue_health": {
|
|
"has_queued_jobs": False,
|
|
"has_running_jobs": False,
|
|
"failure_rate": 0.0
|
|
}
|
|
}
|
|
|
|
async def get_jobs_by_tenant(
|
|
self,
|
|
tenant_id: str,
|
|
status: str = None,
|
|
job_type: str = None,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[TrainingJobQueue]:
|
|
"""Get jobs for a tenant with optional filtering"""
|
|
try:
|
|
filters = {"tenant_id": tenant_id}
|
|
if status:
|
|
filters["status"] = status
|
|
if job_type:
|
|
filters["job_type"] = job_type
|
|
|
|
return await self.get_multi(
|
|
filters=filters,
|
|
skip=skip,
|
|
limit=limit,
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get jobs by tenant",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get tenant jobs: {str(e)}")
|
|
|
|
async def cleanup_old_jobs(self, days_old: int = 30, status_filter: str = None) -> int:
|
|
"""Clean up old completed/failed/cancelled jobs"""
|
|
try:
|
|
cutoff_date = datetime.now() - timedelta(days=days_old)
|
|
|
|
# Only clean up finished jobs by default
|
|
default_statuses = ["completed", "failed", "cancelled"]
|
|
|
|
if status_filter:
|
|
status_condition = "status = :status"
|
|
params = {"cutoff_date": cutoff_date, "status": status_filter}
|
|
else:
|
|
status_list = "', '".join(default_statuses)
|
|
status_condition = f"status IN ('{status_list}')"
|
|
params = {"cutoff_date": cutoff_date}
|
|
|
|
query_text = f"""
|
|
DELETE FROM training_job_queue
|
|
WHERE created_at < :cutoff_date
|
|
AND {status_condition}
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), params)
|
|
deleted_count = result.rowcount
|
|
|
|
logger.info("Cleaned up old queue jobs",
|
|
deleted_count=deleted_count,
|
|
days_old=days_old,
|
|
status_filter=status_filter)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup old queue jobs",
|
|
error=str(e))
|
|
raise DatabaseError(f"Queue cleanup failed: {str(e)}")
|
|
|
|
async def get_stuck_jobs(self, hours_stuck: int = 2) -> List[TrainingJobQueue]:
|
|
"""Get jobs that have been running for too long"""
|
|
try:
|
|
cutoff_time = datetime.now() - timedelta(hours=hours_stuck)
|
|
|
|
query_text = """
|
|
SELECT * FROM training_job_queue
|
|
WHERE status = 'running'
|
|
AND started_at IS NOT NULL
|
|
AND started_at < :cutoff_time
|
|
ORDER BY started_at ASC
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"cutoff_time": cutoff_time})
|
|
|
|
stuck_jobs = []
|
|
for row in result.fetchall():
|
|
record_dict = dict(row._mapping)
|
|
job = self.model(**record_dict)
|
|
stuck_jobs.append(job)
|
|
|
|
if stuck_jobs:
|
|
logger.warning("Found stuck jobs",
|
|
count=len(stuck_jobs),
|
|
hours_stuck=hours_stuck)
|
|
|
|
return stuck_jobs
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get stuck jobs",
|
|
hours_stuck=hours_stuck,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def reset_stuck_jobs(self, hours_stuck: int = 2) -> int:
|
|
"""Reset stuck jobs back to queued status"""
|
|
try:
|
|
stuck_jobs = await self.get_stuck_jobs(hours_stuck)
|
|
reset_count = 0
|
|
|
|
for job in stuck_jobs:
|
|
# Reset job to queued status
|
|
await self.update(job.id, {
|
|
"status": "queued",
|
|
"started_at": None,
|
|
"updated_at": datetime.now()
|
|
})
|
|
reset_count += 1
|
|
|
|
if reset_count > 0:
|
|
logger.info("Reset stuck jobs",
|
|
reset_count=reset_count,
|
|
hours_stuck=hours_stuck)
|
|
|
|
return reset_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to reset stuck jobs",
|
|
hours_stuck=hours_stuck,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to reset stuck jobs: {str(e)}") |