Files
bakery-ia/services/training/app/repositories/training_log_repository.py
2026-01-18 09:02:27 +01:00

507 lines
20 KiB
Python

"""
Training Log Repository
Repository for model training log operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelTrainingLog
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrainingLogRepository(TrainingBaseRepository):
"""Repository for training log operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training logs change frequently, shorter cache time (5 minutes)
super().__init__(ModelTrainingLog, session, cache_ttl)
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training log entry"""
try:
# Validate log data
validation_result = self._validate_training_data(
log_data,
["job_id", "tenant_id", "status"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
# Set default values
if "progress" not in log_data:
log_data["progress"] = 0
if "current_step" not in log_data:
log_data["current_step"] = "initializing"
# Create log entry
log_entry = await self.create(log_data)
logger.info("Training log created",
job_id=log_entry.job_id,
tenant_id=log_entry.tenant_id,
status=log_entry.status)
return log_entry
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create training log",
job_id=log_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to create training log: {str(e)}")
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
"""Get training log by job ID"""
return await self.get_by_job_id(job_id)
async def update_log_progress(
self,
job_id: str,
progress: int,
current_step: str = None,
status: str = None
) -> Optional[ModelTrainingLog]:
"""Update training log progress"""
try:
update_data = {"progress": progress, "updated_at": datetime.now()}
if current_step:
update_data["current_step"] = current_step
if status:
update_data["status"] = status
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.debug("Training log progress updated",
job_id=job_id,
progress=progress,
step=current_step)
return updated_log
except Exception as e:
logger.error("Failed to update training log progress",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to update progress: {str(e)}")
async def complete_training_log(
self,
job_id: str,
results: Dict[str, Any] = None,
error_message: str = None
) -> Optional[ModelTrainingLog]:
"""Mark training log as completed or failed"""
try:
status = "failed" if error_message else "completed"
update_data = {
"status": status,
"progress": 100 if status == "completed" else None,
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if results:
update_data["results"] = results
if error_message:
update_data["error_message"] = error_message
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training log completed",
job_id=job_id,
status=status,
has_results=bool(results))
return updated_log
except Exception as e:
logger.error("Failed to complete training log",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete training log: {str(e)}")
async def get_logs_by_tenant(
self,
tenant_id: str,
status: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelTrainingLog]:
"""Get training logs for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get logs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get training logs: {str(e)}")
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
"""Get currently running training jobs"""
try:
filters = {"status": "running"}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="start_time",
order_desc=True
)
except Exception as e:
logger.error("Failed to get active jobs",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
"""Cancel a training job"""
try:
update_data = {
"status": "cancelled",
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if cancelled_by:
update_data["error_message"] = f"Cancelled by {cancelled_by}"
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
# Only cancel if job is still running
if log_entry.status not in ["pending", "running"]:
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
return log_entry
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_log
except Exception as e:
logger.error("Failed to cancel training job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get training job statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
total_jobs = await self.count(filters=base_filters)
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
# Get recent activity (jobs in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_jobs = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
# Calculate success rate
finished_jobs = completed_jobs + failed_jobs
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
return {
"total_jobs": total_jobs,
"completed_jobs": completed_jobs,
"failed_jobs": failed_jobs,
"running_jobs": running_jobs,
"pending_jobs": pending_jobs,
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
"success_rate": round(success_rate, 2),
"recent_jobs_7d": recent_jobs
}
except Exception as e:
logger.error("Failed to get job statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_jobs": 0,
"completed_jobs": 0,
"failed_jobs": 0,
"running_jobs": 0,
"pending_jobs": 0,
"cancelled_jobs": 0,
"success_rate": 0.0,
"recent_jobs_7d": 0
}
async def cleanup_old_logs(self, days_old: int = 90) -> int:
"""Clean up old completed/failed training logs"""
return await self.cleanup_old_records(
days_old=days_old,
status_filter=None # Clean up all old records regardless of status
)
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get job duration statistics"""
try:
# Use raw SQL for complex duration calculations
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
params = {"tenant_id": tenant_id} if tenant_id else {}
query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
COUNT(*) as completed_jobs_with_duration
FROM model_training_logs
WHERE status = 'completed'
AND start_time IS NOT NULL
AND end_time IS NOT NULL
{tenant_filter}
""")
result = await self.session.execute(query, params)
row = result.fetchone()
if row and row.completed_jobs_with_duration > 0:
return {
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
}
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
except Exception as e:
logger.error("Failed to get job duration statistics",
tenant_id=tenant_id,
error=str(e))
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
async def get_start_time(self, job_id: str) -> Optional[datetime]:
"""Get the start time for a training job"""
try:
log_entry = await self.get_by_job_id(job_id)
if log_entry and log_entry.start_time:
return log_entry.start_time
return None
except Exception as e:
logger.error("Failed to get start time",
job_id=job_id,
error=str(e))
return None
async def create_job_atomic(
self,
job_id: str,
tenant_id: str,
config: Dict[str, Any] = None
) -> tuple[Optional[ModelTrainingLog], bool]:
"""
Atomically create a training job, respecting the unique constraint.
This method uses INSERT ... ON CONFLICT to handle race conditions
when multiple pods try to create a job for the same tenant simultaneously.
The database constraint (idx_unique_active_training_per_tenant) ensures
only one active job per tenant can exist.
Args:
job_id: Unique job identifier
tenant_id: Tenant identifier
config: Optional job configuration
Returns:
Tuple of (job, created):
- If created: (new_job, True)
- If conflict (existing active job): (existing_job, False)
- If error: raises DatabaseError
"""
try:
# First, try to find an existing active job
existing = await self.get_active_jobs(tenant_id=tenant_id)
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
if existing or pending:
# Return existing job
active_job = existing[0] if existing else pending[0]
logger.info("Found existing active job, skipping creation",
existing_job_id=active_job.job_id,
tenant_id=tenant_id,
requested_job_id=job_id)
return (active_job, False)
# Try to create the new job
# If another pod created one in the meantime, the unique constraint will prevent this
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": "pending",
"progress": 0,
"current_step": "initializing",
"config": config or {}
}
try:
new_job = await self.create_training_log(log_data)
await self.session.commit()
logger.info("Created new training job atomically",
job_id=job_id,
tenant_id=tenant_id)
return (new_job, True)
except Exception as create_error:
error_str = str(create_error).lower()
# Check if this is a unique constraint violation
if "unique" in error_str or "duplicate" in error_str or "constraint" in error_str:
await self.session.rollback()
# Another pod created a job, fetch it
logger.info("Unique constraint hit, fetching existing job",
tenant_id=tenant_id,
requested_job_id=job_id)
existing = await self.get_active_jobs(tenant_id=tenant_id)
pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1)
if existing or pending:
active_job = existing[0] if existing else pending[0]
return (active_job, False)
# If still no job found, something went wrong
raise DatabaseError(f"Constraint violation but no active job found: {create_error}")
else:
raise
except DatabaseError:
raise
except Exception as e:
logger.error("Failed to create job atomically",
job_id=job_id,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to create training job atomically: {str(e)}")
async def recover_stale_jobs(self, stale_threshold_minutes: int = 60) -> List[ModelTrainingLog]:
"""
Find and mark stale running jobs as failed.
This is used during service startup to clean up jobs that were
running when a pod crashed. With multiple replicas, only stale
jobs (not updated recently) should be marked as failed.
Args:
stale_threshold_minutes: Jobs not updated for this long are considered stale
Returns:
List of jobs that were marked as failed
"""
try:
stale_cutoff = datetime.now() - timedelta(minutes=stale_threshold_minutes)
# Find running jobs that haven't been updated recently
query = text("""
SELECT id, job_id, tenant_id, status, updated_at
FROM model_training_logs
WHERE status IN ('running', 'pending')
AND updated_at < :stale_cutoff
""")
result = await self.session.execute(query, {"stale_cutoff": stale_cutoff})
stale_jobs = result.fetchall()
recovered_jobs = []
for row in stale_jobs:
try:
# Mark as failed
update_query = text("""
UPDATE model_training_logs
SET status = 'failed',
error_message = :error_msg,
end_time = :end_time,
updated_at = :updated_at
WHERE id = :id AND status IN ('running', 'pending')
""")
await self.session.execute(update_query, {
"id": row.id,
"error_msg": f"Job recovered as failed - not updated since {row.updated_at.isoformat()}. Pod may have crashed.",
"end_time": datetime.now(),
"updated_at": datetime.now()
})
logger.warning("Recovered stale training job",
job_id=row.job_id,
tenant_id=str(row.tenant_id),
last_updated=row.updated_at.isoformat() if row.updated_at else "unknown")
# Fetch the updated job to return
job = await self.get_by_job_id(row.job_id)
if job:
recovered_jobs.append(job)
except Exception as job_error:
logger.error("Failed to recover individual stale job",
job_id=row.job_id,
error=str(job_error))
if recovered_jobs:
await self.session.commit()
logger.info("Stale job recovery completed",
recovered_count=len(recovered_jobs),
stale_threshold_minutes=stale_threshold_minutes)
return recovered_jobs
except Exception as e:
logger.error("Failed to recover stale jobs",
error=str(e))
await self.session.rollback()
return []