Initial commit - production deployment
This commit is contained in:
560
services/training/app/repositories/artifact_repository.py
Normal file
560
services/training/app/repositories/artifact_repository.py
Normal file
@@ -0,0 +1,560 @@
|
||||
"""
|
||||
Artifact Repository
|
||||
Repository for model artifact operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelArtifact
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ArtifactRepository(TrainingBaseRepository):
|
||||
"""Repository for model artifact operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
|
||||
# Artifacts are stable, longer cache time (30 minutes)
|
||||
super().__init__(ModelArtifact, session, cache_ttl)
|
||||
|
||||
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
|
||||
"""Create a new model artifact record"""
|
||||
try:
|
||||
# Validate artifact data
|
||||
validation_result = self._validate_training_data(
|
||||
artifact_data,
|
||||
["model_id", "tenant_id", "artifact_type", "file_path"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "storage_location" not in artifact_data:
|
||||
artifact_data["storage_location"] = "local"
|
||||
|
||||
# Create artifact record
|
||||
artifact = await self.create(artifact_data)
|
||||
|
||||
logger.info("Model artifact created",
|
||||
model_id=artifact.model_id,
|
||||
tenant_id=artifact.tenant_id,
|
||||
artifact_type=artifact.artifact_type,
|
||||
file_path=artifact.file_path)
|
||||
|
||||
return artifact
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create model artifact",
|
||||
model_id=artifact_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create artifact: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
artifact_type: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get all artifacts for a model"""
|
||||
try:
|
||||
filters = {"model_id": model_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by model",
|
||||
model_id=model_id,
|
||||
artifact_type=artifact_type,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
artifact_type: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
|
||||
"""Get artifact by file path"""
|
||||
try:
|
||||
return await self.get_by_field("file_path", file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact by path",
|
||||
file_path=file_path,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifact: {str(e)}")
|
||||
|
||||
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
|
||||
"""Update artifact file size"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact size",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
|
||||
"""Update artifact checksum for integrity verification"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"checksum": checksum})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact checksum",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
|
||||
"""Mark artifact for expiration/cleanup"""
|
||||
try:
|
||||
if not expires_at:
|
||||
expires_at = datetime.now()
|
||||
|
||||
return await self.update(artifact_id, {"expires_at": expires_at})
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark artifact as expired",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
|
||||
"""Get artifacts that have expired"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
ORDER BY expires_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
|
||||
expired_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
expired_artifacts.append(artifact)
|
||||
|
||||
return expired_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
|
||||
"""Clean up expired artifacts"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired artifacts",
|
||||
deleted_count=deleted_count,
|
||||
days_expired=days_expired)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
|
||||
|
||||
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
|
||||
"""Get artifacts larger than specified size"""
|
||||
try:
|
||||
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE file_size_bytes >= :min_size_bytes
|
||||
ORDER BY file_size_bytes DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
|
||||
|
||||
large_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
large_artifacts.append(artifact)
|
||||
|
||||
return large_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get large artifacts",
|
||||
min_size_mb=min_size_mb,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_artifacts_by_storage_location(
|
||||
self,
|
||||
storage_location: str,
|
||||
tenant_id: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts by storage location"""
|
||||
try:
|
||||
filters = {"storage_location": storage_location}
|
||||
if tenant_id:
|
||||
filters["tenant_id"] = tenant_id
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by storage location",
|
||||
storage_location=storage_location,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get artifact statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get basic counts
|
||||
total_artifacts = await self.count(filters=base_filters)
|
||||
|
||||
# Get artifacts by type
|
||||
type_query_params = {}
|
||||
type_query_filter = ""
|
||||
if tenant_id:
|
||||
type_query_filter = "WHERE tenant_id = :tenant_id"
|
||||
type_query_params["tenant_id"] = tenant_id
|
||||
|
||||
type_query = text(f"""
|
||||
SELECT artifact_type, COUNT(*) as count
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY artifact_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(type_query, type_query_params)
|
||||
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get storage location stats
|
||||
location_query = text(f"""
|
||||
SELECT
|
||||
storage_location,
|
||||
COUNT(*) as count,
|
||||
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY storage_location
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
location_result = await self.session.execute(location_query, type_query_params)
|
||||
storage_stats = {}
|
||||
total_size_bytes = 0
|
||||
|
||||
for row in location_result.fetchall():
|
||||
storage_stats[row.storage_location] = {
|
||||
"artifact_count": row.count,
|
||||
"total_size_bytes": int(row.total_size_bytes or 0),
|
||||
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
|
||||
}
|
||||
total_size_bytes += row.total_size_bytes or 0
|
||||
|
||||
# Get expired artifacts count
|
||||
expired_artifacts = len(await self.get_expired_artifacts())
|
||||
|
||||
return {
|
||||
"total_artifacts": total_artifacts,
|
||||
"expired_artifacts": expired_artifacts,
|
||||
"active_artifacts": total_artifacts - expired_artifacts,
|
||||
"artifacts_by_type": artifacts_by_type,
|
||||
"storage_statistics": storage_stats,
|
||||
"total_storage": {
|
||||
"total_size_bytes": total_size_bytes,
|
||||
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
|
||||
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_artifacts": 0,
|
||||
"expired_artifacts": 0,
|
||||
"active_artifacts": 0,
|
||||
"artifacts_by_type": {},
|
||||
"storage_statistics": {},
|
||||
"total_storage": {
|
||||
"total_size_bytes": 0,
|
||||
"total_size_mb": 0.0,
|
||||
"total_size_gb": 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
|
||||
"""Verify artifact file integrity with actual file system checks"""
|
||||
try:
|
||||
import os
|
||||
import hashlib
|
||||
|
||||
artifact = await self.get_by_id(artifact_id)
|
||||
if not artifact:
|
||||
return {"exists": False, "error": "Artifact not found"}
|
||||
|
||||
# Check if file exists
|
||||
file_exists = os.path.exists(artifact.file_path)
|
||||
if not file_exists:
|
||||
return {
|
||||
"artifact_id": artifact_id,
|
||||
"file_path": artifact.file_path,
|
||||
"exists": False,
|
||||
"checksum_valid": False,
|
||||
"size_valid": False,
|
||||
"storage_location": artifact.storage_location,
|
||||
"last_verified": datetime.now().isoformat(),
|
||||
"error": "File does not exist on disk"
|
||||
}
|
||||
|
||||
# Verify file size
|
||||
actual_size = os.path.getsize(artifact.file_path)
|
||||
size_valid = True
|
||||
if artifact.file_size_bytes:
|
||||
size_valid = (actual_size == artifact.file_size_bytes)
|
||||
|
||||
# Verify checksum if stored
|
||||
checksum_valid = True
|
||||
actual_checksum = None
|
||||
if artifact.checksum:
|
||||
# Calculate checksum of actual file
|
||||
sha256_hash = hashlib.sha256()
|
||||
try:
|
||||
with open(artifact.file_path, "rb") as f:
|
||||
# Read file in chunks to handle large files
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
actual_checksum = sha256_hash.hexdigest()
|
||||
checksum_valid = (actual_checksum == artifact.checksum)
|
||||
except Exception as checksum_error:
|
||||
logger.error(f"Failed to calculate checksum: {checksum_error}")
|
||||
checksum_valid = False
|
||||
actual_checksum = None
|
||||
|
||||
# Overall integrity status
|
||||
integrity_valid = file_exists and size_valid and checksum_valid
|
||||
|
||||
result = {
|
||||
"artifact_id": artifact_id,
|
||||
"file_path": artifact.file_path,
|
||||
"exists": file_exists,
|
||||
"checksum_valid": checksum_valid,
|
||||
"size_valid": size_valid,
|
||||
"integrity_valid": integrity_valid,
|
||||
"storage_location": artifact.storage_location,
|
||||
"last_verified": datetime.now().isoformat(),
|
||||
"details": {
|
||||
"stored_size_bytes": artifact.file_size_bytes,
|
||||
"actual_size_bytes": actual_size if file_exists else None,
|
||||
"stored_checksum": artifact.checksum,
|
||||
"actual_checksum": actual_checksum
|
||||
}
|
||||
}
|
||||
|
||||
if not integrity_valid:
|
||||
issues = []
|
||||
if not file_exists:
|
||||
issues.append("file_missing")
|
||||
if not size_valid:
|
||||
issues.append("size_mismatch")
|
||||
if not checksum_valid:
|
||||
issues.append("checksum_mismatch")
|
||||
result["issues"] = issues
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify artifact integrity",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"exists": False,
|
||||
"error": f"Verification failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def migrate_artifacts_to_storage(
|
||||
self,
|
||||
from_location: str,
|
||||
to_location: str,
|
||||
tenant_id: str = None,
|
||||
copy_only: bool = False,
|
||||
verify: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""Migrate artifacts from one storage location to another with actual file operations"""
|
||||
try:
|
||||
import os
|
||||
import shutil
|
||||
import hashlib
|
||||
|
||||
# Get artifacts to migrate
|
||||
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
|
||||
|
||||
migrated_count = 0
|
||||
failed_count = 0
|
||||
failed_artifacts = []
|
||||
verified_count = 0
|
||||
|
||||
for artifact in artifacts:
|
||||
try:
|
||||
# Determine new file path
|
||||
new_file_path = artifact.file_path.replace(from_location, to_location, 1)
|
||||
|
||||
# Create destination directory if it doesn't exist
|
||||
dest_dir = os.path.dirname(new_file_path)
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
# Check if source file exists
|
||||
if not os.path.exists(artifact.file_path):
|
||||
logger.warning(f"Source file not found: {artifact.file_path}")
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": artifact.file_path,
|
||||
"reason": "source_file_not_found"
|
||||
})
|
||||
continue
|
||||
|
||||
# Copy or move file
|
||||
if copy_only:
|
||||
shutil.copy2(artifact.file_path, new_file_path)
|
||||
logger.debug(f"Copied file from {artifact.file_path} to {new_file_path}")
|
||||
else:
|
||||
shutil.move(artifact.file_path, new_file_path)
|
||||
logger.debug(f"Moved file from {artifact.file_path} to {new_file_path}")
|
||||
|
||||
# Verify file was copied/moved successfully
|
||||
if verify and os.path.exists(new_file_path):
|
||||
# Verify file size
|
||||
new_size = os.path.getsize(new_file_path)
|
||||
if artifact.file_size_bytes and new_size != artifact.file_size_bytes:
|
||||
logger.warning(f"File size mismatch after migration: {new_file_path}")
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": new_file_path,
|
||||
"reason": "size_mismatch_after_migration"
|
||||
})
|
||||
continue
|
||||
|
||||
# Verify checksum if available
|
||||
if artifact.checksum:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(new_file_path, "rb") as f:
|
||||
for byte_block in iter(lambda: f.read(4096), b""):
|
||||
sha256_hash.update(byte_block)
|
||||
new_checksum = sha256_hash.hexdigest()
|
||||
|
||||
if new_checksum != artifact.checksum:
|
||||
logger.warning(f"Checksum mismatch after migration: {new_file_path}")
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": new_file_path,
|
||||
"reason": "checksum_mismatch_after_migration"
|
||||
})
|
||||
continue
|
||||
|
||||
verified_count += 1
|
||||
|
||||
# Update database with new location
|
||||
await self.update(artifact.id, {
|
||||
"storage_location": to_location,
|
||||
"file_path": new_file_path
|
||||
})
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as migration_error:
|
||||
logger.error("Failed to migrate artifact",
|
||||
artifact_id=artifact.id,
|
||||
error=str(migration_error))
|
||||
failed_count += 1
|
||||
failed_artifacts.append({
|
||||
"artifact_id": artifact.id,
|
||||
"file_path": artifact.file_path,
|
||||
"reason": str(migration_error)
|
||||
})
|
||||
|
||||
logger.info("Artifact migration completed",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
migrated_count=migrated_count,
|
||||
failed_count=failed_count,
|
||||
verified_count=verified_count)
|
||||
|
||||
return {
|
||||
"from_location": from_location,
|
||||
"to_location": to_location,
|
||||
"total_artifacts": len(artifacts),
|
||||
"migrated_count": migrated_count,
|
||||
"failed_count": failed_count,
|
||||
"verified_count": verified_count if verify else None,
|
||||
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100,
|
||||
"copy_only": copy_only,
|
||||
"failed_artifacts": failed_artifacts if failed_artifacts else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate artifacts",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
error=str(e))
|
||||
return {
|
||||
"error": f"Migration failed: {str(e)}"
|
||||
}
|
||||
Reference in New Issue
Block a user