Files
bakery-ia/services/training/app/repositories/base.py
2025-11-14 07:23:56 +01:00

179 lines
6.8 KiB
Python

"""
Base Repository for Training Service
Service-specific repository base class with training service utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
from datetime import datetime, timezone, timedelta
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TrainingBaseRepository(BaseRepository):
"""Base repository for training service with common training operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training data changes frequently, shorter cache time (5 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_job_id(self, job_id: str) -> Optional:
"""Get record by job ID (if model has job_id field)"""
if hasattr(self.model, 'job_id'):
return await self.get_by_field("job_id", job_id)
return None
async def get_by_model_id(self, model_id: str) -> Optional:
"""Get record by model ID (if model has model_id field)"""
if hasattr(self.model, 'model_id'):
return await self.get_by_field("model_id", model_id)
return None
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
"""Clean up old training records"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
table_name = self.model.__tablename__
# Build query based on available fields
conditions = [f"created_at < :cutoff_date"]
params = {"cutoff_date": cutoff_date}
if status_filter and hasattr(self.model, 'status'):
conditions.append(f"status = :status")
params["status"] = status_filter
query_text = f"""
DELETE FROM {table_name}
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_records_by_date_range(
self,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range"""
if not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no created_at field")
return []
try:
table_name = self.model.__tablename__
query_text = f"""
SELECT * FROM {table_name}
WHERE created_at >= :start_date
AND created_at <= :end_date
ORDER BY created_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
# Create model instance from row data
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate training-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate job_id format if present
if "job_id" in data and data["job_id"]:
job_id = data["job_id"]
if not isinstance(job_id, str) or len(job_id) < 1:
errors.append("Invalid job_id format")
return {
"is_valid": len(errors) == 0,
"errors": errors
}