402 lines
15 KiB
Python
402 lines
15 KiB
Python
|
|
"""
|
||
|
|
Database Utilities
|
||
|
|
Helper functions for database operations and maintenance
|
||
|
|
"""
|
||
|
|
|
||
|
|
from typing import Dict, Any, List, Optional
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy import text, inspect
|
||
|
|
from sqlalchemy.exc import SQLAlchemyError
|
||
|
|
import structlog
|
||
|
|
|
||
|
|
from .exceptions import DatabaseError, HealthCheckError
|
||
|
|
|
||
|
|
logger = structlog.get_logger()
|
||
|
|
|
||
|
|
|
||
|
|
class DatabaseUtils:
|
||
|
|
"""Utility functions for database operations"""
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Comprehensive database health check
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dict with health status, metrics, and diagnostics
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Basic connectivity test
|
||
|
|
start_time = __import__('time').time()
|
||
|
|
await session.execute(text("SELECT 1"))
|
||
|
|
response_time = __import__('time').time() - start_time
|
||
|
|
|
||
|
|
# Get database info
|
||
|
|
db_info = await DatabaseUtils._get_database_info(session)
|
||
|
|
|
||
|
|
# Connection pool status (if available)
|
||
|
|
pool_info = await DatabaseUtils._get_pool_info(session)
|
||
|
|
|
||
|
|
return {
|
||
|
|
"status": "healthy",
|
||
|
|
"response_time_seconds": round(response_time, 4),
|
||
|
|
"database": db_info,
|
||
|
|
"connection_pool": pool_info,
|
||
|
|
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Database health check failed", error=str(e))
|
||
|
|
raise HealthCheckError(f"Health check failed: {str(e)}")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def _get_database_info(session: AsyncSession) -> Dict[str, Any]:
|
||
|
|
"""Get database server information"""
|
||
|
|
try:
|
||
|
|
# Try to get database version and basic stats
|
||
|
|
if session.bind.dialect.name == 'postgresql':
|
||
|
|
version_result = await session.execute(text("SELECT version()"))
|
||
|
|
version = version_result.scalar()
|
||
|
|
|
||
|
|
stats_result = await session.execute(text("""
|
||
|
|
SELECT
|
||
|
|
count(*) as active_connections,
|
||
|
|
(SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections
|
||
|
|
FROM pg_stat_activity
|
||
|
|
WHERE state = 'active'
|
||
|
|
"""))
|
||
|
|
stats = stats_result.fetchone()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"type": "postgresql",
|
||
|
|
"version": version,
|
||
|
|
"active_connections": stats.active_connections if stats else 0,
|
||
|
|
"max_connections": stats.max_connections if stats else "unknown"
|
||
|
|
}
|
||
|
|
|
||
|
|
elif session.bind.dialect.name == 'sqlite':
|
||
|
|
version_result = await session.execute(text("SELECT sqlite_version()"))
|
||
|
|
version = version_result.scalar()
|
||
|
|
|
||
|
|
return {
|
||
|
|
"type": "sqlite",
|
||
|
|
"version": version,
|
||
|
|
"active_connections": 1,
|
||
|
|
"max_connections": "unlimited"
|
||
|
|
}
|
||
|
|
|
||
|
|
else:
|
||
|
|
return {
|
||
|
|
"type": session.bind.dialect.name,
|
||
|
|
"version": "unknown",
|
||
|
|
"active_connections": "unknown",
|
||
|
|
"max_connections": "unknown"
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning("Could not retrieve database info", error=str(e))
|
||
|
|
return {
|
||
|
|
"type": session.bind.dialect.name,
|
||
|
|
"version": "unknown",
|
||
|
|
"error": str(e)
|
||
|
|
}
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]:
|
||
|
|
"""Get connection pool information"""
|
||
|
|
try:
|
||
|
|
pool = session.bind.pool
|
||
|
|
if pool:
|
||
|
|
return {
|
||
|
|
"size": pool.size(),
|
||
|
|
"checked_in": pool.checkedin(),
|
||
|
|
"checked_out": pool.checkedout(),
|
||
|
|
"overflow": pool.overflow(),
|
||
|
|
"status": pool.status()
|
||
|
|
}
|
||
|
|
else:
|
||
|
|
return {"status": "no_pool"}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.warning("Could not retrieve pool info", error=str(e))
|
||
|
|
return {"error": str(e)}
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Validate database schema against expected tables
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session: Database session
|
||
|
|
expected_tables: List of table names that should exist
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Validation results with missing/extra tables
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
# Get existing tables
|
||
|
|
inspector = inspect(session.bind)
|
||
|
|
existing_tables = set(inspector.get_table_names())
|
||
|
|
expected_tables_set = set(expected_tables)
|
||
|
|
|
||
|
|
missing_tables = expected_tables_set - existing_tables
|
||
|
|
extra_tables = existing_tables - expected_tables_set
|
||
|
|
|
||
|
|
return {
|
||
|
|
"valid": len(missing_tables) == 0,
|
||
|
|
"existing_tables": list(existing_tables),
|
||
|
|
"expected_tables": expected_tables,
|
||
|
|
"missing_tables": list(missing_tables),
|
||
|
|
"extra_tables": list(extra_tables),
|
||
|
|
"total_tables": len(existing_tables)
|
||
|
|
}
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Schema validation failed", error=str(e))
|
||
|
|
raise DatabaseError(f"Schema validation failed: {str(e)}")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Get statistics for specified tables
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session: Database session
|
||
|
|
table_names: List of table names to analyze
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary with table statistics
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
stats = {}
|
||
|
|
|
||
|
|
for table_name in table_names:
|
||
|
|
if session.bind.dialect.name == 'postgresql':
|
||
|
|
# PostgreSQL specific queries
|
||
|
|
count_result = await session.execute(
|
||
|
|
text(f"SELECT COUNT(*) FROM {table_name}")
|
||
|
|
)
|
||
|
|
row_count = count_result.scalar()
|
||
|
|
|
||
|
|
size_result = await session.execute(
|
||
|
|
text(f"SELECT pg_total_relation_size('{table_name}')")
|
||
|
|
)
|
||
|
|
table_size = size_result.scalar()
|
||
|
|
|
||
|
|
stats[table_name] = {
|
||
|
|
"row_count": row_count,
|
||
|
|
"size_bytes": table_size,
|
||
|
|
"size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0
|
||
|
|
}
|
||
|
|
|
||
|
|
elif session.bind.dialect.name == 'sqlite':
|
||
|
|
# SQLite specific queries
|
||
|
|
count_result = await session.execute(
|
||
|
|
text(f"SELECT COUNT(*) FROM {table_name}")
|
||
|
|
)
|
||
|
|
row_count = count_result.scalar()
|
||
|
|
|
||
|
|
stats[table_name] = {
|
||
|
|
"row_count": row_count,
|
||
|
|
"size_bytes": "unknown",
|
||
|
|
"size_mb": "unknown"
|
||
|
|
}
|
||
|
|
|
||
|
|
else:
|
||
|
|
# Generic fallback
|
||
|
|
count_result = await session.execute(
|
||
|
|
text(f"SELECT COUNT(*) FROM {table_name}")
|
||
|
|
)
|
||
|
|
row_count = count_result.scalar()
|
||
|
|
|
||
|
|
stats[table_name] = {
|
||
|
|
"row_count": row_count,
|
||
|
|
"size_bytes": "unknown",
|
||
|
|
"size_mb": "unknown"
|
||
|
|
}
|
||
|
|
|
||
|
|
return stats
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error("Failed to get table statistics",
|
||
|
|
tables=table_names, error=str(e))
|
||
|
|
raise DatabaseError(f"Failed to get table stats: {str(e)}")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def cleanup_old_records(
|
||
|
|
session: AsyncSession,
|
||
|
|
table_name: str,
|
||
|
|
date_column: str,
|
||
|
|
days_old: int,
|
||
|
|
batch_size: int = 1000
|
||
|
|
) -> int:
|
||
|
|
"""
|
||
|
|
Clean up old records from a table
|
||
|
|
|
||
|
|
Args:
|
||
|
|
session: Database session
|
||
|
|
table_name: Name of table to clean
|
||
|
|
date_column: Date column to filter by
|
||
|
|
days_old: Records older than this many days will be deleted
|
||
|
|
batch_size: Number of records to delete per batch
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Total number of records deleted
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
total_deleted = 0
|
||
|
|
|
||
|
|
while True:
|
||
|
|
if session.bind.dialect.name == 'postgresql':
|
||
|
|
delete_query = text(f"""
|
||
|
|
DELETE FROM {table_name}
|
||
|
|
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||
|
|
AND ctid IN (
|
||
|
|
SELECT ctid FROM {table_name}
|
||
|
|
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||
|
|
LIMIT :batch_size
|
||
|
|
)
|
||
|
|
""")
|
||
|
|
params = {
|
||
|
|
"days_param": f"{days_old} days",
|
||
|
|
"batch_size": batch_size
|
||
|
|
}
|
||
|
|
|
||
|
|
elif session.bind.dialect.name == 'sqlite':
|
||
|
|
delete_query = text(f"""
|
||
|
|
DELETE FROM {table_name}
|
||
|
|
WHERE {date_column} < datetime('now', :days_param)
|
||
|
|
AND rowid IN (
|
||
|
|
SELECT rowid FROM {table_name}
|
||
|
|
WHERE {date_column} < datetime('now', :days_param)
|
||
|
|
LIMIT :batch_size
|
||
|
|
)
|
||
|
|
""")
|
||
|
|
params = {
|
||
|
|
"days_param": f"-{days_old} days",
|
||
|
|
"batch_size": batch_size
|
||
|
|
}
|
||
|
|
|
||
|
|
else:
|
||
|
|
# Generic fallback (may not work for all databases)
|
||
|
|
delete_query = text(f"""
|
||
|
|
DELETE FROM {table_name}
|
||
|
|
WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY)
|
||
|
|
LIMIT :batch_size
|
||
|
|
""")
|
||
|
|
params = {
|
||
|
|
"days_old": days_old,
|
||
|
|
"batch_size": batch_size
|
||
|
|
}
|
||
|
|
|
||
|
|
result = await session.execute(delete_query, params)
|
||
|
|
deleted_count = result.rowcount
|
||
|
|
|
||
|
|
if deleted_count == 0:
|
||
|
|
break
|
||
|
|
|
||
|
|
total_deleted += deleted_count
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
logger.debug(f"Deleted batch from {table_name}",
|
||
|
|
batch_size=deleted_count,
|
||
|
|
total_deleted=total_deleted)
|
||
|
|
|
||
|
|
logger.info(f"Cleanup completed for {table_name}",
|
||
|
|
total_deleted=total_deleted,
|
||
|
|
days_old=days_old)
|
||
|
|
|
||
|
|
return total_deleted
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
logger.error(f"Cleanup failed for {table_name}", error=str(e))
|
||
|
|
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]:
|
||
|
|
"""
|
||
|
|
Execute database maintenance tasks
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Dictionary with maintenance results
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
results = {}
|
||
|
|
|
||
|
|
if session.bind.dialect.name == 'postgresql':
|
||
|
|
# PostgreSQL maintenance
|
||
|
|
await session.execute(text("VACUUM ANALYZE"))
|
||
|
|
results["vacuum"] = "completed"
|
||
|
|
|
||
|
|
# Update statistics
|
||
|
|
await session.execute(text("ANALYZE"))
|
||
|
|
results["analyze"] = "completed"
|
||
|
|
|
||
|
|
elif session.bind.dialect.name == 'sqlite':
|
||
|
|
# SQLite maintenance
|
||
|
|
await session.execute(text("VACUUM"))
|
||
|
|
results["vacuum"] = "completed"
|
||
|
|
|
||
|
|
await session.execute(text("ANALYZE"))
|
||
|
|
results["analyze"] = "completed"
|
||
|
|
|
||
|
|
else:
|
||
|
|
results["maintenance"] = "not_supported"
|
||
|
|
|
||
|
|
await session.commit()
|
||
|
|
|
||
|
|
logger.info("Database maintenance completed", results=results)
|
||
|
|
return results
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
await session.rollback()
|
||
|
|
logger.error("Database maintenance failed", error=str(e))
|
||
|
|
raise DatabaseError(f"Maintenance failed: {str(e)}")
|
||
|
|
|
||
|
|
|
||
|
|
class QueryLogger:
|
||
|
|
"""Utility for logging and analyzing database queries"""
|
||
|
|
|
||
|
|
def __init__(self, session: AsyncSession):
|
||
|
|
self.session = session
|
||
|
|
self._query_log = []
|
||
|
|
|
||
|
|
async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None):
|
||
|
|
"""Log a database query with metadata"""
|
||
|
|
log_entry = {
|
||
|
|
"query": query,
|
||
|
|
"params": params,
|
||
|
|
"execution_time": execution_time,
|
||
|
|
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||
|
|
}
|
||
|
|
|
||
|
|
self._query_log.append(log_entry)
|
||
|
|
|
||
|
|
# Log slow queries
|
||
|
|
if execution_time and execution_time > 1.0: # 1 second threshold
|
||
|
|
logger.warning("Slow query detected",
|
||
|
|
query=query,
|
||
|
|
execution_time=execution_time)
|
||
|
|
|
||
|
|
def get_query_stats(self) -> Dict[str, Any]:
|
||
|
|
"""Get statistics about logged queries"""
|
||
|
|
if not self._query_log:
|
||
|
|
return {"total_queries": 0}
|
||
|
|
|
||
|
|
execution_times = [
|
||
|
|
entry["execution_time"]
|
||
|
|
for entry in self._query_log
|
||
|
|
if entry["execution_time"] is not None
|
||
|
|
]
|
||
|
|
|
||
|
|
return {
|
||
|
|
"total_queries": len(self._query_log),
|
||
|
|
"avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0,
|
||
|
|
"max_execution_time": max(execution_times) if execution_times else 0,
|
||
|
|
"slow_queries_count": len([t for t in execution_times if t > 1.0])
|
||
|
|
}
|
||
|
|
|
||
|
|
def clear_log(self):
|
||
|
|
"""Clear the query log"""
|
||
|
|
self._query_log.clear()
|