REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -0,0 +1,332 @@
"""
Training Log Repository
Repository for model training log operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelTrainingLog
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrainingLogRepository(TrainingBaseRepository):
"""Repository for training log operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training logs change frequently, shorter cache time (5 minutes)
super().__init__(ModelTrainingLog, session, cache_ttl)
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training log entry"""
try:
# Validate log data
validation_result = self._validate_training_data(
log_data,
["job_id", "tenant_id", "status"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
# Set default values
if "progress" not in log_data:
log_data["progress"] = 0
if "current_step" not in log_data:
log_data["current_step"] = "initializing"
# Create log entry
log_entry = await self.create(log_data)
logger.info("Training log created",
job_id=log_entry.job_id,
tenant_id=log_entry.tenant_id,
status=log_entry.status)
return log_entry
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create training log",
job_id=log_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to create training log: {str(e)}")
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
"""Get training log by job ID"""
return await self.get_by_job_id(job_id)
async def update_log_progress(
self,
job_id: str,
progress: int,
current_step: str = None,
status: str = None
) -> Optional[ModelTrainingLog]:
"""Update training log progress"""
try:
update_data = {"progress": progress, "updated_at": datetime.now()}
if current_step:
update_data["current_step"] = current_step
if status:
update_data["status"] = status
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.debug("Training log progress updated",
job_id=job_id,
progress=progress,
step=current_step)
return updated_log
except Exception as e:
logger.error("Failed to update training log progress",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to update progress: {str(e)}")
async def complete_training_log(
self,
job_id: str,
results: Dict[str, Any] = None,
error_message: str = None
) -> Optional[ModelTrainingLog]:
"""Mark training log as completed or failed"""
try:
status = "failed" if error_message else "completed"
update_data = {
"status": status,
"progress": 100 if status == "completed" else None,
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if results:
update_data["results"] = results
if error_message:
update_data["error_message"] = error_message
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training log completed",
job_id=job_id,
status=status,
has_results=bool(results))
return updated_log
except Exception as e:
logger.error("Failed to complete training log",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete training log: {str(e)}")
async def get_logs_by_tenant(
self,
tenant_id: str,
status: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelTrainingLog]:
"""Get training logs for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get logs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get training logs: {str(e)}")
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
"""Get currently running training jobs"""
try:
filters = {"status": "running"}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="start_time",
order_desc=True
)
except Exception as e:
logger.error("Failed to get active jobs",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
"""Cancel a training job"""
try:
update_data = {
"status": "cancelled",
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if cancelled_by:
update_data["error_message"] = f"Cancelled by {cancelled_by}"
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
# Only cancel if job is still running
if log_entry.status not in ["pending", "running"]:
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
return log_entry
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_log
except Exception as e:
logger.error("Failed to cancel training job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get training job statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
total_jobs = await self.count(filters=base_filters)
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
# Get recent activity (jobs in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_jobs = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
# Calculate success rate
finished_jobs = completed_jobs + failed_jobs
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
return {
"total_jobs": total_jobs,
"completed_jobs": completed_jobs,
"failed_jobs": failed_jobs,
"running_jobs": running_jobs,
"pending_jobs": pending_jobs,
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
"success_rate": round(success_rate, 2),
"recent_jobs_7d": recent_jobs
}
except Exception as e:
logger.error("Failed to get job statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_jobs": 0,
"completed_jobs": 0,
"failed_jobs": 0,
"running_jobs": 0,
"pending_jobs": 0,
"cancelled_jobs": 0,
"success_rate": 0.0,
"recent_jobs_7d": 0
}
async def cleanup_old_logs(self, days_old: int = 90) -> int:
"""Clean up old completed/failed training logs"""
return await self.cleanup_old_records(
days_old=days_old,
status_filter=None # Clean up all old records regardless of status
)
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get job duration statistics"""
try:
# Use raw SQL for complex duration calculations
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
params = {"tenant_id": tenant_id} if tenant_id else {}
query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
COUNT(*) as completed_jobs_with_duration
FROM model_training_logs
WHERE status = 'completed'
AND start_time IS NOT NULL
AND end_time IS NOT NULL
{tenant_filter}
""")
result = await self.session.execute(query, params)
row = result.fetchone()
if row and row.completed_jobs_with_duration > 0:
return {
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
}
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
except Exception as e:
logger.error("Failed to get job duration statistics",
tenant_id=tenant_id,
error=str(e))
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}