433 lines
17 KiB
Python
433 lines
17 KiB
Python
"""
|
|
Artifact Repository
|
|
Repository for model artifact operations
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, text, desc
|
|
from datetime import datetime, timedelta
|
|
import structlog
|
|
|
|
from .base import TrainingBaseRepository
|
|
from app.models.training import ModelArtifact
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class ArtifactRepository(TrainingBaseRepository):
|
|
"""Repository for model artifact operations"""
|
|
|
|
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
|
|
# Artifacts are stable, longer cache time (30 minutes)
|
|
super().__init__(ModelArtifact, session, cache_ttl)
|
|
|
|
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
|
|
"""Create a new model artifact record"""
|
|
try:
|
|
# Validate artifact data
|
|
validation_result = self._validate_training_data(
|
|
artifact_data,
|
|
["model_id", "tenant_id", "artifact_type", "file_path"]
|
|
)
|
|
|
|
if not validation_result["is_valid"]:
|
|
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
|
|
|
|
# Set default values
|
|
if "storage_location" not in artifact_data:
|
|
artifact_data["storage_location"] = "local"
|
|
|
|
# Create artifact record
|
|
artifact = await self.create(artifact_data)
|
|
|
|
logger.info("Model artifact created",
|
|
model_id=artifact.model_id,
|
|
tenant_id=artifact.tenant_id,
|
|
artifact_type=artifact.artifact_type,
|
|
file_path=artifact.file_path)
|
|
|
|
return artifact
|
|
|
|
except ValidationError:
|
|
raise
|
|
except Exception as e:
|
|
logger.error("Failed to create model artifact",
|
|
model_id=artifact_data.get("model_id"),
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to create artifact: {str(e)}")
|
|
|
|
async def get_artifacts_by_model(
|
|
self,
|
|
model_id: str,
|
|
artifact_type: str = None
|
|
) -> List[ModelArtifact]:
|
|
"""Get all artifacts for a model"""
|
|
try:
|
|
filters = {"model_id": model_id}
|
|
if artifact_type:
|
|
filters["artifact_type"] = artifact_type
|
|
|
|
return await self.get_multi(
|
|
filters=filters,
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get artifacts by model",
|
|
model_id=model_id,
|
|
artifact_type=artifact_type,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
|
|
|
async def get_artifacts_by_tenant(
|
|
self,
|
|
tenant_id: str,
|
|
artifact_type: str = None,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[ModelArtifact]:
|
|
"""Get artifacts for a tenant"""
|
|
try:
|
|
filters = {"tenant_id": tenant_id}
|
|
if artifact_type:
|
|
filters["artifact_type"] = artifact_type
|
|
|
|
return await self.get_multi(
|
|
filters=filters,
|
|
skip=skip,
|
|
limit=limit,
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get artifacts by tenant",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
|
|
|
|
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
|
|
"""Get artifact by file path"""
|
|
try:
|
|
return await self.get_by_field("file_path", file_path)
|
|
except Exception as e:
|
|
logger.error("Failed to get artifact by path",
|
|
file_path=file_path,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get artifact: {str(e)}")
|
|
|
|
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
|
|
"""Update artifact file size"""
|
|
try:
|
|
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
|
|
except Exception as e:
|
|
logger.error("Failed to update artifact size",
|
|
artifact_id=artifact_id,
|
|
error=str(e))
|
|
return None
|
|
|
|
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
|
|
"""Update artifact checksum for integrity verification"""
|
|
try:
|
|
return await self.update(artifact_id, {"checksum": checksum})
|
|
except Exception as e:
|
|
logger.error("Failed to update artifact checksum",
|
|
artifact_id=artifact_id,
|
|
error=str(e))
|
|
return None
|
|
|
|
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
|
|
"""Mark artifact for expiration/cleanup"""
|
|
try:
|
|
if not expires_at:
|
|
expires_at = datetime.now()
|
|
|
|
return await self.update(artifact_id, {"expires_at": expires_at})
|
|
except Exception as e:
|
|
logger.error("Failed to mark artifact as expired",
|
|
artifact_id=artifact_id,
|
|
error=str(e))
|
|
return None
|
|
|
|
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
|
|
"""Get artifacts that have expired"""
|
|
try:
|
|
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
|
|
|
query_text = """
|
|
SELECT * FROM model_artifacts
|
|
WHERE expires_at IS NOT NULL
|
|
AND expires_at <= :cutoff_date
|
|
ORDER BY expires_at ASC
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
|
|
|
expired_artifacts = []
|
|
for row in result.fetchall():
|
|
record_dict = dict(row._mapping)
|
|
artifact = self.model(**record_dict)
|
|
expired_artifacts.append(artifact)
|
|
|
|
return expired_artifacts
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get expired artifacts",
|
|
days_expired=days_expired,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
|
|
"""Clean up expired artifacts"""
|
|
try:
|
|
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
|
|
|
query_text = """
|
|
DELETE FROM model_artifacts
|
|
WHERE expires_at IS NOT NULL
|
|
AND expires_at <= :cutoff_date
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
|
deleted_count = result.rowcount
|
|
|
|
logger.info("Cleaned up expired artifacts",
|
|
deleted_count=deleted_count,
|
|
days_expired=days_expired)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup expired artifacts",
|
|
days_expired=days_expired,
|
|
error=str(e))
|
|
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
|
|
|
|
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
|
|
"""Get artifacts larger than specified size"""
|
|
try:
|
|
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
|
|
|
|
query_text = """
|
|
SELECT * FROM model_artifacts
|
|
WHERE file_size_bytes >= :min_size_bytes
|
|
ORDER BY file_size_bytes DESC
|
|
"""
|
|
|
|
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
|
|
|
|
large_artifacts = []
|
|
for row in result.fetchall():
|
|
record_dict = dict(row._mapping)
|
|
artifact = self.model(**record_dict)
|
|
large_artifacts.append(artifact)
|
|
|
|
return large_artifacts
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get large artifacts",
|
|
min_size_mb=min_size_mb,
|
|
error=str(e))
|
|
return []
|
|
|
|
async def get_artifacts_by_storage_location(
|
|
self,
|
|
storage_location: str,
|
|
tenant_id: str = None
|
|
) -> List[ModelArtifact]:
|
|
"""Get artifacts by storage location"""
|
|
try:
|
|
filters = {"storage_location": storage_location}
|
|
if tenant_id:
|
|
filters["tenant_id"] = tenant_id
|
|
|
|
return await self.get_multi(
|
|
filters=filters,
|
|
order_by="created_at",
|
|
order_desc=True
|
|
)
|
|
except Exception as e:
|
|
logger.error("Failed to get artifacts by storage location",
|
|
storage_location=storage_location,
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
|
|
|
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
|
"""Get artifact statistics"""
|
|
try:
|
|
base_filters = {}
|
|
if tenant_id:
|
|
base_filters["tenant_id"] = tenant_id
|
|
|
|
# Get basic counts
|
|
total_artifacts = await self.count(filters=base_filters)
|
|
|
|
# Get artifacts by type
|
|
type_query_params = {}
|
|
type_query_filter = ""
|
|
if tenant_id:
|
|
type_query_filter = "WHERE tenant_id = :tenant_id"
|
|
type_query_params["tenant_id"] = tenant_id
|
|
|
|
type_query = text(f"""
|
|
SELECT artifact_type, COUNT(*) as count
|
|
FROM model_artifacts
|
|
{type_query_filter}
|
|
GROUP BY artifact_type
|
|
ORDER BY count DESC
|
|
""")
|
|
|
|
result = await self.session.execute(type_query, type_query_params)
|
|
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
|
|
|
|
# Get storage location stats
|
|
location_query = text(f"""
|
|
SELECT
|
|
storage_location,
|
|
COUNT(*) as count,
|
|
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
|
|
FROM model_artifacts
|
|
{type_query_filter}
|
|
GROUP BY storage_location
|
|
ORDER BY count DESC
|
|
""")
|
|
|
|
location_result = await self.session.execute(location_query, type_query_params)
|
|
storage_stats = {}
|
|
total_size_bytes = 0
|
|
|
|
for row in location_result.fetchall():
|
|
storage_stats[row.storage_location] = {
|
|
"artifact_count": row.count,
|
|
"total_size_bytes": int(row.total_size_bytes or 0),
|
|
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
|
|
}
|
|
total_size_bytes += row.total_size_bytes or 0
|
|
|
|
# Get expired artifacts count
|
|
expired_artifacts = len(await self.get_expired_artifacts())
|
|
|
|
return {
|
|
"total_artifacts": total_artifacts,
|
|
"expired_artifacts": expired_artifacts,
|
|
"active_artifacts": total_artifacts - expired_artifacts,
|
|
"artifacts_by_type": artifacts_by_type,
|
|
"storage_statistics": storage_stats,
|
|
"total_storage": {
|
|
"total_size_bytes": total_size_bytes,
|
|
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
|
|
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get artifact statistics",
|
|
tenant_id=tenant_id,
|
|
error=str(e))
|
|
return {
|
|
"total_artifacts": 0,
|
|
"expired_artifacts": 0,
|
|
"active_artifacts": 0,
|
|
"artifacts_by_type": {},
|
|
"storage_statistics": {},
|
|
"total_storage": {
|
|
"total_size_bytes": 0,
|
|
"total_size_mb": 0.0,
|
|
"total_size_gb": 0.0
|
|
}
|
|
}
|
|
|
|
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
|
|
"""Verify artifact file integrity (placeholder for file system checks)"""
|
|
try:
|
|
artifact = await self.get_by_id(artifact_id)
|
|
if not artifact:
|
|
return {"exists": False, "error": "Artifact not found"}
|
|
|
|
# This is a placeholder - in a real implementation, you would:
|
|
# 1. Check if the file exists at artifact.file_path
|
|
# 2. Calculate current checksum and compare with stored checksum
|
|
# 3. Verify file size matches stored file_size_bytes
|
|
|
|
return {
|
|
"artifact_id": artifact_id,
|
|
"file_path": artifact.file_path,
|
|
"exists": True, # Would check actual file existence
|
|
"checksum_valid": True, # Would verify actual checksum
|
|
"size_valid": True, # Would verify actual file size
|
|
"storage_location": artifact.storage_location,
|
|
"last_verified": datetime.now().isoformat()
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to verify artifact integrity",
|
|
artifact_id=artifact_id,
|
|
error=str(e))
|
|
return {
|
|
"exists": False,
|
|
"error": f"Verification failed: {str(e)}"
|
|
}
|
|
|
|
async def migrate_artifacts_to_storage(
|
|
self,
|
|
from_location: str,
|
|
to_location: str,
|
|
tenant_id: str = None
|
|
) -> Dict[str, Any]:
|
|
"""Migrate artifacts from one storage location to another (placeholder)"""
|
|
try:
|
|
# Get artifacts to migrate
|
|
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
|
|
|
|
migrated_count = 0
|
|
failed_count = 0
|
|
|
|
# This is a placeholder - in a real implementation, you would:
|
|
# 1. Copy files from old location to new location
|
|
# 2. Update file paths in database
|
|
# 3. Verify successful migration
|
|
# 4. Clean up old files
|
|
|
|
for artifact in artifacts:
|
|
try:
|
|
# Placeholder migration logic
|
|
new_file_path = artifact.file_path.replace(from_location, to_location)
|
|
|
|
await self.update(artifact.id, {
|
|
"storage_location": to_location,
|
|
"file_path": new_file_path
|
|
})
|
|
|
|
migrated_count += 1
|
|
|
|
except Exception as migration_error:
|
|
logger.error("Failed to migrate artifact",
|
|
artifact_id=artifact.id,
|
|
error=str(migration_error))
|
|
failed_count += 1
|
|
|
|
logger.info("Artifact migration completed",
|
|
from_location=from_location,
|
|
to_location=to_location,
|
|
migrated_count=migrated_count,
|
|
failed_count=failed_count)
|
|
|
|
return {
|
|
"from_location": from_location,
|
|
"to_location": to_location,
|
|
"total_artifacts": len(artifacts),
|
|
"migrated_count": migrated_count,
|
|
"failed_count": failed_count,
|
|
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to migrate artifacts",
|
|
from_location=from_location,
|
|
to_location=to_location,
|
|
error=str(e))
|
|
return {
|
|
"error": f"Migration failed: {str(e)}"
|
|
} |