REFACTOR - Database logic
This commit is contained in:
445
services/training/app/repositories/job_queue_repository.py
Normal file
445
services/training/app/repositories/job_queue_repository.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Job Queue Repository
|
||||
Repository for training job queue operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import TrainingJobQueue
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class JobQueueRepository(TrainingBaseRepository):
|
||||
"""Repository for training job queue operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
|
||||
# Job queue changes frequently, very short cache time (1 minute)
|
||||
super().__init__(TrainingJobQueue, session, cache_ttl)
|
||||
|
||||
async def enqueue_job(self, job_data: Dict[str, Any]) -> TrainingJobQueue:
|
||||
"""Add a job to the training queue"""
|
||||
try:
|
||||
# Validate job data
|
||||
validation_result = self._validate_training_data(
|
||||
job_data,
|
||||
["job_id", "tenant_id", "job_type"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid job data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "priority" not in job_data:
|
||||
job_data["priority"] = 1
|
||||
if "status" not in job_data:
|
||||
job_data["status"] = "queued"
|
||||
if "max_retries" not in job_data:
|
||||
job_data["max_retries"] = 3
|
||||
|
||||
# Create queue entry
|
||||
queued_job = await self.create(job_data)
|
||||
|
||||
logger.info("Job enqueued",
|
||||
job_id=queued_job.job_id,
|
||||
tenant_id=queued_job.tenant_id,
|
||||
job_type=queued_job.job_type,
|
||||
priority=queued_job.priority)
|
||||
|
||||
return queued_job
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to enqueue job",
|
||||
job_id=job_data.get("job_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to enqueue job: {str(e)}")
|
||||
|
||||
async def get_next_job(self, job_types: List[str] = None) -> Optional[TrainingJobQueue]:
|
||||
"""Get the next job to process from the queue"""
|
||||
try:
|
||||
# Build filters for job types if specified
|
||||
filters = {"status": "queued"}
|
||||
|
||||
if job_types:
|
||||
# For multiple job types, we need to use raw SQL
|
||||
job_types_str = "', '".join(job_types)
|
||||
query_text = f"""
|
||||
SELECT * FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
AND job_type IN ('{job_types_str}')
|
||||
AND (scheduled_at IS NULL OR scheduled_at <= :now)
|
||||
ORDER BY priority DESC, created_at ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"now": datetime.now()})
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
record_dict = dict(row._mapping)
|
||||
return self.model(**record_dict)
|
||||
return None
|
||||
else:
|
||||
# Simple case - get any queued job
|
||||
jobs = await self.get_multi(
|
||||
filters=filters,
|
||||
limit=1,
|
||||
order_by="priority",
|
||||
order_desc=True
|
||||
)
|
||||
return jobs[0] if jobs else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get next job from queue",
|
||||
job_types=job_types,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get next job: {str(e)}")
|
||||
|
||||
async def start_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as started"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
if job.status != "queued":
|
||||
logger.warning(f"Job {job_id} is not queued (status: {job.status})")
|
||||
return job
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "running",
|
||||
"started_at": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job started",
|
||||
job_id=job_id,
|
||||
job_type=job.job_type)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to start job: {str(e)}")
|
||||
|
||||
async def complete_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as completed"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "completed",
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job completed",
|
||||
job_id=job_id,
|
||||
job_type=job.job_type if job else "unknown")
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to complete job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to complete job: {str(e)}")
|
||||
|
||||
async def fail_job(self, job_id: str, error_message: str = None) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as failed and handle retries"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
# Increment retry count
|
||||
new_retry_count = job.retry_count + 1
|
||||
|
||||
# Check if we should retry
|
||||
if new_retry_count < job.max_retries:
|
||||
# Reset to queued for retry
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "queued",
|
||||
"retry_count": new_retry_count,
|
||||
"updated_at": datetime.now(),
|
||||
"started_at": None # Reset started_at for retry
|
||||
})
|
||||
|
||||
logger.info("Job failed, queued for retry",
|
||||
job_id=job_id,
|
||||
retry_count=new_retry_count,
|
||||
max_retries=job.max_retries)
|
||||
else:
|
||||
# Mark as permanently failed
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "failed",
|
||||
"retry_count": new_retry_count,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.error("Job permanently failed",
|
||||
job_id=job_id,
|
||||
retry_count=new_retry_count,
|
||||
error_message=error_message)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to handle job failure",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to handle job failure: {str(e)}")
|
||||
|
||||
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[TrainingJobQueue]:
|
||||
"""Cancel a job"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
if job.status in ["completed", "failed"]:
|
||||
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
|
||||
return job
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "cancelled",
|
||||
"cancelled_by": cancelled_by,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job cancelled",
|
||||
job_id=job_id,
|
||||
cancelled_by=cancelled_by)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel job: {str(e)}")
|
||||
|
||||
async def get_queue_status(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get queue status and statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get counts by status
|
||||
queued_jobs = await self.count(filters={**base_filters, "status": "queued"})
|
||||
running_jobs = await self.count(filters={**base_filters, "status": "running"})
|
||||
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
|
||||
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
|
||||
cancelled_jobs = await self.count(filters={**base_filters, "status": "cancelled"})
|
||||
|
||||
# Get jobs by type
|
||||
type_query = text(f"""
|
||||
SELECT job_type, COUNT(*) as count
|
||||
FROM training_job_queue
|
||||
WHERE 1=1
|
||||
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
||||
GROUP BY job_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
params = {"tenant_id": tenant_id} if tenant_id else {}
|
||||
result = await self.session.execute(type_query, params)
|
||||
jobs_by_type = {row.job_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get average wait time for completed jobs
|
||||
wait_time_query = text(f"""
|
||||
SELECT
|
||||
AVG(EXTRACT(EPOCH FROM (started_at - created_at))/60) as avg_wait_minutes
|
||||
FROM training_job_queue
|
||||
WHERE status = 'completed'
|
||||
AND started_at IS NOT NULL
|
||||
AND created_at IS NOT NULL
|
||||
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
||||
""")
|
||||
|
||||
wait_result = await self.session.execute(wait_time_query, params)
|
||||
wait_row = wait_result.fetchone()
|
||||
avg_wait_time = float(wait_row.avg_wait_minutes) if wait_row and wait_row.avg_wait_minutes else 0.0
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"queue_counts": {
|
||||
"queued": queued_jobs,
|
||||
"running": running_jobs,
|
||||
"completed": completed_jobs,
|
||||
"failed": failed_jobs,
|
||||
"cancelled": cancelled_jobs,
|
||||
"total": queued_jobs + running_jobs + completed_jobs + failed_jobs + cancelled_jobs
|
||||
},
|
||||
"jobs_by_type": jobs_by_type,
|
||||
"avg_wait_time_minutes": round(avg_wait_time, 2),
|
||||
"queue_health": {
|
||||
"has_queued_jobs": queued_jobs > 0,
|
||||
"has_running_jobs": running_jobs > 0,
|
||||
"failure_rate": round((failed_jobs / max(completed_jobs + failed_jobs, 1)) * 100, 2)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get queue status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"queue_counts": {
|
||||
"queued": 0, "running": 0, "completed": 0,
|
||||
"failed": 0, "cancelled": 0, "total": 0
|
||||
},
|
||||
"jobs_by_type": {},
|
||||
"avg_wait_time_minutes": 0.0,
|
||||
"queue_health": {
|
||||
"has_queued_jobs": False,
|
||||
"has_running_jobs": False,
|
||||
"failure_rate": 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def get_jobs_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
status: str = None,
|
||||
job_type: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrainingJobQueue]:
|
||||
"""Get jobs for a tenant with optional filtering"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if status:
|
||||
filters["status"] = status
|
||||
if job_type:
|
||||
filters["job_type"] = job_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get jobs by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant jobs: {str(e)}")
|
||||
|
||||
async def cleanup_old_jobs(self, days_old: int = 30, status_filter: str = None) -> int:
|
||||
"""Clean up old completed/failed/cancelled jobs"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_old)
|
||||
|
||||
# Only clean up finished jobs by default
|
||||
default_statuses = ["completed", "failed", "cancelled"]
|
||||
|
||||
if status_filter:
|
||||
status_condition = "status = :status"
|
||||
params = {"cutoff_date": cutoff_date, "status": status_filter}
|
||||
else:
|
||||
status_list = "', '".join(default_statuses)
|
||||
status_condition = f"status IN ('{status_list}')"
|
||||
params = {"cutoff_date": cutoff_date}
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM training_job_queue
|
||||
WHERE created_at < :cutoff_date
|
||||
AND {status_condition}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old queue jobs",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old,
|
||||
status_filter=status_filter)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old queue jobs",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Queue cleanup failed: {str(e)}")
|
||||
|
||||
async def get_stuck_jobs(self, hours_stuck: int = 2) -> List[TrainingJobQueue]:
|
||||
"""Get jobs that have been running for too long"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_stuck)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM training_job_queue
|
||||
WHERE status = 'running'
|
||||
AND started_at IS NOT NULL
|
||||
AND started_at < :cutoff_time
|
||||
ORDER BY started_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_time": cutoff_time})
|
||||
|
||||
stuck_jobs = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
job = self.model(**record_dict)
|
||||
stuck_jobs.append(job)
|
||||
|
||||
if stuck_jobs:
|
||||
logger.warning("Found stuck jobs",
|
||||
count=len(stuck_jobs),
|
||||
hours_stuck=hours_stuck)
|
||||
|
||||
return stuck_jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get stuck jobs",
|
||||
hours_stuck=hours_stuck,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def reset_stuck_jobs(self, hours_stuck: int = 2) -> int:
|
||||
"""Reset stuck jobs back to queued status"""
|
||||
try:
|
||||
stuck_jobs = await self.get_stuck_jobs(hours_stuck)
|
||||
reset_count = 0
|
||||
|
||||
for job in stuck_jobs:
|
||||
# Reset job to queued status
|
||||
await self.update(job.id, {
|
||||
"status": "queued",
|
||||
"started_at": None,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
reset_count += 1
|
||||
|
||||
if reset_count > 0:
|
||||
logger.info("Reset stuck jobs",
|
||||
reset_count=reset_count,
|
||||
hours_stuck=hours_stuck)
|
||||
|
||||
return reset_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to reset stuck jobs",
|
||||
hours_stuck=hours_stuck,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reset stuck jobs: {str(e)}")
|
||||
Reference in New Issue
Block a user