REFACTOR - Database logic
This commit is contained in:
433
services/training/app/repositories/artifact_repository.py
Normal file
433
services/training/app/repositories/artifact_repository.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Artifact Repository
|
||||
Repository for model artifact operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelArtifact
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ArtifactRepository(TrainingBaseRepository):
|
||||
"""Repository for model artifact operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
|
||||
# Artifacts are stable, longer cache time (30 minutes)
|
||||
super().__init__(ModelArtifact, session, cache_ttl)
|
||||
|
||||
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
|
||||
"""Create a new model artifact record"""
|
||||
try:
|
||||
# Validate artifact data
|
||||
validation_result = self._validate_training_data(
|
||||
artifact_data,
|
||||
["model_id", "tenant_id", "artifact_type", "file_path"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "storage_location" not in artifact_data:
|
||||
artifact_data["storage_location"] = "local"
|
||||
|
||||
# Create artifact record
|
||||
artifact = await self.create(artifact_data)
|
||||
|
||||
logger.info("Model artifact created",
|
||||
model_id=artifact.model_id,
|
||||
tenant_id=artifact.tenant_id,
|
||||
artifact_type=artifact.artifact_type,
|
||||
file_path=artifact.file_path)
|
||||
|
||||
return artifact
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create model artifact",
|
||||
model_id=artifact_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create artifact: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
artifact_type: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get all artifacts for a model"""
|
||||
try:
|
||||
filters = {"model_id": model_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by model",
|
||||
model_id=model_id,
|
||||
artifact_type=artifact_type,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
artifact_type: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
|
||||
"""Get artifact by file path"""
|
||||
try:
|
||||
return await self.get_by_field("file_path", file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact by path",
|
||||
file_path=file_path,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifact: {str(e)}")
|
||||
|
||||
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
|
||||
"""Update artifact file size"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact size",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
|
||||
"""Update artifact checksum for integrity verification"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"checksum": checksum})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact checksum",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
|
||||
"""Mark artifact for expiration/cleanup"""
|
||||
try:
|
||||
if not expires_at:
|
||||
expires_at = datetime.now()
|
||||
|
||||
return await self.update(artifact_id, {"expires_at": expires_at})
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark artifact as expired",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
|
||||
"""Get artifacts that have expired"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
ORDER BY expires_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
|
||||
expired_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
expired_artifacts.append(artifact)
|
||||
|
||||
return expired_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
|
||||
"""Clean up expired artifacts"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired artifacts",
|
||||
deleted_count=deleted_count,
|
||||
days_expired=days_expired)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
|
||||
|
||||
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
|
||||
"""Get artifacts larger than specified size"""
|
||||
try:
|
||||
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE file_size_bytes >= :min_size_bytes
|
||||
ORDER BY file_size_bytes DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
|
||||
|
||||
large_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
large_artifacts.append(artifact)
|
||||
|
||||
return large_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get large artifacts",
|
||||
min_size_mb=min_size_mb,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_artifacts_by_storage_location(
|
||||
self,
|
||||
storage_location: str,
|
||||
tenant_id: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts by storage location"""
|
||||
try:
|
||||
filters = {"storage_location": storage_location}
|
||||
if tenant_id:
|
||||
filters["tenant_id"] = tenant_id
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by storage location",
|
||||
storage_location=storage_location,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get artifact statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get basic counts
|
||||
total_artifacts = await self.count(filters=base_filters)
|
||||
|
||||
# Get artifacts by type
|
||||
type_query_params = {}
|
||||
type_query_filter = ""
|
||||
if tenant_id:
|
||||
type_query_filter = "WHERE tenant_id = :tenant_id"
|
||||
type_query_params["tenant_id"] = tenant_id
|
||||
|
||||
type_query = text(f"""
|
||||
SELECT artifact_type, COUNT(*) as count
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY artifact_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(type_query, type_query_params)
|
||||
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get storage location stats
|
||||
location_query = text(f"""
|
||||
SELECT
|
||||
storage_location,
|
||||
COUNT(*) as count,
|
||||
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY storage_location
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
location_result = await self.session.execute(location_query, type_query_params)
|
||||
storage_stats = {}
|
||||
total_size_bytes = 0
|
||||
|
||||
for row in location_result.fetchall():
|
||||
storage_stats[row.storage_location] = {
|
||||
"artifact_count": row.count,
|
||||
"total_size_bytes": int(row.total_size_bytes or 0),
|
||||
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
|
||||
}
|
||||
total_size_bytes += row.total_size_bytes or 0
|
||||
|
||||
# Get expired artifacts count
|
||||
expired_artifacts = len(await self.get_expired_artifacts())
|
||||
|
||||
return {
|
||||
"total_artifacts": total_artifacts,
|
||||
"expired_artifacts": expired_artifacts,
|
||||
"active_artifacts": total_artifacts - expired_artifacts,
|
||||
"artifacts_by_type": artifacts_by_type,
|
||||
"storage_statistics": storage_stats,
|
||||
"total_storage": {
|
||||
"total_size_bytes": total_size_bytes,
|
||||
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
|
||||
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_artifacts": 0,
|
||||
"expired_artifacts": 0,
|
||||
"active_artifacts": 0,
|
||||
"artifacts_by_type": {},
|
||||
"storage_statistics": {},
|
||||
"total_storage": {
|
||||
"total_size_bytes": 0,
|
||||
"total_size_mb": 0.0,
|
||||
"total_size_gb": 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
|
||||
"""Verify artifact file integrity (placeholder for file system checks)"""
|
||||
try:
|
||||
artifact = await self.get_by_id(artifact_id)
|
||||
if not artifact:
|
||||
return {"exists": False, "error": "Artifact not found"}
|
||||
|
||||
# This is a placeholder - in a real implementation, you would:
|
||||
# 1. Check if the file exists at artifact.file_path
|
||||
# 2. Calculate current checksum and compare with stored checksum
|
||||
# 3. Verify file size matches stored file_size_bytes
|
||||
|
||||
return {
|
||||
"artifact_id": artifact_id,
|
||||
"file_path": artifact.file_path,
|
||||
"exists": True, # Would check actual file existence
|
||||
"checksum_valid": True, # Would verify actual checksum
|
||||
"size_valid": True, # Would verify actual file size
|
||||
"storage_location": artifact.storage_location,
|
||||
"last_verified": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify artifact integrity",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"exists": False,
|
||||
"error": f"Verification failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def migrate_artifacts_to_storage(
|
||||
self,
|
||||
from_location: str,
|
||||
to_location: str,
|
||||
tenant_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Migrate artifacts from one storage location to another (placeholder)"""
|
||||
try:
|
||||
# Get artifacts to migrate
|
||||
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
|
||||
|
||||
migrated_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# This is a placeholder - in a real implementation, you would:
|
||||
# 1. Copy files from old location to new location
|
||||
# 2. Update file paths in database
|
||||
# 3. Verify successful migration
|
||||
# 4. Clean up old files
|
||||
|
||||
for artifact in artifacts:
|
||||
try:
|
||||
# Placeholder migration logic
|
||||
new_file_path = artifact.file_path.replace(from_location, to_location)
|
||||
|
||||
await self.update(artifact.id, {
|
||||
"storage_location": to_location,
|
||||
"file_path": new_file_path
|
||||
})
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as migration_error:
|
||||
logger.error("Failed to migrate artifact",
|
||||
artifact_id=artifact.id,
|
||||
error=str(migration_error))
|
||||
failed_count += 1
|
||||
|
||||
logger.info("Artifact migration completed",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
migrated_count=migrated_count,
|
||||
failed_count=failed_count)
|
||||
|
||||
return {
|
||||
"from_location": from_location,
|
||||
"to_location": to_location,
|
||||
"total_artifacts": len(artifacts),
|
||||
"migrated_count": migrated_count,
|
||||
"failed_count": failed_count,
|
||||
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate artifacts",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
error=str(e))
|
||||
return {
|
||||
"error": f"Migration failed: {str(e)}"
|
||||
}
|
||||
Reference in New Issue
Block a user