REFACTOR - Database logic
This commit is contained in:
20
services/training/app/repositories/__init__.py
Normal file
20
services/training/app/repositories/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""
|
||||
Training Service Repositories
|
||||
Repository implementations for training service
|
||||
"""
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from .model_repository import ModelRepository
|
||||
from .training_log_repository import TrainingLogRepository
|
||||
from .performance_repository import PerformanceRepository
|
||||
from .job_queue_repository import JobQueueRepository
|
||||
from .artifact_repository import ArtifactRepository
|
||||
|
||||
__all__ = [
|
||||
"TrainingBaseRepository",
|
||||
"ModelRepository",
|
||||
"TrainingLogRepository",
|
||||
"PerformanceRepository",
|
||||
"JobQueueRepository",
|
||||
"ArtifactRepository"
|
||||
]
|
||||
433
services/training/app/repositories/artifact_repository.py
Normal file
433
services/training/app/repositories/artifact_repository.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Artifact Repository
|
||||
Repository for model artifact operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelArtifact
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ArtifactRepository(TrainingBaseRepository):
|
||||
"""Repository for model artifact operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 1800):
|
||||
# Artifacts are stable, longer cache time (30 minutes)
|
||||
super().__init__(ModelArtifact, session, cache_ttl)
|
||||
|
||||
async def create_artifact(self, artifact_data: Dict[str, Any]) -> ModelArtifact:
|
||||
"""Create a new model artifact record"""
|
||||
try:
|
||||
# Validate artifact data
|
||||
validation_result = self._validate_training_data(
|
||||
artifact_data,
|
||||
["model_id", "tenant_id", "artifact_type", "file_path"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid artifact data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "storage_location" not in artifact_data:
|
||||
artifact_data["storage_location"] = "local"
|
||||
|
||||
# Create artifact record
|
||||
artifact = await self.create(artifact_data)
|
||||
|
||||
logger.info("Model artifact created",
|
||||
model_id=artifact.model_id,
|
||||
tenant_id=artifact.tenant_id,
|
||||
artifact_type=artifact.artifact_type,
|
||||
file_path=artifact.file_path)
|
||||
|
||||
return artifact
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create model artifact",
|
||||
model_id=artifact_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create artifact: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
artifact_type: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get all artifacts for a model"""
|
||||
try:
|
||||
filters = {"model_id": model_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by model",
|
||||
model_id=model_id,
|
||||
artifact_type=artifact_type,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifacts_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
artifact_type: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if artifact_type:
|
||||
filters["artifact_type"] = artifact_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_by_path(self, file_path: str) -> Optional[ModelArtifact]:
|
||||
"""Get artifact by file path"""
|
||||
try:
|
||||
return await self.get_by_field("file_path", file_path)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact by path",
|
||||
file_path=file_path,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifact: {str(e)}")
|
||||
|
||||
async def update_artifact_size(self, artifact_id: int, file_size_bytes: int) -> Optional[ModelArtifact]:
|
||||
"""Update artifact file size"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"file_size_bytes": file_size_bytes})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact size",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def update_artifact_checksum(self, artifact_id: int, checksum: str) -> Optional[ModelArtifact]:
|
||||
"""Update artifact checksum for integrity verification"""
|
||||
try:
|
||||
return await self.update(artifact_id, {"checksum": checksum})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update artifact checksum",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def mark_artifact_expired(self, artifact_id: int, expires_at: datetime = None) -> Optional[ModelArtifact]:
|
||||
"""Mark artifact for expiration/cleanup"""
|
||||
try:
|
||||
if not expires_at:
|
||||
expires_at = datetime.now()
|
||||
|
||||
return await self.update(artifact_id, {"expires_at": expires_at})
|
||||
except Exception as e:
|
||||
logger.error("Failed to mark artifact as expired",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_expired_artifacts(self, days_expired: int = 0) -> List[ModelArtifact]:
|
||||
"""Get artifacts that have expired"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
ORDER BY expires_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
|
||||
expired_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
expired_artifacts.append(artifact)
|
||||
|
||||
return expired_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def cleanup_expired_artifacts(self, days_expired: int = 0) -> int:
|
||||
"""Clean up expired artifacts"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_expired)
|
||||
|
||||
query_text = """
|
||||
DELETE FROM model_artifacts
|
||||
WHERE expires_at IS NOT NULL
|
||||
AND expires_at <= :cutoff_date
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_date": cutoff_date})
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up expired artifacts",
|
||||
deleted_count=deleted_count,
|
||||
days_expired=days_expired)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup expired artifacts",
|
||||
days_expired=days_expired,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Artifact cleanup failed: {str(e)}")
|
||||
|
||||
async def get_large_artifacts(self, min_size_mb: int = 100) -> List[ModelArtifact]:
|
||||
"""Get artifacts larger than specified size"""
|
||||
try:
|
||||
min_size_bytes = min_size_mb * 1024 * 1024 # Convert MB to bytes
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM model_artifacts
|
||||
WHERE file_size_bytes >= :min_size_bytes
|
||||
ORDER BY file_size_bytes DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"min_size_bytes": min_size_bytes})
|
||||
|
||||
large_artifacts = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
artifact = self.model(**record_dict)
|
||||
large_artifacts.append(artifact)
|
||||
|
||||
return large_artifacts
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get large artifacts",
|
||||
min_size_mb=min_size_mb,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def get_artifacts_by_storage_location(
|
||||
self,
|
||||
storage_location: str,
|
||||
tenant_id: str = None
|
||||
) -> List[ModelArtifact]:
|
||||
"""Get artifacts by storage location"""
|
||||
try:
|
||||
filters = {"storage_location": storage_location}
|
||||
if tenant_id:
|
||||
filters["tenant_id"] = tenant_id
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifacts by storage location",
|
||||
storage_location=storage_location,
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get artifacts: {str(e)}")
|
||||
|
||||
async def get_artifact_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get artifact statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get basic counts
|
||||
total_artifacts = await self.count(filters=base_filters)
|
||||
|
||||
# Get artifacts by type
|
||||
type_query_params = {}
|
||||
type_query_filter = ""
|
||||
if tenant_id:
|
||||
type_query_filter = "WHERE tenant_id = :tenant_id"
|
||||
type_query_params["tenant_id"] = tenant_id
|
||||
|
||||
type_query = text(f"""
|
||||
SELECT artifact_type, COUNT(*) as count
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY artifact_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(type_query, type_query_params)
|
||||
artifacts_by_type = {row.artifact_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get storage location stats
|
||||
location_query = text(f"""
|
||||
SELECT
|
||||
storage_location,
|
||||
COUNT(*) as count,
|
||||
SUM(COALESCE(file_size_bytes, 0)) as total_size_bytes
|
||||
FROM model_artifacts
|
||||
{type_query_filter}
|
||||
GROUP BY storage_location
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
location_result = await self.session.execute(location_query, type_query_params)
|
||||
storage_stats = {}
|
||||
total_size_bytes = 0
|
||||
|
||||
for row in location_result.fetchall():
|
||||
storage_stats[row.storage_location] = {
|
||||
"artifact_count": row.count,
|
||||
"total_size_bytes": int(row.total_size_bytes or 0),
|
||||
"total_size_mb": round((row.total_size_bytes or 0) / (1024 * 1024), 2)
|
||||
}
|
||||
total_size_bytes += row.total_size_bytes or 0
|
||||
|
||||
# Get expired artifacts count
|
||||
expired_artifacts = len(await self.get_expired_artifacts())
|
||||
|
||||
return {
|
||||
"total_artifacts": total_artifacts,
|
||||
"expired_artifacts": expired_artifacts,
|
||||
"active_artifacts": total_artifacts - expired_artifacts,
|
||||
"artifacts_by_type": artifacts_by_type,
|
||||
"storage_statistics": storage_stats,
|
||||
"total_storage": {
|
||||
"total_size_bytes": total_size_bytes,
|
||||
"total_size_mb": round(total_size_bytes / (1024 * 1024), 2),
|
||||
"total_size_gb": round(total_size_bytes / (1024 * 1024 * 1024), 2)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get artifact statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_artifacts": 0,
|
||||
"expired_artifacts": 0,
|
||||
"active_artifacts": 0,
|
||||
"artifacts_by_type": {},
|
||||
"storage_statistics": {},
|
||||
"total_storage": {
|
||||
"total_size_bytes": 0,
|
||||
"total_size_mb": 0.0,
|
||||
"total_size_gb": 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def verify_artifact_integrity(self, artifact_id: int) -> Dict[str, Any]:
|
||||
"""Verify artifact file integrity (placeholder for file system checks)"""
|
||||
try:
|
||||
artifact = await self.get_by_id(artifact_id)
|
||||
if not artifact:
|
||||
return {"exists": False, "error": "Artifact not found"}
|
||||
|
||||
# This is a placeholder - in a real implementation, you would:
|
||||
# 1. Check if the file exists at artifact.file_path
|
||||
# 2. Calculate current checksum and compare with stored checksum
|
||||
# 3. Verify file size matches stored file_size_bytes
|
||||
|
||||
return {
|
||||
"artifact_id": artifact_id,
|
||||
"file_path": artifact.file_path,
|
||||
"exists": True, # Would check actual file existence
|
||||
"checksum_valid": True, # Would verify actual checksum
|
||||
"size_valid": True, # Would verify actual file size
|
||||
"storage_location": artifact.storage_location,
|
||||
"last_verified": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to verify artifact integrity",
|
||||
artifact_id=artifact_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"exists": False,
|
||||
"error": f"Verification failed: {str(e)}"
|
||||
}
|
||||
|
||||
async def migrate_artifacts_to_storage(
|
||||
self,
|
||||
from_location: str,
|
||||
to_location: str,
|
||||
tenant_id: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Migrate artifacts from one storage location to another (placeholder)"""
|
||||
try:
|
||||
# Get artifacts to migrate
|
||||
artifacts = await self.get_artifacts_by_storage_location(from_location, tenant_id)
|
||||
|
||||
migrated_count = 0
|
||||
failed_count = 0
|
||||
|
||||
# This is a placeholder - in a real implementation, you would:
|
||||
# 1. Copy files from old location to new location
|
||||
# 2. Update file paths in database
|
||||
# 3. Verify successful migration
|
||||
# 4. Clean up old files
|
||||
|
||||
for artifact in artifacts:
|
||||
try:
|
||||
# Placeholder migration logic
|
||||
new_file_path = artifact.file_path.replace(from_location, to_location)
|
||||
|
||||
await self.update(artifact.id, {
|
||||
"storage_location": to_location,
|
||||
"file_path": new_file_path
|
||||
})
|
||||
|
||||
migrated_count += 1
|
||||
|
||||
except Exception as migration_error:
|
||||
logger.error("Failed to migrate artifact",
|
||||
artifact_id=artifact.id,
|
||||
error=str(migration_error))
|
||||
failed_count += 1
|
||||
|
||||
logger.info("Artifact migration completed",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
migrated_count=migrated_count,
|
||||
failed_count=failed_count)
|
||||
|
||||
return {
|
||||
"from_location": from_location,
|
||||
"to_location": to_location,
|
||||
"total_artifacts": len(artifacts),
|
||||
"migrated_count": migrated_count,
|
||||
"failed_count": failed_count,
|
||||
"success_rate": round((migrated_count / len(artifacts)) * 100, 2) if artifacts else 100
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to migrate artifacts",
|
||||
from_location=from_location,
|
||||
to_location=to_location,
|
||||
error=str(e))
|
||||
return {
|
||||
"error": f"Migration failed: {str(e)}"
|
||||
}
|
||||
179
services/training/app/repositories/base.py
Normal file
179
services/training/app/repositories/base.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Base Repository for Training Service
|
||||
Service-specific repository base class with training service utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrainingBaseRepository(BaseRepository):
|
||||
"""Base repository for training service with common training operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Training data changes frequently, shorter cache time (5 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by tenant ID"""
|
||||
if hasattr(self.model, 'tenant_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_job_id(self, job_id: str) -> Optional:
|
||||
"""Get record by job ID (if model has job_id field)"""
|
||||
if hasattr(self.model, 'job_id'):
|
||||
return await self.get_by_field("job_id", job_id)
|
||||
return None
|
||||
|
||||
async def get_by_model_id(self, model_id: str) -> Optional:
|
||||
"""Get record by model ID (if model has model_id field)"""
|
||||
if hasattr(self.model, 'model_id'):
|
||||
return await self.get_by_field("model_id", model_id)
|
||||
return None
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
|
||||
"""Clean up old training records"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Build query based on available fields
|
||||
conditions = [f"created_at < :cutoff_date"]
|
||||
params = {"cutoff_date": cutoff_date}
|
||||
|
||||
if status_filter and hasattr(self.model, 'status'):
|
||||
conditions.append(f"status = :status")
|
||||
params["status"] = status_filter
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up old {self.model.__name__} records",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old,
|
||||
status_filter=status_filter)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_records_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records within date range"""
|
||||
if not hasattr(self.model, 'created_at'):
|
||||
logger.warning(f"Model {self.model.__name__} has no created_at field")
|
||||
return []
|
||||
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE created_at >= :start_date
|
||||
AND created_at <= :end_date
|
||||
ORDER BY created_at DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"limit": limit,
|
||||
"skip": skip
|
||||
})
|
||||
|
||||
# Convert rows to model objects
|
||||
records = []
|
||||
for row in result.fetchall():
|
||||
# Create model instance from row data
|
||||
record_dict = dict(row._mapping)
|
||||
record = self.model(**record_dict)
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get records by date range",
|
||||
model=self.model.__name__,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate training-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate tenant_id format if present
|
||||
if "tenant_id" in data and data["tenant_id"]:
|
||||
tenant_id = data["tenant_id"]
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate job_id format if present
|
||||
if "job_id" in data and data["job_id"]:
|
||||
job_id = data["job_id"]
|
||||
if not isinstance(job_id, str) or len(job_id) < 1:
|
||||
errors.append("Invalid job_id format")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
445
services/training/app/repositories/job_queue_repository.py
Normal file
445
services/training/app/repositories/job_queue_repository.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""
|
||||
Job Queue Repository
|
||||
Repository for training job queue operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import TrainingJobQueue
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class JobQueueRepository(TrainingBaseRepository):
|
||||
"""Repository for training job queue operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 60):
|
||||
# Job queue changes frequently, very short cache time (1 minute)
|
||||
super().__init__(TrainingJobQueue, session, cache_ttl)
|
||||
|
||||
async def enqueue_job(self, job_data: Dict[str, Any]) -> TrainingJobQueue:
|
||||
"""Add a job to the training queue"""
|
||||
try:
|
||||
# Validate job data
|
||||
validation_result = self._validate_training_data(
|
||||
job_data,
|
||||
["job_id", "tenant_id", "job_type"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid job data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "priority" not in job_data:
|
||||
job_data["priority"] = 1
|
||||
if "status" not in job_data:
|
||||
job_data["status"] = "queued"
|
||||
if "max_retries" not in job_data:
|
||||
job_data["max_retries"] = 3
|
||||
|
||||
# Create queue entry
|
||||
queued_job = await self.create(job_data)
|
||||
|
||||
logger.info("Job enqueued",
|
||||
job_id=queued_job.job_id,
|
||||
tenant_id=queued_job.tenant_id,
|
||||
job_type=queued_job.job_type,
|
||||
priority=queued_job.priority)
|
||||
|
||||
return queued_job
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to enqueue job",
|
||||
job_id=job_data.get("job_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to enqueue job: {str(e)}")
|
||||
|
||||
async def get_next_job(self, job_types: List[str] = None) -> Optional[TrainingJobQueue]:
|
||||
"""Get the next job to process from the queue"""
|
||||
try:
|
||||
# Build filters for job types if specified
|
||||
filters = {"status": "queued"}
|
||||
|
||||
if job_types:
|
||||
# For multiple job types, we need to use raw SQL
|
||||
job_types_str = "', '".join(job_types)
|
||||
query_text = f"""
|
||||
SELECT * FROM training_job_queue
|
||||
WHERE status = 'queued'
|
||||
AND job_type IN ('{job_types_str}')
|
||||
AND (scheduled_at IS NULL OR scheduled_at <= :now)
|
||||
ORDER BY priority DESC, created_at ASC
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"now": datetime.now()})
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
record_dict = dict(row._mapping)
|
||||
return self.model(**record_dict)
|
||||
return None
|
||||
else:
|
||||
# Simple case - get any queued job
|
||||
jobs = await self.get_multi(
|
||||
filters=filters,
|
||||
limit=1,
|
||||
order_by="priority",
|
||||
order_desc=True
|
||||
)
|
||||
return jobs[0] if jobs else None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get next job from queue",
|
||||
job_types=job_types,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get next job: {str(e)}")
|
||||
|
||||
async def start_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as started"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
if job.status != "queued":
|
||||
logger.warning(f"Job {job_id} is not queued (status: {job.status})")
|
||||
return job
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "running",
|
||||
"started_at": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job started",
|
||||
job_id=job_id,
|
||||
job_type=job.job_type)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to start job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to start job: {str(e)}")
|
||||
|
||||
async def complete_job(self, job_id: str) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as completed"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "completed",
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job completed",
|
||||
job_id=job_id,
|
||||
job_type=job.job_type if job else "unknown")
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to complete job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to complete job: {str(e)}")
|
||||
|
||||
async def fail_job(self, job_id: str, error_message: str = None) -> Optional[TrainingJobQueue]:
|
||||
"""Mark a job as failed and handle retries"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
# Increment retry count
|
||||
new_retry_count = job.retry_count + 1
|
||||
|
||||
# Check if we should retry
|
||||
if new_retry_count < job.max_retries:
|
||||
# Reset to queued for retry
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "queued",
|
||||
"retry_count": new_retry_count,
|
||||
"updated_at": datetime.now(),
|
||||
"started_at": None # Reset started_at for retry
|
||||
})
|
||||
|
||||
logger.info("Job failed, queued for retry",
|
||||
job_id=job_id,
|
||||
retry_count=new_retry_count,
|
||||
max_retries=job.max_retries)
|
||||
else:
|
||||
# Mark as permanently failed
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "failed",
|
||||
"retry_count": new_retry_count,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.error("Job permanently failed",
|
||||
job_id=job_id,
|
||||
retry_count=new_retry_count,
|
||||
error_message=error_message)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to handle job failure",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to handle job failure: {str(e)}")
|
||||
|
||||
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[TrainingJobQueue]:
|
||||
"""Cancel a job"""
|
||||
try:
|
||||
job = await self.get_by_job_id(job_id)
|
||||
if not job:
|
||||
logger.error(f"Job not found in queue: {job_id}")
|
||||
return None
|
||||
|
||||
if job.status in ["completed", "failed"]:
|
||||
logger.warning(f"Cannot cancel job {job_id} with status {job.status}")
|
||||
return job
|
||||
|
||||
updated_job = await self.update(job.id, {
|
||||
"status": "cancelled",
|
||||
"cancelled_by": cancelled_by,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
|
||||
logger.info("Job cancelled",
|
||||
job_id=job_id,
|
||||
cancelled_by=cancelled_by)
|
||||
|
||||
return updated_job
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel job: {str(e)}")
|
||||
|
||||
async def get_queue_status(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get queue status and statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get counts by status
|
||||
queued_jobs = await self.count(filters={**base_filters, "status": "queued"})
|
||||
running_jobs = await self.count(filters={**base_filters, "status": "running"})
|
||||
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
|
||||
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
|
||||
cancelled_jobs = await self.count(filters={**base_filters, "status": "cancelled"})
|
||||
|
||||
# Get jobs by type
|
||||
type_query = text(f"""
|
||||
SELECT job_type, COUNT(*) as count
|
||||
FROM training_job_queue
|
||||
WHERE 1=1
|
||||
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
||||
GROUP BY job_type
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
params = {"tenant_id": tenant_id} if tenant_id else {}
|
||||
result = await self.session.execute(type_query, params)
|
||||
jobs_by_type = {row.job_type: row.count for row in result.fetchall()}
|
||||
|
||||
# Get average wait time for completed jobs
|
||||
wait_time_query = text(f"""
|
||||
SELECT
|
||||
AVG(EXTRACT(EPOCH FROM (started_at - created_at))/60) as avg_wait_minutes
|
||||
FROM training_job_queue
|
||||
WHERE status = 'completed'
|
||||
AND started_at IS NOT NULL
|
||||
AND created_at IS NOT NULL
|
||||
{' AND tenant_id = :tenant_id' if tenant_id else ''}
|
||||
""")
|
||||
|
||||
wait_result = await self.session.execute(wait_time_query, params)
|
||||
wait_row = wait_result.fetchone()
|
||||
avg_wait_time = float(wait_row.avg_wait_minutes) if wait_row and wait_row.avg_wait_minutes else 0.0
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"queue_counts": {
|
||||
"queued": queued_jobs,
|
||||
"running": running_jobs,
|
||||
"completed": completed_jobs,
|
||||
"failed": failed_jobs,
|
||||
"cancelled": cancelled_jobs,
|
||||
"total": queued_jobs + running_jobs + completed_jobs + failed_jobs + cancelled_jobs
|
||||
},
|
||||
"jobs_by_type": jobs_by_type,
|
||||
"avg_wait_time_minutes": round(avg_wait_time, 2),
|
||||
"queue_health": {
|
||||
"has_queued_jobs": queued_jobs > 0,
|
||||
"has_running_jobs": running_jobs > 0,
|
||||
"failure_rate": round((failed_jobs / max(completed_jobs + failed_jobs, 1)) * 100, 2)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get queue status",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"queue_counts": {
|
||||
"queued": 0, "running": 0, "completed": 0,
|
||||
"failed": 0, "cancelled": 0, "total": 0
|
||||
},
|
||||
"jobs_by_type": {},
|
||||
"avg_wait_time_minutes": 0.0,
|
||||
"queue_health": {
|
||||
"has_queued_jobs": False,
|
||||
"has_running_jobs": False,
|
||||
"failure_rate": 0.0
|
||||
}
|
||||
}
|
||||
|
||||
async def get_jobs_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
status: str = None,
|
||||
job_type: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrainingJobQueue]:
|
||||
"""Get jobs for a tenant with optional filtering"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if status:
|
||||
filters["status"] = status
|
||||
if job_type:
|
||||
filters["job_type"] = job_type
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get jobs by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get tenant jobs: {str(e)}")
|
||||
|
||||
async def cleanup_old_jobs(self, days_old: int = 30, status_filter: str = None) -> int:
|
||||
"""Clean up old completed/failed/cancelled jobs"""
|
||||
try:
|
||||
cutoff_date = datetime.now() - timedelta(days=days_old)
|
||||
|
||||
# Only clean up finished jobs by default
|
||||
default_statuses = ["completed", "failed", "cancelled"]
|
||||
|
||||
if status_filter:
|
||||
status_condition = "status = :status"
|
||||
params = {"cutoff_date": cutoff_date, "status": status_filter}
|
||||
else:
|
||||
status_list = "', '".join(default_statuses)
|
||||
status_condition = f"status IN ('{status_list}')"
|
||||
params = {"cutoff_date": cutoff_date}
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM training_job_queue
|
||||
WHERE created_at < :cutoff_date
|
||||
AND {status_condition}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info("Cleaned up old queue jobs",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old,
|
||||
status_filter=status_filter)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old queue jobs",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Queue cleanup failed: {str(e)}")
|
||||
|
||||
async def get_stuck_jobs(self, hours_stuck: int = 2) -> List[TrainingJobQueue]:
|
||||
"""Get jobs that have been running for too long"""
|
||||
try:
|
||||
cutoff_time = datetime.now() - timedelta(hours=hours_stuck)
|
||||
|
||||
query_text = """
|
||||
SELECT * FROM training_job_queue
|
||||
WHERE status = 'running'
|
||||
AND started_at IS NOT NULL
|
||||
AND started_at < :cutoff_time
|
||||
ORDER BY started_at ASC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {"cutoff_time": cutoff_time})
|
||||
|
||||
stuck_jobs = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
job = self.model(**record_dict)
|
||||
stuck_jobs.append(job)
|
||||
|
||||
if stuck_jobs:
|
||||
logger.warning("Found stuck jobs",
|
||||
count=len(stuck_jobs),
|
||||
hours_stuck=hours_stuck)
|
||||
|
||||
return stuck_jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get stuck jobs",
|
||||
hours_stuck=hours_stuck,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def reset_stuck_jobs(self, hours_stuck: int = 2) -> int:
|
||||
"""Reset stuck jobs back to queued status"""
|
||||
try:
|
||||
stuck_jobs = await self.get_stuck_jobs(hours_stuck)
|
||||
reset_count = 0
|
||||
|
||||
for job in stuck_jobs:
|
||||
# Reset job to queued status
|
||||
await self.update(job.id, {
|
||||
"status": "queued",
|
||||
"started_at": None,
|
||||
"updated_at": datetime.now()
|
||||
})
|
||||
reset_count += 1
|
||||
|
||||
if reset_count > 0:
|
||||
logger.info("Reset stuck jobs",
|
||||
reset_count=reset_count,
|
||||
hours_stuck=hours_stuck)
|
||||
|
||||
return reset_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to reset stuck jobs",
|
||||
hours_stuck=hours_stuck,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to reset stuck jobs: {str(e)}")
|
||||
346
services/training/app/repositories/model_repository.py
Normal file
346
services/training/app/repositories/model_repository.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""
|
||||
Model Repository
|
||||
Repository for trained model operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import TrainedModel
|
||||
from shared.database.exceptions import DatabaseError, ValidationError, DuplicateRecordError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class ModelRepository(TrainingBaseRepository):
|
||||
"""Repository for trained model operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 600):
|
||||
# Models are relatively stable, longer cache time (10 minutes)
|
||||
super().__init__(TrainedModel, session, cache_ttl)
|
||||
|
||||
async def create_model(self, model_data: Dict[str, Any]) -> TrainedModel:
|
||||
"""Create a new trained model with validation"""
|
||||
try:
|
||||
# Validate model data
|
||||
validation_result = self._validate_training_data(
|
||||
model_data,
|
||||
["tenant_id", "product_name", "model_path", "job_id"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid model data: {validation_result['errors']}")
|
||||
|
||||
# Check for duplicate active models for same tenant+product
|
||||
existing_model = await self.get_active_model_for_product(
|
||||
model_data["tenant_id"],
|
||||
model_data["product_name"]
|
||||
)
|
||||
|
||||
# If there's an existing active model, we may want to deactivate it
|
||||
if existing_model and model_data.get("is_production", False):
|
||||
logger.info("Deactivating previous production model",
|
||||
previous_model_id=existing_model.id,
|
||||
tenant_id=model_data["tenant_id"],
|
||||
product_name=model_data["product_name"])
|
||||
await self.update(existing_model.id, {"is_production": False})
|
||||
|
||||
# Create new model
|
||||
model = await self.create(model_data)
|
||||
|
||||
logger.info("Trained model created successfully",
|
||||
model_id=model.id,
|
||||
tenant_id=model.tenant_id,
|
||||
product_name=model.product_name,
|
||||
model_type=model.model_type)
|
||||
|
||||
return model
|
||||
|
||||
except (ValidationError, DuplicateRecordError):
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create trained model",
|
||||
tenant_id=model_data.get("tenant_id"),
|
||||
product_name=model_data.get("product_name"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create model: {str(e)}")
|
||||
|
||||
async def get_model_by_tenant_and_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str
|
||||
) -> List[TrainedModel]:
|
||||
"""Get all models for a tenant and product"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get models by tenant and product",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get models: {str(e)}")
|
||||
|
||||
async def get_active_model_for_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str
|
||||
) -> Optional[TrainedModel]:
|
||||
"""Get the active production model for a product"""
|
||||
try:
|
||||
models = await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"is_active": True,
|
||||
"is_production": True
|
||||
},
|
||||
order_by="created_at",
|
||||
order_desc=True,
|
||||
limit=1
|
||||
)
|
||||
return models[0] if models else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active model for product",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active model: {str(e)}")
|
||||
|
||||
async def get_models_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrainedModel]:
|
||||
"""Get all models for a tenant"""
|
||||
return await self.get_by_tenant_id(tenant_id, skip=skip, limit=limit)
|
||||
|
||||
async def promote_to_production(self, model_id: str) -> Optional[TrainedModel]:
|
||||
"""Promote a model to production"""
|
||||
try:
|
||||
# Get the model first
|
||||
model = await self.get_by_id(model_id)
|
||||
if not model:
|
||||
raise ValueError(f"Model {model_id} not found")
|
||||
|
||||
# Deactivate other production models for the same tenant+product
|
||||
await self._deactivate_other_production_models(
|
||||
model.tenant_id,
|
||||
model.product_name,
|
||||
model_id
|
||||
)
|
||||
|
||||
# Promote this model
|
||||
updated_model = await self.update(model_id, {
|
||||
"is_production": True,
|
||||
"last_used_at": datetime.utcnow()
|
||||
})
|
||||
|
||||
logger.info("Model promoted to production",
|
||||
model_id=model_id,
|
||||
tenant_id=model.tenant_id,
|
||||
product_name=model.product_name)
|
||||
|
||||
return updated_model
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to promote model to production",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to promote model: {str(e)}")
|
||||
|
||||
async def update_model_usage(self, model_id: str) -> Optional[TrainedModel]:
|
||||
"""Update model last used timestamp"""
|
||||
try:
|
||||
return await self.update(model_id, {
|
||||
"last_used_at": datetime.utcnow()
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error("Failed to update model usage",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
# Don't raise here - usage update is not critical
|
||||
return None
|
||||
|
||||
async def archive_old_models(self, tenant_id: str, days_old: int = 90) -> int:
|
||||
"""Archive old non-production models"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
|
||||
query = text("""
|
||||
UPDATE trained_models
|
||||
SET is_active = false
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND is_production = false
|
||||
AND created_at < :cutoff_date
|
||||
AND is_active = true
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"cutoff_date": cutoff_date
|
||||
})
|
||||
|
||||
archived_count = result.rowcount
|
||||
|
||||
logger.info("Archived old models",
|
||||
tenant_id=tenant_id,
|
||||
archived_count=archived_count,
|
||||
days_old=days_old)
|
||||
|
||||
return archived_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to archive old models",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Model archival failed: {str(e)}")
|
||||
|
||||
async def get_model_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get model statistics for a tenant"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_models = await self.count(filters={"tenant_id": tenant_id})
|
||||
active_models = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_active": True
|
||||
})
|
||||
production_models = await self.count(filters={
|
||||
"tenant_id": tenant_id,
|
||||
"is_production": True
|
||||
})
|
||||
|
||||
# Get models by product using raw query
|
||||
product_query = text("""
|
||||
SELECT product_name, COUNT(*) as count
|
||||
FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND is_active = true
|
||||
GROUP BY product_name
|
||||
ORDER BY count DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {row.product_name: row.count for row in result.fetchall()}
|
||||
|
||||
# Recent activity (models created in last 30 days)
|
||||
thirty_days_ago = datetime.utcnow() - timedelta(days=30)
|
||||
recent_models_query = text("""
|
||||
SELECT COUNT(*) as count
|
||||
FROM trained_models
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND created_at >= :thirty_days_ago
|
||||
""")
|
||||
|
||||
recent_result = await self.session.execute(
|
||||
recent_models_query,
|
||||
{"tenant_id": tenant_id, "thirty_days_ago": thirty_days_ago}
|
||||
)
|
||||
recent_models = recent_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total_models": total_models,
|
||||
"active_models": active_models,
|
||||
"inactive_models": total_models - active_models,
|
||||
"production_models": production_models,
|
||||
"models_by_product": product_stats,
|
||||
"recent_models_30d": recent_models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get model statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_models": 0,
|
||||
"active_models": 0,
|
||||
"inactive_models": 0,
|
||||
"production_models": 0,
|
||||
"models_by_product": {},
|
||||
"recent_models_30d": 0
|
||||
}
|
||||
|
||||
async def _deactivate_other_production_models(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
exclude_model_id: str
|
||||
) -> int:
|
||||
"""Deactivate other production models for the same tenant+product"""
|
||||
try:
|
||||
query = text("""
|
||||
UPDATE trained_models
|
||||
SET is_production = false
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND product_name = :product_name
|
||||
AND id != :exclude_model_id
|
||||
AND is_production = true
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"exclude_model_id": exclude_model_id
|
||||
})
|
||||
|
||||
return result.rowcount
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to deactivate other production models",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to deactivate models: {str(e)}")
|
||||
|
||||
async def get_model_performance_summary(self, model_id: str) -> Dict[str, Any]:
|
||||
"""Get performance summary for a model"""
|
||||
try:
|
||||
model = await self.get_by_id(model_id)
|
||||
if not model:
|
||||
return {}
|
||||
|
||||
return {
|
||||
"model_id": model.id,
|
||||
"tenant_id": model.tenant_id,
|
||||
"product_name": model.product_name,
|
||||
"model_type": model.model_type,
|
||||
"metrics": {
|
||||
"mape": model.mape,
|
||||
"mae": model.mae,
|
||||
"rmse": model.rmse,
|
||||
"r2_score": model.r2_score
|
||||
},
|
||||
"training_info": {
|
||||
"training_samples": model.training_samples,
|
||||
"training_start_date": model.training_start_date.isoformat() if model.training_start_date else None,
|
||||
"training_end_date": model.training_end_date.isoformat() if model.training_end_date else None,
|
||||
"data_quality_score": model.data_quality_score
|
||||
},
|
||||
"status": {
|
||||
"is_active": model.is_active,
|
||||
"is_production": model.is_production,
|
||||
"created_at": model.created_at.isoformat() if model.created_at else None,
|
||||
"last_used_at": model.last_used_at.isoformat() if model.last_used_at else None
|
||||
},
|
||||
"features": {
|
||||
"hyperparameters": model.hyperparameters,
|
||||
"features_used": model.features_used
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get model performance summary",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
return {}
|
||||
433
services/training/app/repositories/performance_repository.py
Normal file
433
services/training/app/repositories/performance_repository.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Performance Repository
|
||||
Repository for model performance metrics operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelPerformanceMetric
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class PerformanceRepository(TrainingBaseRepository):
|
||||
"""Repository for model performance metrics operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 900):
|
||||
# Performance metrics are relatively stable, longer cache time (15 minutes)
|
||||
super().__init__(ModelPerformanceMetric, session, cache_ttl)
|
||||
|
||||
async def create_performance_metric(self, metric_data: Dict[str, Any]) -> ModelPerformanceMetric:
|
||||
"""Create a new performance metric record"""
|
||||
try:
|
||||
# Validate metric data
|
||||
validation_result = self._validate_training_data(
|
||||
metric_data,
|
||||
["model_id", "tenant_id", "product_name"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid metric data: {validation_result['errors']}")
|
||||
|
||||
# Set measurement timestamp if not provided
|
||||
if "measured_at" not in metric_data:
|
||||
metric_data["measured_at"] = datetime.now()
|
||||
|
||||
# Create metric record
|
||||
metric = await self.create(metric_data)
|
||||
|
||||
logger.info("Performance metric created",
|
||||
model_id=metric.model_id,
|
||||
tenant_id=metric.tenant_id,
|
||||
product_name=metric.product_name)
|
||||
|
||||
return metric
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create performance metric",
|
||||
model_id=metric_data.get("model_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create metric: {str(e)}")
|
||||
|
||||
async def get_metrics_by_model(
|
||||
self,
|
||||
model_id: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get all performance metrics for a model"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics by model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
|
||||
async def get_latest_metric_for_model(self, model_id: str) -> Optional[ModelPerformanceMetric]:
|
||||
"""Get the latest performance metric for a model"""
|
||||
try:
|
||||
metrics = await self.get_multi(
|
||||
filters={"model_id": model_id},
|
||||
limit=1,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
return metrics[0] if metrics else None
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest metric for model",
|
||||
model_id=model_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest metric: {str(e)}")
|
||||
|
||||
async def get_metrics_by_tenant_and_product(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get performance metrics for a tenant's product"""
|
||||
try:
|
||||
return await self.get_multi(
|
||||
filters={
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name
|
||||
},
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="measured_at",
|
||||
order_desc=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics by tenant and product",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get metrics: {str(e)}")
|
||||
|
||||
async def get_metrics_in_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: str = None,
|
||||
model_id: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelPerformanceMetric]:
|
||||
"""Get performance metrics within a date range"""
|
||||
try:
|
||||
# Build filters
|
||||
table_name = self.model.__tablename__
|
||||
conditions = ["measured_at >= :start_date", "measured_at <= :end_date"]
|
||||
params = {"start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip}
|
||||
|
||||
if tenant_id:
|
||||
conditions.append("tenant_id = :tenant_id")
|
||||
params["tenant_id"] = tenant_id
|
||||
|
||||
if model_id:
|
||||
conditions.append("model_id = :model_id")
|
||||
params["model_id"] = model_id
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
ORDER BY measured_at DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
# Convert rows to model objects
|
||||
metrics = []
|
||||
for row in result.fetchall():
|
||||
record_dict = dict(row._mapping)
|
||||
metric = self.model(**record_dict)
|
||||
metrics.append(metric)
|
||||
|
||||
return metrics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metrics in date range",
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
async def get_performance_trends(
|
||||
self,
|
||||
tenant_id: str,
|
||||
product_name: str = None,
|
||||
days: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance trends for analysis"""
|
||||
try:
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
end_date = datetime.now()
|
||||
|
||||
# Build query for performance trends
|
||||
conditions = ["tenant_id = :tenant_id", "measured_at >= :start_date"]
|
||||
params = {"tenant_id": tenant_id, "start_date": start_date}
|
||||
|
||||
if product_name:
|
||||
conditions.append("product_name = :product_name")
|
||||
params["product_name"] = product_name
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
product_name,
|
||||
AVG(mae) as avg_mae,
|
||||
AVG(mse) as avg_mse,
|
||||
AVG(rmse) as avg_rmse,
|
||||
AVG(mape) as avg_mape,
|
||||
AVG(r2_score) as avg_r2_score,
|
||||
AVG(accuracy_percentage) as avg_accuracy,
|
||||
COUNT(*) as measurement_count,
|
||||
MIN(measured_at) as first_measurement,
|
||||
MAX(measured_at) as last_measurement
|
||||
FROM model_performance_metrics
|
||||
WHERE {' AND '.join(conditions)}
|
||||
GROUP BY product_name
|
||||
ORDER BY avg_accuracy DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
|
||||
trends = []
|
||||
for row in result.fetchall():
|
||||
trends.append({
|
||||
"product_name": row.product_name,
|
||||
"metrics": {
|
||||
"avg_mae": float(row.avg_mae) if row.avg_mae else None,
|
||||
"avg_mse": float(row.avg_mse) if row.avg_mse else None,
|
||||
"avg_rmse": float(row.avg_rmse) if row.avg_rmse else None,
|
||||
"avg_mape": float(row.avg_mape) if row.avg_mape else None,
|
||||
"avg_r2_score": float(row.avg_r2_score) if row.avg_r2_score else None,
|
||||
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
||||
},
|
||||
"measurement_count": int(row.measurement_count),
|
||||
"period": {
|
||||
"start": row.first_measurement.isoformat() if row.first_measurement else None,
|
||||
"end": row.last_measurement.isoformat() if row.last_measurement else None,
|
||||
"days": days
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"trends": trends,
|
||||
"period_days": days,
|
||||
"total_products": len(trends)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get performance trends",
|
||||
tenant_id=tenant_id,
|
||||
product_name=product_name,
|
||||
error=str(e))
|
||||
return {
|
||||
"tenant_id": tenant_id,
|
||||
"product_name": product_name,
|
||||
"trends": [],
|
||||
"period_days": days,
|
||||
"total_products": 0
|
||||
}
|
||||
|
||||
async def get_best_performing_models(
|
||||
self,
|
||||
tenant_id: str,
|
||||
metric_type: str = "accuracy_percentage",
|
||||
limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get best performing models based on a specific metric"""
|
||||
try:
|
||||
# Validate metric type
|
||||
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
||||
if metric_type not in valid_metrics:
|
||||
metric_type = "accuracy_percentage"
|
||||
|
||||
# For error metrics (mae, mse, rmse, mape), lower is better
|
||||
# For performance metrics (r2_score, accuracy_percentage), higher is better
|
||||
order_desc = metric_type in ["r2_score", "accuracy_percentage"]
|
||||
order_direction = "DESC" if order_desc else "ASC"
|
||||
|
||||
query_text = f"""
|
||||
SELECT DISTINCT ON (product_name, model_id)
|
||||
model_id,
|
||||
product_name,
|
||||
{metric_type},
|
||||
measured_at,
|
||||
evaluation_samples
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
AND {metric_type} IS NOT NULL
|
||||
ORDER BY product_name, model_id, measured_at DESC, {metric_type} {order_direction}
|
||||
LIMIT :limit
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"tenant_id": tenant_id,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
best_models = []
|
||||
for row in result.fetchall():
|
||||
best_models.append({
|
||||
"model_id": row.model_id,
|
||||
"product_name": row.product_name,
|
||||
"metric_value": float(getattr(row, metric_type)),
|
||||
"metric_type": metric_type,
|
||||
"measured_at": row.measured_at.isoformat() if row.measured_at else None,
|
||||
"evaluation_samples": int(row.evaluation_samples) if row.evaluation_samples else None
|
||||
})
|
||||
|
||||
return best_models
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get best performing models",
|
||||
tenant_id=tenant_id,
|
||||
metric_type=metric_type,
|
||||
error=str(e))
|
||||
return []
|
||||
|
||||
async def cleanup_old_metrics(self, days_old: int = 180) -> int:
|
||||
"""Clean up old performance metrics"""
|
||||
return await self.cleanup_old_records(days_old=days_old)
|
||||
|
||||
async def get_metric_statistics(self, tenant_id: str) -> Dict[str, Any]:
|
||||
"""Get performance metric statistics for a tenant"""
|
||||
try:
|
||||
# Get basic counts
|
||||
total_metrics = await self.count(filters={"tenant_id": tenant_id})
|
||||
|
||||
# Get metrics by product using raw query
|
||||
product_query = text("""
|
||||
SELECT
|
||||
product_name,
|
||||
COUNT(*) as metric_count,
|
||||
AVG(accuracy_percentage) as avg_accuracy
|
||||
FROM model_performance_metrics
|
||||
WHERE tenant_id = :tenant_id
|
||||
GROUP BY product_name
|
||||
ORDER BY avg_accuracy DESC
|
||||
""")
|
||||
|
||||
result = await self.session.execute(product_query, {"tenant_id": tenant_id})
|
||||
product_stats = {}
|
||||
|
||||
for row in result.fetchall():
|
||||
product_stats[row.product_name] = {
|
||||
"metric_count": row.metric_count,
|
||||
"avg_accuracy": float(row.avg_accuracy) if row.avg_accuracy else None
|
||||
}
|
||||
|
||||
# Recent activity (metrics in last 7 days)
|
||||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||||
recent_metrics = len(await self.get_records_by_date_range(
|
||||
seven_days_ago,
|
||||
datetime.now(),
|
||||
limit=1000 # High limit to get accurate count
|
||||
))
|
||||
|
||||
return {
|
||||
"total_metrics": total_metrics,
|
||||
"products_tracked": len(product_stats),
|
||||
"metrics_by_product": product_stats,
|
||||
"recent_metrics_7d": recent_metrics
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get metric statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_metrics": 0,
|
||||
"products_tracked": 0,
|
||||
"metrics_by_product": {},
|
||||
"recent_metrics_7d": 0
|
||||
}
|
||||
|
||||
async def compare_model_performance(
|
||||
self,
|
||||
model_ids: List[str],
|
||||
metric_type: str = "accuracy_percentage"
|
||||
) -> Dict[str, Any]:
|
||||
"""Compare performance between multiple models"""
|
||||
try:
|
||||
if not model_ids or len(model_ids) < 2:
|
||||
return {"error": "At least 2 model IDs required for comparison"}
|
||||
|
||||
# Validate metric type
|
||||
valid_metrics = ["mae", "mse", "rmse", "mape", "r2_score", "accuracy_percentage"]
|
||||
if metric_type not in valid_metrics:
|
||||
metric_type = "accuracy_percentage"
|
||||
|
||||
model_ids_str = "', '".join(model_ids)
|
||||
|
||||
query_text = f"""
|
||||
SELECT
|
||||
model_id,
|
||||
product_name,
|
||||
AVG({metric_type}) as avg_metric,
|
||||
MIN({metric_type}) as min_metric,
|
||||
MAX({metric_type}) as max_metric,
|
||||
COUNT(*) as measurement_count,
|
||||
MAX(measured_at) as latest_measurement
|
||||
FROM model_performance_metrics
|
||||
WHERE model_id IN ('{model_ids_str}')
|
||||
AND {metric_type} IS NOT NULL
|
||||
GROUP BY model_id, product_name
|
||||
ORDER BY avg_metric DESC
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text))
|
||||
|
||||
comparisons = []
|
||||
for row in result.fetchall():
|
||||
comparisons.append({
|
||||
"model_id": row.model_id,
|
||||
"product_name": row.product_name,
|
||||
"avg_metric": float(row.avg_metric),
|
||||
"min_metric": float(row.min_metric),
|
||||
"max_metric": float(row.max_metric),
|
||||
"measurement_count": int(row.measurement_count),
|
||||
"latest_measurement": row.latest_measurement.isoformat() if row.latest_measurement else None
|
||||
})
|
||||
|
||||
# Find best and worst performing models
|
||||
if comparisons:
|
||||
best_model = max(comparisons, key=lambda x: x["avg_metric"])
|
||||
worst_model = min(comparisons, key=lambda x: x["avg_metric"])
|
||||
else:
|
||||
best_model = worst_model = None
|
||||
|
||||
return {
|
||||
"metric_type": metric_type,
|
||||
"models_compared": len(set(comp["model_id"] for comp in comparisons)),
|
||||
"comparisons": comparisons,
|
||||
"best_performing": best_model,
|
||||
"worst_performing": worst_model
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to compare model performance",
|
||||
model_ids=model_ids,
|
||||
metric_type=metric_type,
|
||||
error=str(e))
|
||||
return {"error": f"Comparison failed: {str(e)}"}
|
||||
332
services/training/app/repositories/training_log_repository.py
Normal file
332
services/training/app/repositories/training_log_repository.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Training Log Repository
|
||||
Repository for model training log operations
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, text, desc
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import TrainingBaseRepository
|
||||
from app.models.training import ModelTrainingLog
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrainingLogRepository(TrainingBaseRepository):
|
||||
"""Repository for training log operations"""
|
||||
|
||||
def __init__(self, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Training logs change frequently, shorter cache time (5 minutes)
|
||||
super().__init__(ModelTrainingLog, session, cache_ttl)
|
||||
|
||||
async def create_training_log(self, log_data: Dict[str, Any]) -> ModelTrainingLog:
|
||||
"""Create a new training log entry"""
|
||||
try:
|
||||
# Validate log data
|
||||
validation_result = self._validate_training_data(
|
||||
log_data,
|
||||
["job_id", "tenant_id", "status"]
|
||||
)
|
||||
|
||||
if not validation_result["is_valid"]:
|
||||
raise ValidationError(f"Invalid training log data: {validation_result['errors']}")
|
||||
|
||||
# Set default values
|
||||
if "progress" not in log_data:
|
||||
log_data["progress"] = 0
|
||||
if "current_step" not in log_data:
|
||||
log_data["current_step"] = "initializing"
|
||||
|
||||
# Create log entry
|
||||
log_entry = await self.create(log_data)
|
||||
|
||||
logger.info("Training log created",
|
||||
job_id=log_entry.job_id,
|
||||
tenant_id=log_entry.tenant_id,
|
||||
status=log_entry.status)
|
||||
|
||||
return log_entry
|
||||
|
||||
except ValidationError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error("Failed to create training log",
|
||||
job_id=log_data.get("job_id"),
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to create training log: {str(e)}")
|
||||
|
||||
async def get_log_by_job_id(self, job_id: str) -> Optional[ModelTrainingLog]:
|
||||
"""Get training log by job ID"""
|
||||
return await self.get_by_job_id(job_id)
|
||||
|
||||
async def update_log_progress(
|
||||
self,
|
||||
job_id: str,
|
||||
progress: int,
|
||||
current_step: str = None,
|
||||
status: str = None
|
||||
) -> Optional[ModelTrainingLog]:
|
||||
"""Update training log progress"""
|
||||
try:
|
||||
update_data = {"progress": progress, "updated_at": datetime.now()}
|
||||
|
||||
if current_step:
|
||||
update_data["current_step"] = current_step
|
||||
if status:
|
||||
update_data["status"] = status
|
||||
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if not log_entry:
|
||||
logger.error(f"Training log not found for job {job_id}")
|
||||
return None
|
||||
|
||||
updated_log = await self.update(log_entry.id, update_data)
|
||||
|
||||
logger.debug("Training log progress updated",
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
step=current_step)
|
||||
|
||||
return updated_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update training log progress",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to update progress: {str(e)}")
|
||||
|
||||
async def complete_training_log(
|
||||
self,
|
||||
job_id: str,
|
||||
results: Dict[str, Any] = None,
|
||||
error_message: str = None
|
||||
) -> Optional[ModelTrainingLog]:
|
||||
"""Mark training log as completed or failed"""
|
||||
try:
|
||||
status = "failed" if error_message else "completed"
|
||||
|
||||
update_data = {
|
||||
"status": status,
|
||||
"progress": 100 if status == "completed" else None,
|
||||
"end_time": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
if results:
|
||||
update_data["results"] = results
|
||||
if error_message:
|
||||
update_data["error_message"] = error_message
|
||||
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if not log_entry:
|
||||
logger.error(f"Training log not found for job {job_id}")
|
||||
return None
|
||||
|
||||
updated_log = await self.update(log_entry.id, update_data)
|
||||
|
||||
logger.info("Training log completed",
|
||||
job_id=job_id,
|
||||
status=status,
|
||||
has_results=bool(results))
|
||||
|
||||
return updated_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to complete training log",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to complete training log: {str(e)}")
|
||||
|
||||
async def get_logs_by_tenant(
|
||||
self,
|
||||
tenant_id: str,
|
||||
status: str = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[ModelTrainingLog]:
|
||||
"""Get training logs for a tenant"""
|
||||
try:
|
||||
filters = {"tenant_id": tenant_id}
|
||||
if status:
|
||||
filters["status"] = status
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get logs by tenant",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get training logs: {str(e)}")
|
||||
|
||||
async def get_active_jobs(self, tenant_id: str = None) -> List[ModelTrainingLog]:
|
||||
"""Get currently running training jobs"""
|
||||
try:
|
||||
filters = {"status": "running"}
|
||||
if tenant_id:
|
||||
filters["tenant_id"] = tenant_id
|
||||
|
||||
return await self.get_multi(
|
||||
filters=filters,
|
||||
order_by="start_time",
|
||||
order_desc=True
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get active jobs",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get active jobs: {str(e)}")
|
||||
|
||||
async def cancel_job(self, job_id: str, cancelled_by: str = None) -> Optional[ModelTrainingLog]:
|
||||
"""Cancel a training job"""
|
||||
try:
|
||||
update_data = {
|
||||
"status": "cancelled",
|
||||
"end_time": datetime.now(),
|
||||
"updated_at": datetime.now()
|
||||
}
|
||||
|
||||
if cancelled_by:
|
||||
update_data["error_message"] = f"Cancelled by {cancelled_by}"
|
||||
|
||||
log_entry = await self.get_by_job_id(job_id)
|
||||
if not log_entry:
|
||||
logger.error(f"Training log not found for job {job_id}")
|
||||
return None
|
||||
|
||||
# Only cancel if job is still running
|
||||
if log_entry.status not in ["pending", "running"]:
|
||||
logger.warning(f"Cannot cancel job {job_id} with status {log_entry.status}")
|
||||
return log_entry
|
||||
|
||||
updated_log = await self.update(log_entry.id, update_data)
|
||||
|
||||
logger.info("Training job cancelled",
|
||||
job_id=job_id,
|
||||
cancelled_by=cancelled_by)
|
||||
|
||||
return updated_log
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cancel training job",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to cancel job: {str(e)}")
|
||||
|
||||
async def get_job_statistics(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get training job statistics"""
|
||||
try:
|
||||
base_filters = {}
|
||||
if tenant_id:
|
||||
base_filters["tenant_id"] = tenant_id
|
||||
|
||||
# Get counts by status
|
||||
total_jobs = await self.count(filters=base_filters)
|
||||
completed_jobs = await self.count(filters={**base_filters, "status": "completed"})
|
||||
failed_jobs = await self.count(filters={**base_filters, "status": "failed"})
|
||||
running_jobs = await self.count(filters={**base_filters, "status": "running"})
|
||||
pending_jobs = await self.count(filters={**base_filters, "status": "pending"})
|
||||
|
||||
# Get recent activity (jobs in last 7 days)
|
||||
seven_days_ago = datetime.now() - timedelta(days=7)
|
||||
recent_jobs = len(await self.get_records_by_date_range(
|
||||
seven_days_ago,
|
||||
datetime.now(),
|
||||
limit=1000 # High limit to get accurate count
|
||||
))
|
||||
|
||||
# Calculate success rate
|
||||
finished_jobs = completed_jobs + failed_jobs
|
||||
success_rate = (completed_jobs / finished_jobs * 100) if finished_jobs > 0 else 0
|
||||
|
||||
return {
|
||||
"total_jobs": total_jobs,
|
||||
"completed_jobs": completed_jobs,
|
||||
"failed_jobs": failed_jobs,
|
||||
"running_jobs": running_jobs,
|
||||
"pending_jobs": pending_jobs,
|
||||
"cancelled_jobs": total_jobs - completed_jobs - failed_jobs - running_jobs - pending_jobs,
|
||||
"success_rate": round(success_rate, 2),
|
||||
"recent_jobs_7d": recent_jobs
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get job statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"total_jobs": 0,
|
||||
"completed_jobs": 0,
|
||||
"failed_jobs": 0,
|
||||
"running_jobs": 0,
|
||||
"pending_jobs": 0,
|
||||
"cancelled_jobs": 0,
|
||||
"success_rate": 0.0,
|
||||
"recent_jobs_7d": 0
|
||||
}
|
||||
|
||||
async def cleanup_old_logs(self, days_old: int = 90) -> int:
|
||||
"""Clean up old completed/failed training logs"""
|
||||
return await self.cleanup_old_records(
|
||||
days_old=days_old,
|
||||
status_filter=None # Clean up all old records regardless of status
|
||||
)
|
||||
|
||||
async def get_job_duration_stats(self, tenant_id: str = None) -> Dict[str, Any]:
|
||||
"""Get job duration statistics"""
|
||||
try:
|
||||
# Use raw SQL for complex duration calculations
|
||||
tenant_filter = "AND tenant_id = :tenant_id" if tenant_id else ""
|
||||
params = {"tenant_id": tenant_id} if tenant_id else {}
|
||||
|
||||
query = text(f"""
|
||||
SELECT
|
||||
AVG(EXTRACT(EPOCH FROM (end_time - start_time))/60) as avg_duration_minutes,
|
||||
MIN(EXTRACT(EPOCH FROM (end_time - start_time))/60) as min_duration_minutes,
|
||||
MAX(EXTRACT(EPOCH FROM (end_time - start_time))/60) as max_duration_minutes,
|
||||
COUNT(*) as completed_jobs_with_duration
|
||||
FROM model_training_logs
|
||||
WHERE status = 'completed'
|
||||
AND start_time IS NOT NULL
|
||||
AND end_time IS NOT NULL
|
||||
{tenant_filter}
|
||||
""")
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
row = result.fetchone()
|
||||
|
||||
if row and row.completed_jobs_with_duration > 0:
|
||||
return {
|
||||
"avg_duration_minutes": round(float(row.avg_duration_minutes or 0), 2),
|
||||
"min_duration_minutes": round(float(row.min_duration_minutes or 0), 2),
|
||||
"max_duration_minutes": round(float(row.max_duration_minutes or 0), 2),
|
||||
"completed_jobs_with_duration": int(row.completed_jobs_with_duration)
|
||||
}
|
||||
|
||||
return {
|
||||
"avg_duration_minutes": 0.0,
|
||||
"min_duration_minutes": 0.0,
|
||||
"max_duration_minutes": 0.0,
|
||||
"completed_jobs_with_duration": 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get job duration statistics",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e))
|
||||
return {
|
||||
"avg_duration_minutes": 0.0,
|
||||
"min_duration_minutes": 0.0,
|
||||
"max_duration_minutes": 0.0,
|
||||
"completed_jobs_with_duration": 0
|
||||
}
|
||||
Reference in New Issue
Block a user