Files
bakery-ia/services/training/app/repositories/training_log_repository.py
2025-08-08 09:08:41 +02:00

332 lines
13 KiB
Python

"""
Training Log Repository
Repository for model training log operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelTrainingLog
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrainingLogRepository(TrainingBaseRepository):
"""Repository for training log operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training logs change frequently, shorter cache time (5 minutes)
super().__init__(ModelTrainingLog, session, cache_ttl)
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training log entry"""
try:
# Validate log data
validation_result = self._validate_training_data(
log_data,
["job_id", "tenant_id", "status"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
# Set default values
if "progress" not in log_data:
log_data["progress"] = 0
if "current_step" not in log_data:
log_data["current_step"] = "initializing"
# Create log entry
log_entry = await self.create(log_data)
logger.info("Training log created",
job_id=log_entry.job_id,
tenant_id=log_entry.tenant_id,
status=log_entry.status)
return log_entry
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create training log",
job_id=log_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to create training log: {str(e)}")
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
"""Get training log by job ID"""
return await self.get_by_job_id(job_id)
async def update_log_progress(
self,
job_id: str,
progress: int,
current_step: str = None,
status: str = None
) -> Optional[ModelTrainingLog]:
"""Update training log progress"""
try:
update_data = {"progress": progress, "updated_at": datetime.now()}
if current_step:
update_data["current_step"] = current_step
if status:
update_data["status"] = status
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.debug("Training log progress updated",
job_id=job_id,
progress=progress,
step=current_step)
return updated_log
except Exception as e:
logger.error("Failed to update training log progress",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to update progress: {str(e)}")
async def complete_training_log(
self,
job_id: str,
results: Dict[str, Any] = None,
error_message: str = None
) -> Optional[ModelTrainingLog]:
"""Mark training log as completed or failed"""
try:
status = "failed" if error_message else "completed"
update_data = {
"status": status,
"progress": 100 if status == "completed" else None,
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if results:
update_data["results"] = results
if error_message:
update_data["error_message"] = error_message
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training log completed",
job_id=job_id,
status=status,
has_results=bool(results))
return updated_log
except Exception as e:
logger.error("Failed to complete training log",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete training log: {str(e)}")
async def get_logs_by_tenant(
self,
tenant_id: str,
status: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelTrainingLog]:
"""Get training logs for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get logs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get training logs: {str(e)}")
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
"""Get currently running training jobs"""
try:
filters = {"status": "running"}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="start_time",
order_desc=True
)
except Exception as e:
logger.error("Failed to get active jobs",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
"""Cancel a training job"""
try:
update_data = {
"status": "cancelled",
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if cancelled_by:
update_data["error_message"] = f"Cancelled by {cancelled_by}"
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
# Only cancel if job is still running
if log_entry.status not in ["pending", "running"]:
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
return log_entry
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_log
except Exception as e:
logger.error("Failed to cancel training job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get training job statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
total_jobs = await self.count(filters=base_filters)
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
# Get recent activity (jobs in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_jobs = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
# Calculate success rate
finished_jobs = completed_jobs + failed_jobs
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
return {
"total_jobs": total_jobs,
"completed_jobs": completed_jobs,
"failed_jobs": failed_jobs,
"running_jobs": running_jobs,
"pending_jobs": pending_jobs,
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
"success_rate": round(success_rate, 2),
"recent_jobs_7d": recent_jobs
}
except Exception as e:
logger.error("Failed to get job statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_jobs": 0,
"completed_jobs": 0,
"failed_jobs": 0,
"running_jobs": 0,
"pending_jobs": 0,
"cancelled_jobs": 0,
"success_rate": 0.0,
"recent_jobs_7d": 0
}
async def cleanup_old_logs(self, days_old: int = 90) -> int:
"""Clean up old completed/failed training logs"""
return await self.cleanup_old_records(
days_old=days_old,
status_filter=None # Clean up all old records regardless of status
)
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get job duration statistics"""
try:
# Use raw SQL for complex duration calculations
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
params = {"tenant_id": tenant_id} if tenant_id else {}
query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
COUNT(*) as completed_jobs_with_duration
FROM model_training_logs
WHERE status = 'completed'
AND start_time IS NOT NULL
AND end_time IS NOT NULL
{tenant_filter}
""")
result = await self.session.execute(query, params)
row = result.fetchone()
if row and row.completed_jobs_with_duration > 0:
return {
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
}
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
except Exception as e:
logger.error("Failed to get job duration statistics",
tenant_id=tenant_id,
error=str(e))
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}