REFACTOR - Database logic
This commit is contained in:
179
services/training/app/repositories/base.py
Normal file
179
services/training/app/repositories/base.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""
|
||||
Base Repository for Training Service
|
||||
Service-specific repository base class with training service utilities
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text
|
||||
from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from shared.database.repository import BaseRepository
|
||||
from shared.database.exceptions import DatabaseError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrainingBaseRepository(BaseRepository):
|
||||
"""Base repository for training service with common training operations"""
|
||||
|
||||
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
# Training data changes frequently, shorter cache time (5 minutes)
|
||||
super().__init__(model, session, cache_ttl)
|
||||
|
||||
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get records by tenant ID"""
|
||||
if hasattr(self.model, 'tenant_id'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"tenant_id": tenant_id},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
|
||||
"""Get active records (if model has is_active field)"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.get_multi(
|
||||
skip=skip,
|
||||
limit=limit,
|
||||
filters={"is_active": True},
|
||||
order_by="created_at",
|
||||
order_desc=True
|
||||
)
|
||||
return await self.get_multi(skip=skip, limit=limit)
|
||||
|
||||
async def get_by_job_id(self, job_id: str) -> Optional:
|
||||
"""Get record by job ID (if model has job_id field)"""
|
||||
if hasattr(self.model, 'job_id'):
|
||||
return await self.get_by_field("job_id", job_id)
|
||||
return None
|
||||
|
||||
async def get_by_model_id(self, model_id: str) -> Optional:
|
||||
"""Get record by model ID (if model has model_id field)"""
|
||||
if hasattr(self.model, 'model_id'):
|
||||
return await self.get_by_field("model_id", model_id)
|
||||
return None
|
||||
|
||||
async def deactivate_record(self, record_id: Any) -> Optional:
|
||||
"""Deactivate a record instead of deleting it"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": False})
|
||||
return await self.delete(record_id)
|
||||
|
||||
async def activate_record(self, record_id: Any) -> Optional:
|
||||
"""Activate a record"""
|
||||
if hasattr(self.model, 'is_active'):
|
||||
return await self.update(record_id, {"is_active": True})
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
|
||||
"""Clean up old training records"""
|
||||
try:
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=days_old)
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
# Build query based on available fields
|
||||
conditions = [f"created_at < :cutoff_date"]
|
||||
params = {"cutoff_date": cutoff_date}
|
||||
|
||||
if status_filter and hasattr(self.model, 'status'):
|
||||
conditions.append(f"status = :status")
|
||||
params["status"] = status_filter
|
||||
|
||||
query_text = f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {' AND '.join(conditions)}
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
logger.info(f"Cleaned up old {self.model.__name__} records",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old,
|
||||
status_filter=status_filter)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old records",
|
||||
model=self.model.__name__,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
async def get_records_by_date_range(
|
||||
self,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List:
|
||||
"""Get records within date range"""
|
||||
if not hasattr(self.model, 'created_at'):
|
||||
logger.warning(f"Model {self.model.__name__} has no created_at field")
|
||||
return []
|
||||
|
||||
try:
|
||||
table_name = self.model.__tablename__
|
||||
|
||||
query_text = f"""
|
||||
SELECT * FROM {table_name}
|
||||
WHERE created_at >= :start_date
|
||||
AND created_at <= :end_date
|
||||
ORDER BY created_at DESC
|
||||
LIMIT :limit OFFSET :skip
|
||||
"""
|
||||
|
||||
result = await self.session.execute(text(query_text), {
|
||||
"start_date": start_date,
|
||||
"end_date": end_date,
|
||||
"limit": limit,
|
||||
"skip": skip
|
||||
})
|
||||
|
||||
# Convert rows to model objects
|
||||
records = []
|
||||
for row in result.fetchall():
|
||||
# Create model instance from row data
|
||||
record_dict = dict(row._mapping)
|
||||
record = self.model(**record_dict)
|
||||
records.append(record)
|
||||
|
||||
return records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get records by date range",
|
||||
model=self.model.__name__,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Date range query failed: {str(e)}")
|
||||
|
||||
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
|
||||
"""Validate training-related data"""
|
||||
errors = []
|
||||
|
||||
for field in required_fields:
|
||||
if field not in data or not data[field]:
|
||||
errors.append(f"Missing required field: {field}")
|
||||
|
||||
# Validate tenant_id format if present
|
||||
if "tenant_id" in data and data["tenant_id"]:
|
||||
tenant_id = data["tenant_id"]
|
||||
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
|
||||
errors.append("Invalid tenant_id format")
|
||||
|
||||
# Validate job_id format if present
|
||||
if "job_id" in data and data["job_id"]:
|
||||
job_id = data["job_id"]
|
||||
if not isinstance(job_id, str) or len(job_id) < 1:
|
||||
errors.append("Invalid job_id format")
|
||||
|
||||
return {
|
||||
"is_valid": len(errors) == 0,
|
||||
"errors": errors
|
||||
}
|
||||
Reference in New Issue
Block a user