REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -0,0 +1,20 @@
"""
Training Service Repositories
Repository implementations for training service
"""
from .base import TrainingBaseRepository
from .model_repository import ModelRepository
from .training_log_repository import TrainingLogRepository
from .performance_repository import PerformanceRepository
from .job_queue_repository import JobQueueRepository
from .artifact_repository import ArtifactRepository
__all__ = [
"TrainingBaseRepository",
"ModelRepository",
"TrainingLogRepository",
"PerformanceRepository",
"JobQueueRepository",
"ArtifactRepository"
]

View File

@@ -0,0 +1,433 @@
"""
Artifact Repository
Repository for model artifact operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelArtifact
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ArtifactRepository(TrainingBaseRepository):
"""Repository for model artifact operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
# Artifacts are stable, longer cache time (30 minutes)
super().__init__(ModelArtifact, session, cache_ttl)
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
"""Create a new model artifact record"""
try:
# Validate artifact data
validation_result = self._validate_training_data(
artifact_data,
["model_id", "tenant_id", "artifact_type", "file_path"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
# Set default values
if "storage_location" not in artifact_data:
artifact_data["storage_location"] = "local"
# Create artifact record
artifact = await self.create(artifact_data)
logger.info("Model artifact created",
model_id=artifact.model_id,
tenant_id=artifact.tenant_id,
artifact_type=artifact.artifact_type,
file_path=artifact.file_path)
return artifact
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create model artifact",
model_id=artifact_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create artifact: {str(e)}")
async def get_artifacts_by_model(
self,
model_id: str,
artifact_type: str = None
) -> List[ModelArtifact]:
"""Get all artifacts for a model"""
try:
filters = {"model_id": model_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by model",
model_id=model_id,
artifact_type=artifact_type,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifacts_by_tenant(
self,
tenant_id: str,
artifact_type: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelArtifact]:
"""Get artifacts for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
"""Get artifact by file path"""
try:
return await self.get_by_field("file_path", file_path)
except Exception as e:
logger.error("Failed to get artifact by path",
file_path=file_path,
error=str(e))
raise DatabaseError(f"Failed to get artifact: {str(e)}")
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
"""Update artifact file size"""
try:
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
except Exception as e:
logger.error("Failed to update artifact size",
artifact_id=artifact_id,
error=str(e))
return None
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
"""Update artifact checksum for integrity verification"""
try:
return await self.update(artifact_id, {"checksum": checksum})
except Exception as e:
logger.error("Failed to update artifact checksum",
artifact_id=artifact_id,
error=str(e))
return None
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
"""Mark artifact for expiration/cleanup"""
try:
if not expires_at:
expires_at = datetime.now()
return await self.update(artifact_id, {"expires_at": expires_at})
except Exception as e:
logger.error("Failed to mark artifact as expired",
artifact_id=artifact_id,
error=str(e))
return None
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
"""Get artifacts that have expired"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
SELECT * FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
ORDER BY expires_at ASC
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
expired_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
expired_artifacts.append(artifact)
return expired_artifacts
except Exception as e:
logger.error("Failed to get expired artifacts",
days_expired=days_expired,
error=str(e))
return []
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
"""Clean up expired artifacts"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
DELETE FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up expired artifacts",
deleted_count=deleted_count,
days_expired=days_expired)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired artifacts",
days_expired=days_expired,
error=str(e))
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
"""Get artifacts larger than specified size"""
try:
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
query_text = """
SELECT * FROM model_artifacts
WHERE file_size_bytes >= :min_size_bytes
ORDER BY file_size_bytes DESC
"""
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
large_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
large_artifacts.append(artifact)
return large_artifacts
except Exception as e:
logger.error("Failed to get large artifacts",
min_size_mb=min_size_mb,
error=str(e))
return []
async def get_artifacts_by_storage_location(
self,
storage_location: str,
tenant_id: str = None
) -> List[ModelArtifact]:
"""Get artifacts by storage location"""
try:
filters = {"storage_location": storage_location}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by storage location",
storage_location=storage_location,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get artifact statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get basic counts
total_artifacts = await self.count(filters=base_filters)
# Get artifacts by type
type_query_params = {}
type_query_filter = ""
if tenant_id:
type_query_filter = "WHERE tenant_id = :tenant_id"
type_query_params["tenant_id"] = tenant_id
type_query = text(f"""
SELECT artifact_type, COUNT(*) as count
FROM model_artifacts
{type_query_filter}
GROUP BY artifact_type
ORDER BY count DESC
""")
result = await self.session.execute(type_query, type_query_params)
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
# Get storage location stats
location_query = text(f"""
SELECT
storage_location,
COUNT(*) as count,
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
FROM model_artifacts
{type_query_filter}
GROUP BY storage_location
ORDER BY count DESC
""")
location_result = await self.session.execute(location_query, type_query_params)
storage_stats = {}
total_size_bytes = 0
for row in location_result.fetchall():
storage_stats[row.storage_location] = {
"artifact_count": row.count,
"total_size_bytes": int(row.total_size_bytes or 0),
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
}
total_size_bytes += row.total_size_bytes or 0
# Get expired artifacts count
expired_artifacts = len(await self.get_expired_artifacts())
return {
"total_artifacts": total_artifacts,
"expired_artifacts": expired_artifacts,
"active_artifacts": total_artifacts - expired_artifacts,
"artifacts_by_type": artifacts_by_type,
"storage_statistics": storage_stats,
"total_storage": {
"total_size_bytes": total_size_bytes,
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
}
}
except Exception as e:
logger.error("Failed to get artifact statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_artifacts": 0,
"expired_artifacts": 0,
"active_artifacts": 0,
"artifacts_by_type": {},
"storage_statistics": {},
"total_storage": {
"total_size_bytes": 0,
"total_size_mb": 0.0,
"total_size_gb": 0.0
}
}
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
"""Verify artifact file integrity (placeholder for file system checks)"""
try:
artifact = await self.get_by_id(artifact_id)
if not artifact:
return {"exists": False, "error": "Artifact not found"}
# This is a placeholder - in a real implementation, you would:
# 1. Check if the file exists at artifact.file_path
# 2. Calculate current checksum and compare with stored checksum
# 3. Verify file size matches stored file_size_bytes
return {
"artifact_id": artifact_id,
"file_path": artifact.file_path,
"exists": True, # Would check actual file existence
"checksum_valid": True, # Would verify actual checksum
"size_valid": True, # Would verify actual file size
"storage_location": artifact.storage_location,
"last_verified": datetime.now().isoformat()
}
except Exception as e:
logger.error("Failed to verify artifact integrity",
artifact_id=artifact_id,
error=str(e))
return {
"exists": False,
"error": f"Verification failed: {str(e)}"
}
async def migrate_artifacts_to_storage(
self,
from_location: str,
to_location: str,
tenant_id: str = None
) -> Dict[str, Any]:
"""Migrate artifacts from one storage location to another (placeholder)"""
try:
# Get artifacts to migrate
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
migrated_count = 0
failed_count = 0
# This is a placeholder - in a real implementation, you would:
# 1. Copy files from old location to new location
# 2. Update file paths in database
# 3. Verify successful migration
# 4. Clean up old files
for artifact in artifacts:
try:
# Placeholder migration logic
new_file_path = artifact.file_path.replace(from_location, to_location)
await self.update(artifact.id, {
"storage_location": to_location,
"file_path": new_file_path
})
migrated_count += 1
except Exception as migration_error:
logger.error("Failed to migrate artifact",
artifact_id=artifact.id,
error=str(migration_error))
failed_count += 1
logger.info("Artifact migration completed",
from_location=from_location,
to_location=to_location,
migrated_count=migrated_count,
failed_count=failed_count)
return {
"from_location": from_location,
"to_location": to_location,
"total_artifacts": len(artifacts),
"migrated_count": migrated_count,
"failed_count": failed_count,
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100
}
except Exception as e:
logger.error("Failed to migrate artifacts",
from_location=from_location,
to_location=to_location,
error=str(e))
return {
"error": f"Migration failed: {str(e)}"
}

View File

@@ -0,0 +1,179 @@
"""
Base Repository for Training Service
Service-specific repository base class with training service utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timedelta
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TrainingBaseRepository(BaseRepository):
"""Base repository for training service with common training operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training data changes frequently, shorter cache time (5 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_job_id(self, job_id: str) -> Optional:
"""Get record by job ID (if model has job_id field)"""
if hasattr(self.model, 'job_id'):
return await self.get_by_field("job_id", job_id)
return None
async def get_by_model_id(self, model_id: str) -> Optional:
"""Get record by model ID (if model has model_id field)"""
if hasattr(self.model, 'model_id'):
return await self.get_by_field("model_id", model_id)
return None
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
"""Clean up old training records"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
table_name = self.model.__tablename__
# Build query based on available fields
conditions = [f"created_at < :cutoff_date"]
params = {"cutoff_date": cutoff_date}
if status_filter and hasattr(self.model, 'status'):
conditions.append(f"status = :status")
params["status"] = status_filter
query_text = f"""
DELETE FROM {table_name}
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_records_by_date_range(
self,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range"""
if not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no created_at field")
return []
try:
table_name = self.model.__tablename__
query_text = f"""
SELECT * FROM {table_name}
WHERE created_at >= :start_date
AND created_at <= :end_date
ORDER BY created_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
# Create model instance from row data
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate training-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate job_id format if present
if "job_id" in data and data["job_id"]:
job_id = data["job_id"]
if not isinstance(job_id, str) or len(job_id) < 1:
errors.append("Invalid job_id format")
return {
"is_valid": len(errors) == 0,
"errors": errors
}

View File

@@ -0,0 +1,445 @@
"""
Job Queue Repository
Repository for training job queue operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import TrainingJobQueue
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class JobQueueRepository(TrainingBaseRepository):
"""Repository for training job queue operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
# Job queue changes frequently, very short cache time (1 minute)
super().__init__(TrainingJobQueue, session, cache_ttl)
async def enqueue_job(self, job_data: Dict[str, Any]) -> TrainingJobQueue:
"""Add a job to the training queue"""
try:
# Validate job data
validation_result = self._validate_training_data(
job_data,
["job_id", "tenant_id", "job_type"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid job data: {validation_result['errors']}")
# Set default values
if "priority" not in job_data:
job_data["priority"] = 1
if "status" not in job_data:
job_data["status"] = "queued"
if "max_retries" not in job_data:
job_data["max_retries"] = 3
# Create queue entry
queued_job = await self.create(job_data)
logger.info("Job enqueued",
job_id=queued_job.job_id,
tenant_id=queued_job.tenant_id,
job_type=queued_job.job_type,
priority=queued_job.priority)
return queued_job
except ValidationError:
raise
except Exception as e:
logger.error("Failed to enqueue job",
job_id=job_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to enqueue job: {str(e)}")
async def get_next_job(self, job_types: List[str] = None) -> Optional[TrainingJobQueue]:
"""Get the next job to process from the queue"""
try:
# Build filters for job types if specified
filters = {"status": "queued"}
if job_types:
# For multiple job types, we need to use raw SQL
job_types_str = "', '".join(job_types)
query_text = f"""
SELECT * FROM training_job_queue
WHERE status = 'queued'
AND job_type IN ('{job_types_str}')
AND (scheduled_at IS NULL OR scheduled_at <= :now)
ORDER BY priority DESC, created_at ASC
LIMIT 1
"""
result = await self.session.execute(text(query_text), {"now": datetime.now()})
row = result.fetchone()
if row:
record_dict = dict(row._mapping)
return self.model(**record_dict)
return None
else:
# Simple case - get any queued job
jobs = await self.get_multi(
filters=filters,
limit=1,
order_by="priority",
order_desc=True
)
return jobs[0] if jobs else None
except Exception as e:
logger.error("Failed to get next job from queue",
job_types=job_types,
error=str(e))
raise DatabaseError(f"Failed to get next job: {str(e)}")
async def start_job(self, job_id: str) -> Optional[TrainingJobQueue]:
"""Mark a job as started"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
if job.status != "queued":
logger.warning(f"Job {job_id} is not queued (status: {job.status})")
return job
updated_job = await self.update(job.id, {
"status": "running",
"started_at": datetime.now(),
"updated_at": datetime.now()
})
logger.info("Job started",
job_id=job_id,
job_type=job.job_type)
return updated_job
except Exception as e:
logger.error("Failed to start job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to start job: {str(e)}")
async def complete_job(self, job_id: str) -> Optional[TrainingJobQueue]:
"""Mark a job as completed"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
updated_job = await self.update(job.id, {
"status": "completed",
"updated_at": datetime.now()
})
logger.info("Job completed",
job_id=job_id,
job_type=job.job_type if job else "unknown")
return updated_job
except Exception as e:
logger.error("Failed to complete job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete job: {str(e)}")
async def fail_job(self, job_id: str, error_message: str = None) -> Optional[TrainingJobQueue]:
"""Mark a job as failed and handle retries"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
# Increment retry count
new_retry_count = job.retry_count + 1
# Check if we should retry
if new_retry_count < job.max_retries:
# Reset to queued for retry
updated_job = await self.update(job.id, {
"status": "queued",
"retry_count": new_retry_count,
"updated_at": datetime.now(),
"started_at": None # Reset started_at for retry
})
logger.info("Job failed, queued for retry",
job_id=job_id,
retry_count=new_retry_count,
max_retries=job.max_retries)
else:
# Mark as permanently failed
updated_job = await self.update(job.id, {
"status": "failed",
"retry_count": new_retry_count,
"updated_at": datetime.now()
})
logger.error("Job permanently failed",
job_id=job_id,
retry_count=new_retry_count,
error_message=error_message)
return updated_job
except Exception as e:
logger.error("Failed to handle job failure",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to handle job failure: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[TrainingJobQueue]:
"""Cancel a job"""
try:
job = await self.get_by_job_id(job_id)
if not job:
logger.error(f"Job not found in queue: {job_id}")
return None
if job.status in ["completed", "failed"]:
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
return job
updated_job = await self.update(job.id, {
"status": "cancelled",
"cancelled_by": cancelled_by,
"updated_at": datetime.now()
})
logger.info("Job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_job
except Exception as e:
logger.error("Failed to cancel job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_queue_status(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get queue status and statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
queued_jobs = await self.count(filters={**base_filters, "status": "queued"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
cancelled_jobs = await self.count(filters={**base_filters, "status": "cancelled"})
# Get jobs by type
type_query = text(f"""
SELECT job_type, COUNT(*) as count
FROM training_job_queue
WHERE 1=1
{' AND tenant_id = :tenant_id' if tenant_id else ''}
GROUP BY job_type
ORDER BY count DESC
""")
params = {"tenant_id": tenant_id} if tenant_id else {}
result = await self.session.execute(type_query, params)
jobs_by_type = {row.job_type: row.count for row in result.fetchall()}
# Get average wait time for completed jobs
wait_time_query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (started_at - created_at))/60) as avg_wait_minutes
FROM training_job_queue
WHERE status = 'completed'
AND started_at IS NOT NULL
AND created_at IS NOT NULL
{' AND tenant_id = :tenant_id' if tenant_id else ''}
""")
wait_result = await self.session.execute(wait_time_query, params)
wait_row = wait_result.fetchone()
avg_wait_time = float(wait_row.avg_wait_minutes) if wait_row and wait_row.avg_wait_minutes else 0.0
return {
"tenant_id": tenant_id,
"queue_counts": {
"queued": queued_jobs,
"running": running_jobs,
"completed": completed_jobs,
"failed": failed_jobs,
"cancelled": cancelled_jobs,
"total": queued_jobs + running_jobs + completed_jobs + failed_jobs + cancelled_jobs
},
"jobs_by_type": jobs_by_type,
"avg_wait_time_minutes": round(avg_wait_time, 2),
"queue_health": {
"has_queued_jobs": queued_jobs > 0,
"has_running_jobs": running_jobs > 0,
"failure_rate": round((failed_jobs / max(completed_jobs + failed_jobs, 1)) * 100, 2)
}
}
except Exception as e:
logger.error("Failed to get queue status",
tenant_id=tenant_id,
error=str(e))
return {
"tenant_id": tenant_id,
"queue_counts": {
"queued": 0, "running": 0, "completed": 0,
"failed": 0, "cancelled": 0, "total": 0
},
"jobs_by_type": {},
"avg_wait_time_minutes": 0.0,
"queue_health": {
"has_queued_jobs": False,
"has_running_jobs": False,
"failure_rate": 0.0
}
}
async def get_jobs_by_tenant(
self,
tenant_id: str,
status: str = None,
job_type: str = None,
skip: int = 0,
limit: int = 100
) -> List[TrainingJobQueue]:
"""Get jobs for a tenant with optional filtering"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
if job_type:
filters["job_type"] = job_type
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get jobs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant jobs: {str(e)}")
async def cleanup_old_jobs(self, days_old: int = 30, status_filter: str = None) -> int:
"""Clean up old completed/failed/cancelled jobs"""
try:
cutoff_date = datetime.now() - timedelta(days=days_old)
# Only clean up finished jobs by default
default_statuses = ["completed", "failed", "cancelled"]
if status_filter:
status_condition = "status = :status"
params = {"cutoff_date": cutoff_date, "status": status_filter}
else:
status_list = "', '".join(default_statuses)
status_condition = f"status IN ('{status_list}')"
params = {"cutoff_date": cutoff_date}
query_text = f"""
DELETE FROM training_job_queue
WHERE created_at < :cutoff_date
AND {status_condition}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info("Cleaned up old queue jobs",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old queue jobs",
error=str(e))
raise DatabaseError(f"Queue cleanup failed: {str(e)}")
async def get_stuck_jobs(self, hours_stuck: int = 2) -> List[TrainingJobQueue]:
"""Get jobs that have been running for too long"""
try:
cutoff_time = datetime.now() - timedelta(hours=hours_stuck)
query_text = """
SELECT * FROM training_job_queue
WHERE status = 'running'
AND started_at IS NOT NULL
AND started_at < :cutoff_time
ORDER BY started_at ASC
"""
result = await self.session.execute(text(query_text), {"cutoff_time": cutoff_time})
stuck_jobs = []
for row in result.fetchall():
record_dict = dict(row._mapping)
job = self.model(**record_dict)
stuck_jobs.append(job)
if stuck_jobs:
logger.warning("Found stuck jobs",
count=len(stuck_jobs),
hours_stuck=hours_stuck)
return stuck_jobs
except Exception as e:
logger.error("Failed to get stuck jobs",
hours_stuck=hours_stuck,
error=str(e))
return []
async def reset_stuck_jobs(self, hours_stuck: int = 2) -> int:
"""Reset stuck jobs back to queued status"""
try:
stuck_jobs = await self.get_stuck_jobs(hours_stuck)
reset_count = 0
for job in stuck_jobs:
# Reset job to queued status
await self.update(job.id, {
"status": "queued",
"started_at": None,
"updated_at": datetime.now()
})
reset_count += 1
if reset_count > 0:
logger.info("Reset stuck jobs",
reset_count=reset_count,
hours_stuck=hours_stuck)
return reset_count
except Exception as e:
logger.error("Failed to reset stuck jobs",
hours_stuck=hours_stuck,
error=str(e))
raise DatabaseError(f"Failed to reset stuck jobs: {str(e)}")

View File

@@ -0,0 +1,346 @@
"""
Model Repository
Repository for trained model operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import TrainedModel
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
logger = structlog.get_logger()
class ModelRepository(TrainingBaseRepository):
"""Repository for trained model operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
# Models are relatively stable, longer cache time (10 minutes)
super().__init__(TrainedModel, session, cache_ttl)
async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel:
"""Create a new trained model with validation"""
try:
# Validate model data
validation_result = self._validate_training_data(
model_data,
["tenant_id", "product_name", "model_path", "job_id"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid model data: {validation_result['errors']}")
# Check for duplicate active models for same tenant+product
existing_model = await self.get_active_model_for_product(
model_data["tenant_id"],
model_data["product_name"]
)
# If there's an existing active model, we may want to deactivate it
if existing_model and model_data.get("is_production", False):
logger.info("Deactivating previous production model",
previous_model_id=existing_model.id,
tenant_id=model_data["tenant_id"],
product_name=model_data["product_name"])
await self.update(existing_model.id, {"is_production": False})
# Create new model
model = await self.create(model_data)
logger.info("Trained model created successfully",
model_id=model.id,
tenant_id=model.tenant_id,
product_name=model.product_name,
model_type=model.model_type)
return model
except (ValidationError, DuplicateRecordError):
raise
except Exception as e:
logger.error("Failed to create trained model",
tenant_id=model_data.get("tenant_id"),
product_name=model_data.get("product_name"),
error=str(e))
raise DatabaseError(f"Failed to create model: {str(e)}")
async def get_model_by_tenant_and_product(
self,
tenant_id: str,
product_name: str
) -> List[TrainedModel]:
"""Get all models for a tenant and product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name
},
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get models by tenant and product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get models: {str(e)}")
async def get_active_model_for_product(
self,
tenant_id: str,
product_name: str
) -> Optional[TrainedModel]:
"""Get the active production model for a product"""
try:
models = await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name,
"is_active": True,
"is_production": True
},
order_by="created_at",
order_desc=True,
limit=1
)
return models[0] if models else None
except Exception as e:
logger.error("Failed to get active model for product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get active model: {str(e)}")
async def get_models_by_tenant(
self,
tenant_id: str,
skip: int = 0,
limit: int = 100
) -> List[TrainedModel]:
"""Get all models for a tenant"""
return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit)
async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]:
"""Promote a model to production"""
try:
# Get the model first
model = await self.get_by_id(model_id)
if not model:
raise ValueError(f"Model {model_id} not found")
# Deactivate other production models for the same tenant+product
await self._deactivate_other_production_models(
model.tenant_id,
model.product_name,
model_id
)
# Promote this model
updated_model = await self.update(model_id, {
"is_production": True,
"last_used_at": datetime.utcnow()
})
logger.info("Model promoted to production",
model_id=model_id,
tenant_id=model.tenant_id,
product_name=model.product_name)
return updated_model
except Exception as e:
logger.error("Failed to promote model to production",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to promote model: {str(e)}")
async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]:
"""Update model last used timestamp"""
try:
return await self.update(model_id, {
"last_used_at": datetime.utcnow()
})
except Exception as e:
logger.error("Failed to update model usage",
model_id=model_id,
error=str(e))
# Don't raise here - usage update is not critical
return None
async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int:
"""Archive old non-production models"""
try:
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
query = text("""
UPDATE trained_models
SET is_active = false
WHERE tenant_id = :tenant_id
AND is_production = false
AND created_at < :cutoff_date
AND is_active = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"cutoff_date": cutoff_date
})
archived_count = result.rowcount
logger.info("Archived old models",
tenant_id=tenant_id,
archived_count=archived_count,
days_old=days_old)
return archived_count
except Exception as e:
logger.error("Failed to archive old models",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Model archival failed: {str(e)}")
async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get model statistics for a tenant"""
try:
# Get basic counts
total_models = await self.count(filters={"tenant_id": tenant_id})
active_models = await self.count(filters={
"tenant_id": tenant_id,
"is_active": True
})
production_models = await self.count(filters={
"tenant_id": tenant_id,
"is_production": True
})
# Get models by product using raw query
product_query = text("""
SELECT product_name, COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND is_active = true
GROUP BY product_name
ORDER BY count DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {row.product_name: row.count for row in result.fetchall()}
# Recent activity (models created in last 30 days)
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
recent_models_query = text("""
SELECT COUNT(*) as count
FROM trained_models
WHERE tenant_id = :tenant_id
AND created_at >= :thirty_days_ago
""")
recent_result = await self.session.execute(
recent_models_query,
{"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago}
)
recent_models = recent_result.scalar() or 0
return {
"total_models": total_models,
"active_models": active_models,
"inactive_models": total_models - active_models,
"production_models": production_models,
"models_by_product": product_stats,
"recent_models_30d": recent_models
}
except Exception as e:
logger.error("Failed to get model statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_models": 0,
"active_models": 0,
"inactive_models": 0,
"production_models": 0,
"models_by_product": {},
"recent_models_30d": 0
}
async def _deactivate_other_production_models(
self,
tenant_id: str,
product_name: str,
exclude_model_id: str
) -> int:
"""Deactivate other production models for the same tenant+product"""
try:
query = text("""
UPDATE trained_models
SET is_production = false
WHERE tenant_id = :tenant_id
AND product_name = :product_name
AND id != :exclude_model_id
AND is_production = true
""")
result = await self.session.execute(query, {
"tenant_id": tenant_id,
"product_name": product_name,
"exclude_model_id": exclude_model_id
})
return result.rowcount
except Exception as e:
logger.error("Failed to deactivate other production models",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to deactivate models: {str(e)}")
async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]:
"""Get performance summary for a model"""
try:
model = await self.get_by_id(model_id)
if not model:
return {}
return {
"model_id": model.id,
"tenant_id": model.tenant_id,
"product_name": model.product_name,
"model_type": model.model_type,
"metrics": {
"mape": model.mape,
"mae": model.mae,
"rmse": model.rmse,
"r2_score": model.r2_score
},
"training_info": {
"training_samples": model.training_samples,
"training_start_date": model.training_start_date.isoformat() if model.training_start_date else None,
"training_end_date": model.training_end_date.isoformat() if model.training_end_date else None,
"data_quality_score": model.data_quality_score
},
"status": {
"is_active": model.is_active,
"is_production": model.is_production,
"created_at": model.created_at.isoformat() if model.created_at else None,
"last_used_at": model.last_used_at.isoformat() if model.last_used_at else None
},
"features": {
"hyperparameters": model.hyperparameters,
"features_used": model.features_used
}
}
except Exception as e:
logger.error("Failed to get model performance summary",
model_id=model_id,
error=str(e))
return {}

View File

@@ -0,0 +1,433 @@
"""
Performance Repository
Repository for model performance metrics operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelPerformanceMetric
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class PerformanceRepository(TrainingBaseRepository):
"""Repository for model performance metrics operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
# Performance metrics are relatively stable, longer cache time (15 minutes)
super().__init__(ModelPerformanceMetric, session, cache_ttl)
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
"""Create a new performance metric record"""
try:
# Validate metric data
validation_result = self._validate_training_data(
metric_data,
["model_id", "tenant_id", "product_name"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
# Set measurement timestamp if not provided
if "measured_at" not in metric_data:
metric_data["measured_at"] = datetime.now()
# Create metric record
metric = await self.create(metric_data)
logger.info("Performance metric created",
model_id=metric.model_id,
tenant_id=metric.tenant_id,
product_name=metric.product_name)
return metric
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create performance metric",
model_id=metric_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create metric: {str(e)}")
async def get_metrics_by_model(
self,
model_id: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get all performance metrics for a model"""
try:
return await self.get_multi(
filters={"model_id": model_id},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
"""Get the latest performance metric for a model"""
try:
metrics = await self.get_multi(
filters={"model_id": model_id},
limit=1,
order_by="measured_at",
order_desc=True
)
return metrics[0] if metrics else None
except Exception as e:
logger.error("Failed to get latest metric for model",
model_id=model_id,
error=str(e))
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
async def get_metrics_by_tenant_and_product(
self,
tenant_id: str,
product_name: str,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics for a tenant's product"""
try:
return await self.get_multi(
filters={
"tenant_id": tenant_id,
"product_name": product_name
},
skip=skip,
limit=limit,
order_by="measured_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get metrics by tenant and product",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
raise DatabaseError(f"Failed to get metrics: {str(e)}")
async def get_metrics_in_date_range(
self,
start_date: datetime,
end_date: datetime,
tenant_id: str = None,
model_id: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelPerformanceMetric]:
"""Get performance metrics within a date range"""
try:
# Build filters
table_name = self.model.__tablename__
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
if tenant_id:
conditions.append("tenant_id = :tenant_id")
params["tenant_id"] = tenant_id
if model_id:
conditions.append("model_id = :model_id")
params["model_id"] = model_id
query_text = f"""
SELECT * FROM {table_name}
WHERE {' AND '.join(conditions)}
ORDER BY measured_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), params)
# Convert rows to model objects
metrics = []
for row in result.fetchall():
record_dict = dict(row._mapping)
metric = self.model(**record_dict)
metrics.append(metric)
return metrics
except Exception as e:
logger.error("Failed to get metrics in date range",
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
async def get_performance_trends(
self,
tenant_id: str,
product_name: str = None,
days: int = 30
) -> Dict[str, Any]:
"""Get performance trends for analysis"""
try:
start_date = datetime.now() - timedelta(days=days)
end_date = datetime.now()
# Build query for performance trends
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
params = {"tenant_id": tenant_id, "start_date": start_date}
if product_name:
conditions.append("product_name = :product_name")
params["product_name"] = product_name
query_text = f"""
SELECT
product_name,
AVG(mae) as avg_mae,
AVG(mse) as avg_mse,
AVG(rmse) as avg_rmse,
AVG(mape) as avg_mape,
AVG(r2_score) as avg_r2_score,
AVG(accuracy_percentage) as avg_accuracy,
COUNT(*) as measurement_count,
MIN(measured_at) as first_measurement,
MAX(measured_at) as last_measurement
FROM model_performance_metrics
WHERE {' AND '.join(conditions)}
GROUP BY product_name
ORDER BY avg_accuracy DESC
"""
result = await self.session.execute(text(query_text), params)
trends = []
for row in result.fetchall():
trends.append({
"product_name": row.product_name,
"metrics": {
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
},
"measurement_count": int(row.measurement_count),
"period": {
"start": row.first_measurement.isoformat() if row.first_measurement else None,
"end": row.last_measurement.isoformat() if row.last_measurement else None,
"days": days
}
})
return {
"tenant_id": tenant_id,
"product_name": product_name,
"trends": trends,
"period_days": days,
"total_products": len(trends)
}
except Exception as e:
logger.error("Failed to get performance trends",
tenant_id=tenant_id,
product_name=product_name,
error=str(e))
return {
"tenant_id": tenant_id,
"product_name": product_name,
"trends": [],
"period_days": days,
"total_products": 0
}
async def get_best_performing_models(
self,
tenant_id: str,
metric_type: str = "accuracy_percentage",
limit: int = 10
) -> List[Dict[str, Any]]:
"""Get best performing models based on a specific metric"""
try:
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
# For error metrics (mae, mse, rmse, mape), lower is better
# For performance metrics (r2_score, accuracy_percentage), higher is better
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
order_direction = "DESC" if order_desc else "ASC"
query_text = f"""
SELECT DISTINCT ON (product_name, model_id)
model_id,
product_name,
{metric_type},
measured_at,
evaluation_samples
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
AND {metric_type} IS NOT NULL
ORDER BY product_name, model_id, measured_at DESC, {metric_type} {order_direction}
LIMIT :limit
"""
result = await self.session.execute(text(query_text), {
"tenant_id": tenant_id,
"limit": limit
})
best_models = []
for row in result.fetchall():
best_models.append({
"model_id": row.model_id,
"product_name": row.product_name,
"metric_value": float(getattr(row, metric_type)),
"metric_type": metric_type,
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
})
return best_models
except Exception as e:
logger.error("Failed to get best performing models",
tenant_id=tenant_id,
metric_type=metric_type,
error=str(e))
return []
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
"""Clean up old performance metrics"""
return await self.cleanup_old_records(days_old=days_old)
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
"""Get performance metric statistics for a tenant"""
try:
# Get basic counts
total_metrics = await self.count(filters={"tenant_id": tenant_id})
# Get metrics by product using raw query
product_query = text("""
SELECT
product_name,
COUNT(*) as metric_count,
AVG(accuracy_percentage) as avg_accuracy
FROM model_performance_metrics
WHERE tenant_id = :tenant_id
GROUP BY product_name
ORDER BY avg_accuracy DESC
""")
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
product_stats = {}
for row in result.fetchall():
product_stats[row.product_name] = {
"metric_count": row.metric_count,
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
}
# Recent activity (metrics in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_metrics = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
return {
"total_metrics": total_metrics,
"products_tracked": len(product_stats),
"metrics_by_product": product_stats,
"recent_metrics_7d": recent_metrics
}
except Exception as e:
logger.error("Failed to get metric statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_metrics": 0,
"products_tracked": 0,
"metrics_by_product": {},
"recent_metrics_7d": 0
}
async def compare_model_performance(
self,
model_ids: List[str],
metric_type: str = "accuracy_percentage"
) -> Dict[str, Any]:
"""Compare performance between multiple models"""
try:
if not model_ids or len(model_ids) < 2:
return {"error": "At least 2 model IDs required for comparison"}
# Validate metric type
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
if metric_type not in valid_metrics:
metric_type = "accuracy_percentage"
model_ids_str = "', '".join(model_ids)
query_text = f"""
SELECT
model_id,
product_name,
AVG({metric_type}) as avg_metric,
MIN({metric_type}) as min_metric,
MAX({metric_type}) as max_metric,
COUNT(*) as measurement_count,
MAX(measured_at) as latest_measurement
FROM model_performance_metrics
WHERE model_id IN ('{model_ids_str}')
AND {metric_type} IS NOT NULL
GROUP BY model_id, product_name
ORDER BY avg_metric DESC
"""
result = await self.session.execute(text(query_text))
comparisons = []
for row in result.fetchall():
comparisons.append({
"model_id": row.model_id,
"product_name": row.product_name,
"avg_metric": float(row.avg_metric),
"min_metric": float(row.min_metric),
"max_metric": float(row.max_metric),
"measurement_count": int(row.measurement_count),
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
})
# Find best and worst performing models
if comparisons:
best_model = max(comparisons, key=lambda x: x["avg_metric"])
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
else:
best_model = worst_model = None
return {
"metric_type": metric_type,
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
"comparisons": comparisons,
"best_performing": best_model,
"worst_performing": worst_model
}
except Exception as e:
logger.error("Failed to compare model performance",
model_ids=model_ids,
metric_type=metric_type,
error=str(e))
return {"error": f"Comparison failed: {str(e)}"}

View File

@@ -0,0 +1,332 @@
"""
Training Log Repository
Repository for model training log operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelTrainingLog
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrainingLogRepository(TrainingBaseRepository):
"""Repository for training log operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training logs change frequently, shorter cache time (5 minutes)
super().__init__(ModelTrainingLog, session, cache_ttl)
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
"""Create a new training log entry"""
try:
# Validate log data
validation_result = self._validate_training_data(
log_data,
["job_id", "tenant_id", "status"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
# Set default values
if "progress" not in log_data:
log_data["progress"] = 0
if "current_step" not in log_data:
log_data["current_step"] = "initializing"
# Create log entry
log_entry = await self.create(log_data)
logger.info("Training log created",
job_id=log_entry.job_id,
tenant_id=log_entry.tenant_id,
status=log_entry.status)
return log_entry
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create training log",
job_id=log_data.get("job_id"),
error=str(e))
raise DatabaseError(f"Failed to create training log: {str(e)}")
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
"""Get training log by job ID"""
return await self.get_by_job_id(job_id)
async def update_log_progress(
self,
job_id: str,
progress: int,
current_step: str = None,
status: str = None
) -> Optional[ModelTrainingLog]:
"""Update training log progress"""
try:
update_data = {"progress": progress, "updated_at": datetime.now()}
if current_step:
update_data["current_step"] = current_step
if status:
update_data["status"] = status
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.debug("Training log progress updated",
job_id=job_id,
progress=progress,
step=current_step)
return updated_log
except Exception as e:
logger.error("Failed to update training log progress",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to update progress: {str(e)}")
async def complete_training_log(
self,
job_id: str,
results: Dict[str, Any] = None,
error_message: str = None
) -> Optional[ModelTrainingLog]:
"""Mark training log as completed or failed"""
try:
status = "failed" if error_message else "completed"
update_data = {
"status": status,
"progress": 100 if status == "completed" else None,
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if results:
update_data["results"] = results
if error_message:
update_data["error_message"] = error_message
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training log completed",
job_id=job_id,
status=status,
has_results=bool(results))
return updated_log
except Exception as e:
logger.error("Failed to complete training log",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to complete training log: {str(e)}")
async def get_logs_by_tenant(
self,
tenant_id: str,
status: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelTrainingLog]:
"""Get training logs for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if status:
filters["status"] = status
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get logs by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get training logs: {str(e)}")
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
"""Get currently running training jobs"""
try:
filters = {"status": "running"}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="start_time",
order_desc=True
)
except Exception as e:
logger.error("Failed to get active jobs",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
"""Cancel a training job"""
try:
update_data = {
"status": "cancelled",
"end_time": datetime.now(),
"updated_at": datetime.now()
}
if cancelled_by:
update_data["error_message"] = f"Cancelled by {cancelled_by}"
log_entry = await self.get_by_job_id(job_id)
if not log_entry:
logger.error(f"Training log not found for job {job_id}")
return None
# Only cancel if job is still running
if log_entry.status not in ["pending", "running"]:
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
return log_entry
updated_log = await self.update(log_entry.id, update_data)
logger.info("Training job cancelled",
job_id=job_id,
cancelled_by=cancelled_by)
return updated_log
except Exception as e:
logger.error("Failed to cancel training job",
job_id=job_id,
error=str(e))
raise DatabaseError(f"Failed to cancel job: {str(e)}")
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get training job statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get counts by status
total_jobs = await self.count(filters=base_filters)
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
running_jobs = await self.count(filters={**base_filters, "status": "running"})
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
# Get recent activity (jobs in last 7 days)
seven_days_ago = datetime.now() - timedelta(days=7)
recent_jobs = len(await self.get_records_by_date_range(
seven_days_ago,
datetime.now(),
limit=1000 # High limit to get accurate count
))
# Calculate success rate
finished_jobs = completed_jobs + failed_jobs
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
return {
"total_jobs": total_jobs,
"completed_jobs": completed_jobs,
"failed_jobs": failed_jobs,
"running_jobs": running_jobs,
"pending_jobs": pending_jobs,
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
"success_rate": round(success_rate, 2),
"recent_jobs_7d": recent_jobs
}
except Exception as e:
logger.error("Failed to get job statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_jobs": 0,
"completed_jobs": 0,
"failed_jobs": 0,
"running_jobs": 0,
"pending_jobs": 0,
"cancelled_jobs": 0,
"success_rate": 0.0,
"recent_jobs_7d": 0
}
async def cleanup_old_logs(self, days_old: int = 90) -> int:
"""Clean up old completed/failed training logs"""
return await self.cleanup_old_records(
days_old=days_old,
status_filter=None # Clean up all old records regardless of status
)
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get job duration statistics"""
try:
# Use raw SQL for complex duration calculations
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
params = {"tenant_id": tenant_id} if tenant_id else {}
query = text(f"""
SELECT
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
COUNT(*) as completed_jobs_with_duration
FROM model_training_logs
WHERE status = 'completed'
AND start_time IS NOT NULL
AND end_time IS NOT NULL
{tenant_filter}
""")
result = await self.session.execute(query, params)
row = result.fetchone()
if row and row.completed_jobs_with_duration > 0:
return {
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
}
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}
except Exception as e:
logger.error("Failed to get job duration statistics",
tenant_id=tenant_id,
error=str(e))
return {
"avg_duration_minutes": 0.0,
"min_duration_minutes": 0.0,
"max_duration_minutes": 0.0,
"completed_jobs_with_duration": 0
}