""" Database Utilities Helper functions for database operations and maintenance """ from typing import Dict, Any, List, Optional from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import text, inspect from sqlalchemy.exc import SQLAlchemyError import structlog from .exceptions import DatabaseError, HealthCheckError logger = structlog.get_logger() class DatabaseUtils: """Utility functions for database operations""" @staticmethod async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]: """ Comprehensive database health check Returns: Dict with health status, metrics, and diagnostics """ try: # Basic connectivity test start_time = __import__('time').time() await session.execute(text("SELECT 1")) response_time = __import__('time').time() - start_time # Get database info db_info = await DatabaseUtils._get_database_info(session) # Connection pool status (if available) pool_info = await DatabaseUtils._get_pool_info(session) return { "status": "healthy", "response_time_seconds": round(response_time, 4), "database": db_info, "connection_pool": pool_info, "timestamp": __import__('datetime').datetime.utcnow().isoformat() } except Exception as e: logger.error("Database health check failed", error=str(e)) raise HealthCheckError(f"Health check failed: {str(e)}") @staticmethod async def _get_database_info(session: AsyncSession) -> Dict[str, Any]: """Get database server information""" try: # Try to get database version and basic stats if session.bind.dialect.name == 'postgresql': version_result = await session.execute(text("SELECT version()")) version = version_result.scalar() stats_result = await session.execute(text(""" SELECT count(*) as active_connections, (SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections FROM pg_stat_activity WHERE state = 'active' """)) stats = stats_result.fetchone() return { "type": "postgresql", "version": version, "active_connections": stats.active_connections if stats else 0, "max_connections": stats.max_connections if stats else "unknown" } elif session.bind.dialect.name == 'sqlite': version_result = await session.execute(text("SELECT sqlite_version()")) version = version_result.scalar() return { "type": "sqlite", "version": version, "active_connections": 1, "max_connections": "unlimited" } else: return { "type": session.bind.dialect.name, "version": "unknown", "active_connections": "unknown", "max_connections": "unknown" } except Exception as e: logger.warning("Could not retrieve database info", error=str(e)) return { "type": session.bind.dialect.name, "version": "unknown", "error": str(e) } @staticmethod async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]: """Get connection pool information""" try: pool = session.bind.pool if pool: return { "size": pool.size(), "checked_in": pool.checkedin(), "checked_out": pool.checkedout(), "overflow": pool.overflow(), "invalid": pool.invalid() } else: return {"status": "no_pool"} except Exception as e: logger.warning("Could not retrieve pool info", error=str(e)) return {"error": str(e)} @staticmethod async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]: """ Validate database schema against expected tables Args: session: Database session expected_tables: List of table names that should exist Returns: Validation results with missing/extra tables """ try: # Get existing tables inspector = inspect(session.bind) existing_tables = set(inspector.get_table_names()) expected_tables_set = set(expected_tables) missing_tables = expected_tables_set - existing_tables extra_tables = existing_tables - expected_tables_set return { "valid": len(missing_tables) == 0, "existing_tables": list(existing_tables), "expected_tables": expected_tables, "missing_tables": list(missing_tables), "extra_tables": list(extra_tables), "total_tables": len(existing_tables) } except Exception as e: logger.error("Schema validation failed", error=str(e)) raise DatabaseError(f"Schema validation failed: {str(e)}") @staticmethod async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]: """ Get statistics for specified tables Args: session: Database session table_names: List of table names to analyze Returns: Dictionary with table statistics """ try: stats = {} for table_name in table_names: if session.bind.dialect.name == 'postgresql': # PostgreSQL specific queries count_result = await session.execute( text(f"SELECT COUNT(*) FROM {table_name}") ) row_count = count_result.scalar() size_result = await session.execute( text(f"SELECT pg_total_relation_size('{table_name}')") ) table_size = size_result.scalar() stats[table_name] = { "row_count": row_count, "size_bytes": table_size, "size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0 } elif session.bind.dialect.name == 'sqlite': # SQLite specific queries count_result = await session.execute( text(f"SELECT COUNT(*) FROM {table_name}") ) row_count = count_result.scalar() stats[table_name] = { "row_count": row_count, "size_bytes": "unknown", "size_mb": "unknown" } else: # Generic fallback count_result = await session.execute( text(f"SELECT COUNT(*) FROM {table_name}") ) row_count = count_result.scalar() stats[table_name] = { "row_count": row_count, "size_bytes": "unknown", "size_mb": "unknown" } return stats except Exception as e: logger.error("Failed to get table statistics", tables=table_names, error=str(e)) raise DatabaseError(f"Failed to get table stats: {str(e)}") @staticmethod async def cleanup_old_records( session: AsyncSession, table_name: str, date_column: str, days_old: int, batch_size: int = 1000 ) -> int: """ Clean up old records from a table Args: session: Database session table_name: Name of table to clean date_column: Date column to filter by days_old: Records older than this many days will be deleted batch_size: Number of records to delete per batch Returns: Total number of records deleted """ try: total_deleted = 0 while True: if session.bind.dialect.name == 'postgresql': delete_query = text(f""" DELETE FROM {table_name} WHERE {date_column} < NOW() - INTERVAL :days_param AND ctid IN ( SELECT ctid FROM {table_name} WHERE {date_column} < NOW() - INTERVAL :days_param LIMIT :batch_size ) """) params = { "days_param": f"{days_old} days", "batch_size": batch_size } elif session.bind.dialect.name == 'sqlite': delete_query = text(f""" DELETE FROM {table_name} WHERE {date_column} < datetime('now', :days_param) AND rowid IN ( SELECT rowid FROM {table_name} WHERE {date_column} < datetime('now', :days_param) LIMIT :batch_size ) """) params = { "days_param": f"-{days_old} days", "batch_size": batch_size } else: # Generic fallback (may not work for all databases) delete_query = text(f""" DELETE FROM {table_name} WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY) LIMIT :batch_size """) params = { "days_old": days_old, "batch_size": batch_size } result = await session.execute(delete_query, params) deleted_count = result.rowcount if deleted_count == 0: break total_deleted += deleted_count await session.commit() logger.debug(f"Deleted batch from {table_name}", batch_size=deleted_count, total_deleted=total_deleted) logger.info(f"Cleanup completed for {table_name}", total_deleted=total_deleted, days_old=days_old) return total_deleted except Exception as e: await session.rollback() logger.error(f"Cleanup failed for {table_name}", error=str(e)) raise DatabaseError(f"Cleanup failed: {str(e)}") @staticmethod async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]: """ Execute database maintenance tasks Returns: Dictionary with maintenance results """ try: results = {} if session.bind.dialect.name == 'postgresql': # PostgreSQL maintenance await session.execute(text("VACUUM ANALYZE")) results["vacuum"] = "completed" # Update statistics await session.execute(text("ANALYZE")) results["analyze"] = "completed" elif session.bind.dialect.name == 'sqlite': # SQLite maintenance await session.execute(text("VACUUM")) results["vacuum"] = "completed" await session.execute(text("ANALYZE")) results["analyze"] = "completed" else: results["maintenance"] = "not_supported" await session.commit() logger.info("Database maintenance completed", results=results) return results except Exception as e: await session.rollback() logger.error("Database maintenance failed", error=str(e)) raise DatabaseError(f"Maintenance failed: {str(e)}") class QueryLogger: """Utility for logging and analyzing database queries""" def __init__(self, session: AsyncSession): self.session = session self._query_log = [] async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None): """Log a database query with metadata""" log_entry = { "query": query, "params": params, "execution_time": execution_time, "timestamp": __import__('datetime').datetime.utcnow().isoformat() } self._query_log.append(log_entry) # Log slow queries if execution_time and execution_time > 1.0: # 1 second threshold logger.warning("Slow query detected", query=query, execution_time=execution_time) def get_query_stats(self) -> Dict[str, Any]: """Get statistics about logged queries""" if not self._query_log: return {"total_queries": 0} execution_times = [ entry["execution_time"] for entry in self._query_log if entry["execution_time"] is not None ] return { "total_queries": len(self._query_log), "avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0, "max_execution_time": max(execution_times) if execution_times else 0, "slow_queries_count": len([t for t in execution_times if t > 1.0]) } def clear_log(self): """Clear the query log""" self._query_log.clear()