REFACTOR - Database logic
This commit is contained in:
402
shared/database/utils.py
Normal file
402
shared/database/utils.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Database Utilities
|
||||
Helper functions for database operations and maintenance
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text, inspect
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
import structlog
|
||||
|
||||
from .exceptions import DatabaseError, HealthCheckError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DatabaseUtils:
|
||||
"""Utility functions for database operations"""
|
||||
|
||||
@staticmethod
|
||||
async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive database health check
|
||||
|
||||
Returns:
|
||||
Dict with health status, metrics, and diagnostics
|
||||
"""
|
||||
try:
|
||||
# Basic connectivity test
|
||||
start_time = __import__('time').time()
|
||||
await session.execute(text("SELECT 1"))
|
||||
response_time = __import__('time').time() - start_time
|
||||
|
||||
# Get database info
|
||||
db_info = await DatabaseUtils._get_database_info(session)
|
||||
|
||||
# Connection pool status (if available)
|
||||
pool_info = await DatabaseUtils._get_pool_info(session)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_seconds": round(response_time, 4),
|
||||
"database": db_info,
|
||||
"connection_pool": pool_info,
|
||||
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
raise HealthCheckError(f"Health check failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def _get_database_info(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get database server information"""
|
||||
try:
|
||||
# Try to get database version and basic stats
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
version_result = await session.execute(text("SELECT version()"))
|
||||
version = version_result.scalar()
|
||||
|
||||
stats_result = await session.execute(text("""
|
||||
SELECT
|
||||
count(*) as active_connections,
|
||||
(SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections
|
||||
FROM pg_stat_activity
|
||||
WHERE state = 'active'
|
||||
"""))
|
||||
stats = stats_result.fetchone()
|
||||
|
||||
return {
|
||||
"type": "postgresql",
|
||||
"version": version,
|
||||
"active_connections": stats.active_connections if stats else 0,
|
||||
"max_connections": stats.max_connections if stats else "unknown"
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
version_result = await session.execute(text("SELECT sqlite_version()"))
|
||||
version = version_result.scalar()
|
||||
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"version": version,
|
||||
"active_connections": 1,
|
||||
"max_connections": "unlimited"
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"type": session.bind.dialect.name,
|
||||
"version": "unknown",
|
||||
"active_connections": "unknown",
|
||||
"max_connections": "unknown"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve database info", error=str(e))
|
||||
return {
|
||||
"type": session.bind.dialect.name,
|
||||
"version": "unknown",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get connection pool information"""
|
||||
try:
|
||||
pool = session.bind.pool
|
||||
if pool:
|
||||
return {
|
||||
"size": pool.size(),
|
||||
"checked_in": pool.checkedin(),
|
||||
"checked_out": pool.checkedout(),
|
||||
"overflow": pool.overflow(),
|
||||
"invalid": pool.invalid()
|
||||
}
|
||||
else:
|
||||
return {"status": "no_pool"}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve pool info", error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate database schema against expected tables
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
expected_tables: List of table names that should exist
|
||||
|
||||
Returns:
|
||||
Validation results with missing/extra tables
|
||||
"""
|
||||
try:
|
||||
# Get existing tables
|
||||
inspector = inspect(session.bind)
|
||||
existing_tables = set(inspector.get_table_names())
|
||||
expected_tables_set = set(expected_tables)
|
||||
|
||||
missing_tables = expected_tables_set - existing_tables
|
||||
extra_tables = existing_tables - expected_tables_set
|
||||
|
||||
return {
|
||||
"valid": len(missing_tables) == 0,
|
||||
"existing_tables": list(existing_tables),
|
||||
"expected_tables": expected_tables,
|
||||
"missing_tables": list(missing_tables),
|
||||
"extra_tables": list(extra_tables),
|
||||
"total_tables": len(existing_tables)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Schema validation failed", error=str(e))
|
||||
raise DatabaseError(f"Schema validation failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for specified tables
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
table_names: List of table names to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary with table statistics
|
||||
"""
|
||||
try:
|
||||
stats = {}
|
||||
|
||||
for table_name in table_names:
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
# PostgreSQL specific queries
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
size_result = await session.execute(
|
||||
text(f"SELECT pg_total_relation_size('{table_name}')")
|
||||
)
|
||||
table_size = size_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": table_size,
|
||||
"size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
# SQLite specific queries
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": "unknown",
|
||||
"size_mb": "unknown"
|
||||
}
|
||||
|
||||
else:
|
||||
# Generic fallback
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": "unknown",
|
||||
"size_mb": "unknown"
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get table statistics",
|
||||
tables=table_names, error=str(e))
|
||||
raise DatabaseError(f"Failed to get table stats: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_records(
|
||||
session: AsyncSession,
|
||||
table_name: str,
|
||||
date_column: str,
|
||||
days_old: int,
|
||||
batch_size: int = 1000
|
||||
) -> int:
|
||||
"""
|
||||
Clean up old records from a table
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
table_name: Name of table to clean
|
||||
date_column: Date column to filter by
|
||||
days_old: Records older than this many days will be deleted
|
||||
batch_size: Number of records to delete per batch
|
||||
|
||||
Returns:
|
||||
Total number of records deleted
|
||||
"""
|
||||
try:
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||||
AND ctid IN (
|
||||
SELECT ctid FROM {table_name}
|
||||
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||||
LIMIT :batch_size
|
||||
)
|
||||
""")
|
||||
params = {
|
||||
"days_param": f"{days_old} days",
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < datetime('now', :days_param)
|
||||
AND rowid IN (
|
||||
SELECT rowid FROM {table_name}
|
||||
WHERE {date_column} < datetime('now', :days_param)
|
||||
LIMIT :batch_size
|
||||
)
|
||||
""")
|
||||
params = {
|
||||
"days_param": f"-{days_old} days",
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
else:
|
||||
# Generic fallback (may not work for all databases)
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY)
|
||||
LIMIT :batch_size
|
||||
""")
|
||||
params = {
|
||||
"days_old": days_old,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
result = await session.execute(delete_query, params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if deleted_count == 0:
|
||||
break
|
||||
|
||||
total_deleted += deleted_count
|
||||
await session.commit()
|
||||
|
||||
logger.debug(f"Deleted batch from {table_name}",
|
||||
batch_size=deleted_count,
|
||||
total_deleted=total_deleted)
|
||||
|
||||
logger.info(f"Cleanup completed for {table_name}",
|
||||
total_deleted=total_deleted,
|
||||
days_old=days_old)
|
||||
|
||||
return total_deleted
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Cleanup failed for {table_name}", error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute database maintenance tasks
|
||||
|
||||
Returns:
|
||||
Dictionary with maintenance results
|
||||
"""
|
||||
try:
|
||||
results = {}
|
||||
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
# PostgreSQL maintenance
|
||||
await session.execute(text("VACUUM ANALYZE"))
|
||||
results["vacuum"] = "completed"
|
||||
|
||||
# Update statistics
|
||||
await session.execute(text("ANALYZE"))
|
||||
results["analyze"] = "completed"
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
# SQLite maintenance
|
||||
await session.execute(text("VACUUM"))
|
||||
results["vacuum"] = "completed"
|
||||
|
||||
await session.execute(text("ANALYZE"))
|
||||
results["analyze"] = "completed"
|
||||
|
||||
else:
|
||||
results["maintenance"] = "not_supported"
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("Database maintenance completed", results=results)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Database maintenance failed", error=str(e))
|
||||
raise DatabaseError(f"Maintenance failed: {str(e)}")
|
||||
|
||||
|
||||
class QueryLogger:
|
||||
"""Utility for logging and analyzing database queries"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self._query_log = []
|
||||
|
||||
async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None):
|
||||
"""Log a database query with metadata"""
|
||||
log_entry = {
|
||||
"query": query,
|
||||
"params": params,
|
||||
"execution_time": execution_time,
|
||||
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
self._query_log.append(log_entry)
|
||||
|
||||
# Log slow queries
|
||||
if execution_time and execution_time > 1.0: # 1 second threshold
|
||||
logger.warning("Slow query detected",
|
||||
query=query,
|
||||
execution_time=execution_time)
|
||||
|
||||
def get_query_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about logged queries"""
|
||||
if not self._query_log:
|
||||
return {"total_queries": 0}
|
||||
|
||||
execution_times = [
|
||||
entry["execution_time"]
|
||||
for entry in self._query_log
|
||||
if entry["execution_time"] is not None
|
||||
]
|
||||
|
||||
return {
|
||||
"total_queries": len(self._query_log),
|
||||
"avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0,
|
||||
"max_execution_time": max(execution_times) if execution_times else 0,
|
||||
"slow_queries_count": len([t for t in execution_times if t > 1.0])
|
||||
}
|
||||
|
||||
def clear_log(self):
|
||||
"""Clear the query log"""
|
||||
self._query_log.clear()
|
||||
Reference in New Issue
Block a user