Files
bakery-ia/services/training/app/repositories/artifact_repository.py
2025-12-05 20:07:01 +01:00

560 lines
22 KiB
Python

"""
Artifact Repository
Repository for model artifact operations
"""
from typing import Optional, List, Dict, Any
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, text, desc
from datetime import datetime, timedelta
import structlog
from .base import TrainingBaseRepository
from app.models.training import ModelArtifact
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class ArtifactRepository(TrainingBaseRepository):
"""Repository for model artifact operations"""
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
# Artifacts are stable, longer cache time (30 minutes)
super().__init__(ModelArtifact, session, cache_ttl)
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
"""Create a new model artifact record"""
try:
# Validate artifact data
validation_result = self._validate_training_data(
artifact_data,
["model_id", "tenant_id", "artifact_type", "file_path"]
)
if not validation_result["is_valid"]:
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
# Set default values
if "storage_location" not in artifact_data:
artifact_data["storage_location"] = "local"
# Create artifact record
artifact = await self.create(artifact_data)
logger.info("Model artifact created",
model_id=artifact.model_id,
tenant_id=artifact.tenant_id,
artifact_type=artifact.artifact_type,
file_path=artifact.file_path)
return artifact
except ValidationError:
raise
except Exception as e:
logger.error("Failed to create model artifact",
model_id=artifact_data.get("model_id"),
error=str(e))
raise DatabaseError(f"Failed to create artifact: {str(e)}")
async def get_artifacts_by_model(
self,
model_id: str,
artifact_type: str = None
) -> List[ModelArtifact]:
"""Get all artifacts for a model"""
try:
filters = {"model_id": model_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by model",
model_id=model_id,
artifact_type=artifact_type,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifacts_by_tenant(
self,
tenant_id: str,
artifact_type: str = None,
skip: int = 0,
limit: int = 100
) -> List[ModelArtifact]:
"""Get artifacts for a tenant"""
try:
filters = {"tenant_id": tenant_id}
if artifact_type:
filters["artifact_type"] = artifact_type
return await self.get_multi(
filters=filters,
skip=skip,
limit=limit,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by tenant",
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
"""Get artifact by file path"""
try:
return await self.get_by_field("file_path", file_path)
except Exception as e:
logger.error("Failed to get artifact by path",
file_path=file_path,
error=str(e))
raise DatabaseError(f"Failed to get artifact: {str(e)}")
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
"""Update artifact file size"""
try:
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
except Exception as e:
logger.error("Failed to update artifact size",
artifact_id=artifact_id,
error=str(e))
return None
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
"""Update artifact checksum for integrity verification"""
try:
return await self.update(artifact_id, {"checksum": checksum})
except Exception as e:
logger.error("Failed to update artifact checksum",
artifact_id=artifact_id,
error=str(e))
return None
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
"""Mark artifact for expiration/cleanup"""
try:
if not expires_at:
expires_at = datetime.now()
return await self.update(artifact_id, {"expires_at": expires_at})
except Exception as e:
logger.error("Failed to mark artifact as expired",
artifact_id=artifact_id,
error=str(e))
return None
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
"""Get artifacts that have expired"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
SELECT * FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
ORDER BY expires_at ASC
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
expired_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
expired_artifacts.append(artifact)
return expired_artifacts
except Exception as e:
logger.error("Failed to get expired artifacts",
days_expired=days_expired,
error=str(e))
return []
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
"""Clean up expired artifacts"""
try:
cutoff_date = datetime.now() - timedelta(days=days_expired)
query_text = """
DELETE FROM model_artifacts
WHERE expires_at IS NOT NULL
AND expires_at <= :cutoff_date
"""
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
deleted_count = result.rowcount
logger.info("Cleaned up expired artifacts",
deleted_count=deleted_count,
days_expired=days_expired)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup expired artifacts",
days_expired=days_expired,
error=str(e))
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
"""Get artifacts larger than specified size"""
try:
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
query_text = """
SELECT * FROM model_artifacts
WHERE file_size_bytes >= :min_size_bytes
ORDER BY file_size_bytes DESC
"""
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
large_artifacts = []
for row in result.fetchall():
record_dict = dict(row._mapping)
artifact = self.model(**record_dict)
large_artifacts.append(artifact)
return large_artifacts
except Exception as e:
logger.error("Failed to get large artifacts",
min_size_mb=min_size_mb,
error=str(e))
return []
async def get_artifacts_by_storage_location(
self,
storage_location: str,
tenant_id: str = None
) -> List[ModelArtifact]:
"""Get artifacts by storage location"""
try:
filters = {"storage_location": storage_location}
if tenant_id:
filters["tenant_id"] = tenant_id
return await self.get_multi(
filters=filters,
order_by="created_at",
order_desc=True
)
except Exception as e:
logger.error("Failed to get artifacts by storage location",
storage_location=storage_location,
tenant_id=tenant_id,
error=str(e))
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
"""Get artifact statistics"""
try:
base_filters = {}
if tenant_id:
base_filters["tenant_id"] = tenant_id
# Get basic counts
total_artifacts = await self.count(filters=base_filters)
# Get artifacts by type
type_query_params = {}
type_query_filter = ""
if tenant_id:
type_query_filter = "WHERE tenant_id = :tenant_id"
type_query_params["tenant_id"] = tenant_id
type_query = text(f"""
SELECT artifact_type, COUNT(*) as count
FROM model_artifacts
{type_query_filter}
GROUP BY artifact_type
ORDER BY count DESC
""")
result = await self.session.execute(type_query, type_query_params)
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
# Get storage location stats
location_query = text(f"""
SELECT
storage_location,
COUNT(*) as count,
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
FROM model_artifacts
{type_query_filter}
GROUP BY storage_location
ORDER BY count DESC
""")
location_result = await self.session.execute(location_query, type_query_params)
storage_stats = {}
total_size_bytes = 0
for row in location_result.fetchall():
storage_stats[row.storage_location] = {
"artifact_count": row.count,
"total_size_bytes": int(row.total_size_bytes or 0),
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
}
total_size_bytes += row.total_size_bytes or 0
# Get expired artifacts count
expired_artifacts = len(await self.get_expired_artifacts())
return {
"total_artifacts": total_artifacts,
"expired_artifacts": expired_artifacts,
"active_artifacts": total_artifacts - expired_artifacts,
"artifacts_by_type": artifacts_by_type,
"storage_statistics": storage_stats,
"total_storage": {
"total_size_bytes": total_size_bytes,
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
}
}
except Exception as e:
logger.error("Failed to get artifact statistics",
tenant_id=tenant_id,
error=str(e))
return {
"total_artifacts": 0,
"expired_artifacts": 0,
"active_artifacts": 0,
"artifacts_by_type": {},
"storage_statistics": {},
"total_storage": {
"total_size_bytes": 0,
"total_size_mb": 0.0,
"total_size_gb": 0.0
}
}
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
"""Verify artifact file integrity with actual file system checks"""
try:
import os
import hashlib
artifact = await self.get_by_id(artifact_id)
if not artifact:
return {"exists": False, "error": "Artifact not found"}
# Check if file exists
file_exists = os.path.exists(artifact.file_path)
if not file_exists:
return {
"artifact_id": artifact_id,
"file_path": artifact.file_path,
"exists": False,
"checksum_valid": False,
"size_valid": False,
"storage_location": artifact.storage_location,
"last_verified": datetime.now().isoformat(),
"error": "File does not exist on disk"
}
# Verify file size
actual_size = os.path.getsize(artifact.file_path)
size_valid = True
if artifact.file_size_bytes:
size_valid = (actual_size == artifact.file_size_bytes)
# Verify checksum if stored
checksum_valid = True
actual_checksum = None
if artifact.checksum:
# Calculate checksum of actual file
sha256_hash = hashlib.sha256()
try:
with open(artifact.file_path, "rb") as f:
# Read file in chunks to handle large files
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
actual_checksum = sha256_hash.hexdigest()
checksum_valid = (actual_checksum == artifact.checksum)
except Exception as checksum_error:
logger.error(f"Failed to calculate checksum: {checksum_error}")
checksum_valid = False
actual_checksum = None
# Overall integrity status
integrity_valid = file_exists and size_valid and checksum_valid
result = {
"artifact_id": artifact_id,
"file_path": artifact.file_path,
"exists": file_exists,
"checksum_valid": checksum_valid,
"size_valid": size_valid,
"integrity_valid": integrity_valid,
"storage_location": artifact.storage_location,
"last_verified": datetime.now().isoformat(),
"details": {
"stored_size_bytes": artifact.file_size_bytes,
"actual_size_bytes": actual_size if file_exists else None,
"stored_checksum": artifact.checksum,
"actual_checksum": actual_checksum
}
}
if not integrity_valid:
issues = []
if not file_exists:
issues.append("file_missing")
if not size_valid:
issues.append("size_mismatch")
if not checksum_valid:
issues.append("checksum_mismatch")
result["issues"] = issues
return result
except Exception as e:
logger.error("Failed to verify artifact integrity",
artifact_id=artifact_id,
error=str(e))
return {
"exists": False,
"error": f"Verification failed: {str(e)}"
}
async def migrate_artifacts_to_storage(
self,
from_location: str,
to_location: str,
tenant_id: str = None,
copy_only: bool = False,
verify: bool = True
) -> Dict[str, Any]:
"""Migrate artifacts from one storage location to another with actual file operations"""
try:
import os
import shutil
import hashlib
# Get artifacts to migrate
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
migrated_count = 0
failed_count = 0
failed_artifacts = []
verified_count = 0
for artifact in artifacts:
try:
# Determine new file path
new_file_path = artifact.file_path.replace(from_location, to_location, 1)
# Create destination directory if it doesn't exist
dest_dir = os.path.dirname(new_file_path)
os.makedirs(dest_dir, exist_ok=True)
# Check if source file exists
if not os.path.exists(artifact.file_path):
logger.warning(f"Source file not found: {artifact.file_path}")
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": artifact.file_path,
"reason": "source_file_not_found"
})
continue
# Copy or move file
if copy_only:
shutil.copy2(artifact.file_path, new_file_path)
logger.debug(f"Copied file from {artifact.file_path} to {new_file_path}")
else:
shutil.move(artifact.file_path, new_file_path)
logger.debug(f"Moved file from {artifact.file_path} to {new_file_path}")
# Verify file was copied/moved successfully
if verify and os.path.exists(new_file_path):
# Verify file size
new_size = os.path.getsize(new_file_path)
if artifact.file_size_bytes and new_size != artifact.file_size_bytes:
logger.warning(f"File size mismatch after migration: {new_file_path}")
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": new_file_path,
"reason": "size_mismatch_after_migration"
})
continue
# Verify checksum if available
if artifact.checksum:
sha256_hash = hashlib.sha256()
with open(new_file_path, "rb") as f:
for byte_block in iter(lambda: f.read(4096), b""):
sha256_hash.update(byte_block)
new_checksum = sha256_hash.hexdigest()
if new_checksum != artifact.checksum:
logger.warning(f"Checksum mismatch after migration: {new_file_path}")
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": new_file_path,
"reason": "checksum_mismatch_after_migration"
})
continue
verified_count += 1
# Update database with new location
await self.update(artifact.id, {
"storage_location": to_location,
"file_path": new_file_path
})
migrated_count += 1
except Exception as migration_error:
logger.error("Failed to migrate artifact",
artifact_id=artifact.id,
error=str(migration_error))
failed_count += 1
failed_artifacts.append({
"artifact_id": artifact.id,
"file_path": artifact.file_path,
"reason": str(migration_error)
})
logger.info("Artifact migration completed",
from_location=from_location,
to_location=to_location,
migrated_count=migrated_count,
failed_count=failed_count,
verified_count=verified_count)
return {
"from_location": from_location,
"to_location": to_location,
"total_artifacts": len(artifacts),
"migrated_count": migrated_count,
"failed_count": failed_count,
"verified_count": verified_count if verify else None,
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100,
"copy_only": copy_only,
"failed_artifacts": failed_artifacts if failed_artifacts else None
}
except Exception as e:
logger.error("Failed to migrate artifacts",
from_location=from_location,
to_location=to_location,
error=str(e))
return {
"error": f"Migration failed: {str(e)}"
}