Files
bakery-ia/services/training/app/repositories/base.py

179 lines
6.8 KiB
Python
Raw Normal View History

2025-08-08 09:08:41 +02:00
"""
Base Repository for Training Service
Service-specific repository base class with training service utilities
"""
from typing import Optional, List, Dict, Any, Type
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text
2025-11-14 07:23:56 +01:00
from datetime import datetime, timezone, timedelta
2025-08-08 09:08:41 +02:00
import structlog
from shared.database.repository import BaseRepository
from shared.database.exceptions import DatabaseError
logger = structlog.get_logger()
class TrainingBaseRepository(BaseRepository):
"""Base repository for training service with common training operations"""
def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
# Training data changes frequently, shorter cache time (5 minutes)
super().__init__(model, session, cache_ttl)
async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List:
"""Get records by tenant ID"""
if hasattr(self.model, 'tenant_id'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"tenant_id": tenant_id},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_active_records(self, skip: int = 0, limit: int = 100) -> List:
"""Get active records (if model has is_active field)"""
if hasattr(self.model, 'is_active'):
return await self.get_multi(
skip=skip,
limit=limit,
filters={"is_active": True},
order_by="created_at",
order_desc=True
)
return await self.get_multi(skip=skip, limit=limit)
async def get_by_job_id(self, job_id: str) -> Optional:
"""Get record by job ID (if model has job_id field)"""
if hasattr(self.model, 'job_id'):
return await self.get_by_field("job_id", job_id)
return None
async def get_by_model_id(self, model_id: str) -> Optional:
"""Get record by model ID (if model has model_id field)"""
if hasattr(self.model, 'model_id'):
return await self.get_by_field("model_id", model_id)
return None
async def deactivate_record(self, record_id: Any) -> Optional:
"""Deactivate a record instead of deleting it"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": False})
return await self.delete(record_id)
async def activate_record(self, record_id: Any) -> Optional:
"""Activate a record"""
if hasattr(self.model, 'is_active'):
return await self.update(record_id, {"is_active": True})
return await self.get_by_id(record_id)
async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int:
"""Clean up old training records"""
try:
2025-11-14 07:23:56 +01:00
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old)
2025-08-08 09:08:41 +02:00
table_name = self.model.__tablename__
# Build query based on available fields
conditions = [f"created_at < :cutoff_date"]
params = {"cutoff_date": cutoff_date}
if status_filter and hasattr(self.model, 'status'):
conditions.append(f"status = :status")
params["status"] = status_filter
query_text = f"""
DELETE FROM {table_name}
WHERE {' AND '.join(conditions)}
"""
result = await self.session.execute(text(query_text), params)
deleted_count = result.rowcount
logger.info(f"Cleaned up old {self.model.__name__} records",
deleted_count=deleted_count,
days_old=days_old,
status_filter=status_filter)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old records",
model=self.model.__name__,
error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
async def get_records_by_date_range(
self,
start_date: datetime,
end_date: datetime,
skip: int = 0,
limit: int = 100
) -> List:
"""Get records within date range"""
if not hasattr(self.model, 'created_at'):
logger.warning(f"Model {self.model.__name__} has no created_at field")
return []
try:
table_name = self.model.__tablename__
query_text = f"""
SELECT * FROM {table_name}
WHERE created_at >= :start_date
AND created_at <= :end_date
ORDER BY created_at DESC
LIMIT :limit OFFSET :skip
"""
result = await self.session.execute(text(query_text), {
"start_date": start_date,
"end_date": end_date,
"limit": limit,
"skip": skip
})
# Convert rows to model objects
records = []
for row in result.fetchall():
# Create model instance from row data
record_dict = dict(row._mapping)
record = self.model(**record_dict)
records.append(record)
return records
except Exception as e:
logger.error("Failed to get records by date range",
model=self.model.__name__,
start_date=start_date,
end_date=end_date,
error=str(e))
raise DatabaseError(f"Date range query failed: {str(e)}")
def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]:
"""Validate training-related data"""
errors = []
for field in required_fields:
if field not in data or not data[field]:
errors.append(f"Missing required field: {field}")
# Validate tenant_id format if present
if "tenant_id" in data and data["tenant_id"]:
tenant_id = data["tenant_id"]
if not isinstance(tenant_id, str) or len(tenant_id) < 1:
errors.append("Invalid tenant_id format")
# Validate job_id format if present
if "job_id" in data and data["job_id"]:
job_id = data["job_id"]
if not isinstance(job_id, str) or len(job_id) < 1:
errors.append("Invalid job_id format")
return {
"is_valid": len(errors) == 0,
"errors": errors
}