""" Base Repository for Training Service Service-specific repository base class with training service utilities """ from typing import Optional, List, Dict, Any, Type from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import text from datetime import datetime, timezone, timedelta import structlog from shared.database.repository import BaseRepository from shared.database.exceptions import DatabaseError logger = structlog.get_logger() class TrainingBaseRepository(BaseRepository): """Base repository for training service with common training operations""" def __init__(self, model: Type, session: AsyncSession, cache_ttl: Optional[int] = 300): # Training data changes frequently, shorter cache time (5 minutes) super().__init__(model, session, cache_ttl) async def get_by_tenant_id(self, tenant_id: str, skip: int = 0, limit: int = 100) -> List: """Get records by tenant ID""" if hasattr(self.model, 'tenant_id'): return await self.get_multi( skip=skip, limit=limit, filters={"tenant_id": tenant_id}, order_by="created_at", order_desc=True ) return await self.get_multi(skip=skip, limit=limit) async def get_active_records(self, skip: int = 0, limit: int = 100) -> List: """Get active records (if model has is_active field)""" if hasattr(self.model, 'is_active'): return await self.get_multi( skip=skip, limit=limit, filters={"is_active": True}, order_by="created_at", order_desc=True ) return await self.get_multi(skip=skip, limit=limit) async def get_by_job_id(self, job_id: str) -> Optional: """Get record by job ID (if model has job_id field)""" if hasattr(self.model, 'job_id'): return await self.get_by_field("job_id", job_id) return None async def get_by_model_id(self, model_id: str) -> Optional: """Get record by model ID (if model has model_id field)""" if hasattr(self.model, 'model_id'): return await self.get_by_field("model_id", model_id) return None async def deactivate_record(self, record_id: Any) -> Optional: """Deactivate a record instead of deleting it""" if hasattr(self.model, 'is_active'): return await self.update(record_id, {"is_active": False}) return await self.delete(record_id) async def activate_record(self, record_id: Any) -> Optional: """Activate a record""" if hasattr(self.model, 'is_active'): return await self.update(record_id, {"is_active": True}) return await self.get_by_id(record_id) async def cleanup_old_records(self, days_old: int = 90, status_filter: str = None) -> int: """Clean up old training records""" try: cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_old) table_name = self.model.__tablename__ # Build query based on available fields conditions = [f"created_at < :cutoff_date"] params = {"cutoff_date": cutoff_date} if status_filter and hasattr(self.model, 'status'): conditions.append(f"status = :status") params["status"] = status_filter query_text = f""" DELETE FROM {table_name} WHERE {' AND '.join(conditions)} """ result = await self.session.execute(text(query_text), params) deleted_count = result.rowcount logger.info(f"Cleaned up old {self.model.__name__} records", deleted_count=deleted_count, days_old=days_old, status_filter=status_filter) return deleted_count except Exception as e: logger.error("Failed to cleanup old records", model=self.model.__name__, error=str(e)) raise DatabaseError(f"Cleanup failed: {str(e)}") async def get_records_by_date_range( self, start_date: datetime, end_date: datetime, skip: int = 0, limit: int = 100 ) -> List: """Get records within date range""" if not hasattr(self.model, 'created_at'): logger.warning(f"Model {self.model.__name__} has no created_at field") return [] try: table_name = self.model.__tablename__ query_text = f""" SELECT * FROM {table_name} WHERE created_at >= :start_date AND created_at <= :end_date ORDER BY created_at DESC LIMIT :limit OFFSET :skip """ result = await self.session.execute(text(query_text), { "start_date": start_date, "end_date": end_date, "limit": limit, "skip": skip }) # Convert rows to model objects records = [] for row in result.fetchall(): # Create model instance from row data record_dict = dict(row._mapping) record = self.model(**record_dict) records.append(record) return records except Exception as e: logger.error("Failed to get records by date range", model=self.model.__name__, start_date=start_date, end_date=end_date, error=str(e)) raise DatabaseError(f"Date range query failed: {str(e)}") def _validate_training_data(self, data: Dict[str, Any], required_fields: List[str]) -> Dict[str, Any]: """Validate training-related data""" errors = [] for field in required_fields: if field not in data or not data[field]: errors.append(f"Missing required field: {field}") # Validate tenant_id format if present if "tenant_id" in data and data["tenant_id"]: tenant_id = data["tenant_id"] if not isinstance(tenant_id, str) or len(tenant_id) < 1: errors.append("Invalid tenant_id format") # Validate job_id format if present if "job_id" in data and data["job_id"]: job_id = data["job_id"] if not isinstance(job_id, str) or len(job_id) < 1: errors.append("Invalid job_id format") return { "is_valid": len(errors) == 0, "errors": errors }