""" Training Log Repository Repository for model training log operations """ from typing import Optional, List, Dict, Any from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, text, desc from datetime import datetime, timedelta import structlog from .base import TrainingBaseRepository from app.models.training import ModelTrainingLog from shared.database.exceptions import DatabaseError, ValidationError logger = structlog.get_logger() class TrainingLogRepository(TrainingBaseRepository): """Repository for training log operations""" def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300): # Training logs change frequently, shorter cache time (5 minutes) super().__init__(ModelTrainingLog, session, cache_ttl) async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog: """Create a new training log entry""" try: # Validate log data validation_result = self._validate_training_data( log_data, ["job_id", "tenant_id", "status"] ) if not validation_result["is_valid"]: raise ValidationError(f"Invalid training log data: {validation_result['errors']}") # Set default values if "progress" not in log_data: log_data["progress"] = 0 if "current_step" not in log_data: log_data["current_step"] = "initializing" # Create log entry log_entry = await self.create(log_data) logger.info("Training log created", job_id=log_entry.job_id, tenant_id=log_entry.tenant_id, status=log_entry.status) return log_entry except ValidationError: raise except Exception as e: logger.error("Failed to create training log", job_id=log_data.get("job_id"), error=str(e)) raise DatabaseError(f"Failed to create training log: {str(e)}") async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]: """Get training log by job ID""" return await self.get_by_job_id(job_id) async def update_log_progress( self, job_id: str, progress: int, current_step: str = None, status: str = None ) -> Optional[ModelTrainingLog]: """Update training log progress""" try: update_data = {"progress": progress, "updated_at": datetime.now()} if current_step: update_data["current_step"] = current_step if status: update_data["status"] = status log_entry = await self.get_by_job_id(job_id) if not log_entry: logger.error(f"Training log not found for job {job_id}") return None updated_log = await self.update(log_entry.id, update_data) logger.debug("Training log progress updated", job_id=job_id, progress=progress, step=current_step) return updated_log except Exception as e: logger.error("Failed to update training log progress", job_id=job_id, error=str(e)) raise DatabaseError(f"Failed to update progress: {str(e)}") async def complete_training_log( self, job_id: str, results: Dict[str, Any] = None, error_message: str = None ) -> Optional[ModelTrainingLog]: """Mark training log as completed or failed""" try: status = "failed" if error_message else "completed" update_data = { "status": status, "progress": 100 if status == "completed" else None, "end_time": datetime.now(), "updated_at": datetime.now() } if results: update_data["results"] = results if error_message: update_data["error_message"] = error_message log_entry = await self.get_by_job_id(job_id) if not log_entry: logger.error(f"Training log not found for job {job_id}") return None updated_log = await self.update(log_entry.id, update_data) logger.info("Training log completed", job_id=job_id, status=status, has_results=bool(results)) return updated_log except Exception as e: logger.error("Failed to complete training log", job_id=job_id, error=str(e)) raise DatabaseError(f"Failed to complete training log: {str(e)}") async def get_logs_by_tenant( self, tenant_id: str, status: str = None, skip: int = 0, limit: int = 100 ) -> List[ModelTrainingLog]: """Get training logs for a tenant""" try: filters = {"tenant_id": tenant_id} if status: filters["status"] = status return await self.get_multi( filters=filters, skip=skip, limit=limit, order_by="created_at", order_desc=True ) except Exception as e: logger.error("Failed to get logs by tenant", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to get training logs: {str(e)}") async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]: """Get currently running training jobs""" try: filters = {"status": "running"} if tenant_id: filters["tenant_id"] = tenant_id return await self.get_multi( filters=filters, order_by="start_time", order_desc=True ) except Exception as e: logger.error("Failed to get active jobs", tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to get active jobs: {str(e)}") async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]: """Cancel a training job""" try: update_data = { "status": "cancelled", "end_time": datetime.now(), "updated_at": datetime.now() } if cancelled_by: update_data["error_message"] = f"Cancelled by {cancelled_by}" log_entry = await self.get_by_job_id(job_id) if not log_entry: logger.error(f"Training log not found for job {job_id}") return None # Only cancel if job is still running if log_entry.status not in ["pending", "running"]: logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}") return log_entry updated_log = await self.update(log_entry.id, update_data) logger.info("Training job cancelled", job_id=job_id, cancelled_by=cancelled_by) return updated_log except Exception as e: logger.error("Failed to cancel training job", job_id=job_id, error=str(e)) raise DatabaseError(f"Failed to cancel job: {str(e)}") async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]: """Get training job statistics""" try: base_filters = {} if tenant_id: base_filters["tenant_id"] = tenant_id # Get counts by status total_jobs = await self.count(filters=base_filters) completed_jobs = await self.count(filters={**base_filters, "status": "completed"}) failed_jobs = await self.count(filters={**base_filters, "status": "failed"}) running_jobs = await self.count(filters={**base_filters, "status": "running"}) pending_jobs = await self.count(filters={**base_filters, "status": "pending"}) # Get recent activity (jobs in last 7 days) seven_days_ago = datetime.now() - timedelta(days=7) recent_jobs = len(await self.get_records_by_date_range( seven_days_ago, datetime.now(), limit=1000 # High limit to get accurate count )) # Calculate success rate finished_jobs = completed_jobs + failed_jobs success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0 return { "total_jobs": total_jobs, "completed_jobs": completed_jobs, "failed_jobs": failed_jobs, "running_jobs": running_jobs, "pending_jobs": pending_jobs, "cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs, "success_rate": round(success_rate, 2), "recent_jobs_7d": recent_jobs } except Exception as e: logger.error("Failed to get job statistics", tenant_id=tenant_id, error=str(e)) return { "total_jobs": 0, "completed_jobs": 0, "failed_jobs": 0, "running_jobs": 0, "pending_jobs": 0, "cancelled_jobs": 0, "success_rate": 0.0, "recent_jobs_7d": 0 } async def cleanup_old_logs(self, days_old: int = 90) -> int: """Clean up old completed/failed training logs""" return await self.cleanup_old_records( days_old=days_old, status_filter=None # Clean up all old records regardless of status ) async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]: """Get job duration statistics""" try: # Use raw SQL for complex duration calculations tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else "" params = {"tenant_id": tenant_id} if tenant_id else {} query = text(f""" SELECT AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes, MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes, MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes, COUNT(*) as completed_jobs_with_duration FROM model_training_logs WHERE status = 'completed' AND start_time IS NOT NULL AND end_time IS NOT NULL {tenant_filter} """) result = await self.session.execute(query, params) row = result.fetchone() if row and row.completed_jobs_with_duration > 0: return { "avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2), "min_duration_minutes": round(float(row.min_duration_minutes or 0), 2), "max_duration_minutes": round(float(row.max_duration_minutes or 0), 2), "completed_jobs_with_duration": int(row.completed_jobs_with_duration) } return { "avg_duration_minutes": 0.0, "min_duration_minutes": 0.0, "max_duration_minutes": 0.0, "completed_jobs_with_duration": 0 } except Exception as e: logger.error("Failed to get job duration statistics", tenant_id=tenant_id, error=str(e)) return { "avg_duration_minutes": 0.0, "min_duration_minutes": 0.0, "max_duration_minutes": 0.0, "completed_jobs_with_duration": 0 } async def get_start_time(self, job_id: str) -> Optional[datetime]: """Get the start time for a training job""" try: log_entry = await self.get_by_job_id(job_id) if log_entry and log_entry.start_time: return log_entry.start_time return None except Exception as e: logger.error("Failed to get start time", job_id=job_id, error=str(e)) return None async def create_job_atomic( self, job_id: str, tenant_id: str, config: Dict[str, Any] = None ) -> tuple[Optional[ModelTrainingLog], bool]: """ Atomically create a training job, respecting the unique constraint. This method uses INSERT ... ON CONFLICT to handle race conditions when multiple pods try to create a job for the same tenant simultaneously. The database constraint (idx_unique_active_training_per_tenant) ensures only one active job per tenant can exist. Args: job_id: Unique job identifier tenant_id: Tenant identifier config: Optional job configuration Returns: Tuple of (job, created): - If created: (new_job, True) - If conflict (existing active job): (existing_job, False) - If error: raises DatabaseError """ try: # First, try to find an existing active job existing = await self.get_active_jobs(tenant_id=tenant_id) pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1) if existing or pending: # Return existing job active_job = existing[0] if existing else pending[0] logger.info("Found existing active job, skipping creation", existing_job_id=active_job.job_id, tenant_id=tenant_id, requested_job_id=job_id) return (active_job, False) # Try to create the new job # If another pod created one in the meantime, the unique constraint will prevent this log_data = { "job_id": job_id, "tenant_id": tenant_id, "status": "pending", "progress": 0, "current_step": "initializing", "config": config or {} } try: new_job = await self.create_training_log(log_data) await self.session.commit() logger.info("Created new training job atomically", job_id=job_id, tenant_id=tenant_id) return (new_job, True) except Exception as create_error: error_str = str(create_error).lower() # Check if this is a unique constraint violation if "unique" in error_str or "duplicate" in error_str or "constraint" in error_str: await self.session.rollback() # Another pod created a job, fetch it logger.info("Unique constraint hit, fetching existing job", tenant_id=tenant_id, requested_job_id=job_id) existing = await self.get_active_jobs(tenant_id=tenant_id) pending = await self.get_logs_by_tenant(tenant_id=tenant_id, status="pending", limit=1) if existing or pending: active_job = existing[0] if existing else pending[0] return (active_job, False) # If still no job found, something went wrong raise DatabaseError(f"Constraint violation but no active job found: {create_error}") else: raise except DatabaseError: raise except Exception as e: logger.error("Failed to create job atomically", job_id=job_id, tenant_id=tenant_id, error=str(e)) raise DatabaseError(f"Failed to create training job atomically: {str(e)}") async def recover_stale_jobs(self, stale_threshold_minutes: int = 60) -> List[ModelTrainingLog]: """ Find and mark stale running jobs as failed. This is used during service startup to clean up jobs that were running when a pod crashed. With multiple replicas, only stale jobs (not updated recently) should be marked as failed. Args: stale_threshold_minutes: Jobs not updated for this long are considered stale Returns: List of jobs that were marked as failed """ try: stale_cutoff = datetime.now() - timedelta(minutes=stale_threshold_minutes) # Find running jobs that haven't been updated recently query = text(""" SELECT id, job_id, tenant_id, status, updated_at FROM model_training_logs WHERE status IN ('running', 'pending') AND updated_at < :stale_cutoff """) result = await self.session.execute(query, {"stale_cutoff": stale_cutoff}) stale_jobs = result.fetchall() recovered_jobs = [] for row in stale_jobs: try: # Mark as failed update_query = text(""" UPDATE model_training_logs SET status = 'failed', error_message = :error_msg, end_time = :end_time, updated_at = :updated_at WHERE id = :id AND status IN ('running', 'pending') """) await self.session.execute(update_query, { "id": row.id, "error_msg": f"Job recovered as failed - not updated since {row.updated_at.isoformat()}. Pod may have crashed.", "end_time": datetime.now(), "updated_at": datetime.now() }) logger.warning("Recovered stale training job", job_id=row.job_id, tenant_id=str(row.tenant_id), last_updated=row.updated_at.isoformat() if row.updated_at else "unknown") # Fetch the updated job to return job = await self.get_by_job_id(row.job_id) if job: recovered_jobs.append(job) except Exception as job_error: logger.error("Failed to recover individual stale job", job_id=row.job_id, error=str(job_error)) if recovered_jobs: await self.session.commit() logger.info("Stale job recovery completed", recovered_count=len(recovered_jobs), stale_threshold_minutes=stale_threshold_minutes) return recovered_jobs except Exception as e: logger.error("Failed to recover stale jobs", error=str(e)) await self.session.rollback() return []