Files
bakery-ia/shared/database/utils.py
2025-08-08 09:08:41 +02:00

402 lines
15 KiB
Python

"""
Database Utilities
Helper functions for database operations and maintenance
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text, inspect
from sqlalchemy.exc import SQLAlchemyError
import structlog
from .exceptions import DatabaseError, HealthCheckError
logger = structlog.get_logger()
class DatabaseUtils:
"""Utility functions for database operations"""
@staticmethod
async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]:
"""
Comprehensive database health check
Returns:
Dict with health status, metrics, and diagnostics
"""
try:
# Basic connectivity test
start_time = __import__('time').time()
await session.execute(text("SELECT 1"))
response_time = __import__('time').time() - start_time
# Get database info
db_info = await DatabaseUtils._get_database_info(session)
# Connection pool status (if available)
pool_info = await DatabaseUtils._get_pool_info(session)
return {
"status": "healthy",
"response_time_seconds": round(response_time, 4),
"database": db_info,
"connection_pool": pool_info,
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
}
except Exception as e:
logger.error("Database health check failed", error=str(e))
raise HealthCheckError(f"Health check failed: {str(e)}")
@staticmethod
async def _get_database_info(session: AsyncSession) -> Dict[str, Any]:
"""Get database server information"""
try:
# Try to get database version and basic stats
if session.bind.dialect.name == 'postgresql':
version_result = await session.execute(text("SELECT version()"))
version = version_result.scalar()
stats_result = await session.execute(text("""
SELECT
count(*) as active_connections,
(SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections
FROM pg_stat_activity
WHERE state = 'active'
"""))
stats = stats_result.fetchone()
return {
"type": "postgresql",
"version": version,
"active_connections": stats.active_connections if stats else 0,
"max_connections": stats.max_connections if stats else "unknown"
}
elif session.bind.dialect.name == 'sqlite':
version_result = await session.execute(text("SELECT sqlite_version()"))
version = version_result.scalar()
return {
"type": "sqlite",
"version": version,
"active_connections": 1,
"max_connections": "unlimited"
}
else:
return {
"type": session.bind.dialect.name,
"version": "unknown",
"active_connections": "unknown",
"max_connections": "unknown"
}
except Exception as e:
logger.warning("Could not retrieve database info", error=str(e))
return {
"type": session.bind.dialect.name,
"version": "unknown",
"error": str(e)
}
@staticmethod
async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]:
"""Get connection pool information"""
try:
pool = session.bind.pool
if pool:
return {
"size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"invalid": pool.invalid()
}
else:
return {"status": "no_pool"}
except Exception as e:
logger.warning("Could not retrieve pool info", error=str(e))
return {"error": str(e)}
@staticmethod
async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]:
"""
Validate database schema against expected tables
Args:
session: Database session
expected_tables: List of table names that should exist
Returns:
Validation results with missing/extra tables
"""
try:
# Get existing tables
inspector = inspect(session.bind)
existing_tables = set(inspector.get_table_names())
expected_tables_set = set(expected_tables)
missing_tables = expected_tables_set - existing_tables
extra_tables = existing_tables - expected_tables_set
return {
"valid": len(missing_tables) == 0,
"existing_tables": list(existing_tables),
"expected_tables": expected_tables,
"missing_tables": list(missing_tables),
"extra_tables": list(extra_tables),
"total_tables": len(existing_tables)
}
except Exception as e:
logger.error("Schema validation failed", error=str(e))
raise DatabaseError(f"Schema validation failed: {str(e)}")
@staticmethod
async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]:
"""
Get statistics for specified tables
Args:
session: Database session
table_names: List of table names to analyze
Returns:
Dictionary with table statistics
"""
try:
stats = {}
for table_name in table_names:
if session.bind.dialect.name == 'postgresql':
# PostgreSQL specific queries
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
size_result = await session.execute(
text(f"SELECT pg_total_relation_size('{table_name}')")
)
table_size = size_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": table_size,
"size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0
}
elif session.bind.dialect.name == 'sqlite':
# SQLite specific queries
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": "unknown",
"size_mb": "unknown"
}
else:
# Generic fallback
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": "unknown",
"size_mb": "unknown"
}
return stats
except Exception as e:
logger.error("Failed to get table statistics",
tables=table_names, error=str(e))
raise DatabaseError(f"Failed to get table stats: {str(e)}")
@staticmethod
async def cleanup_old_records(
session: AsyncSession,
table_name: str,
date_column: str,
days_old: int,
batch_size: int = 1000
) -> int:
"""
Clean up old records from a table
Args:
session: Database session
table_name: Name of table to clean
date_column: Date column to filter by
days_old: Records older than this many days will be deleted
batch_size: Number of records to delete per batch
Returns:
Total number of records deleted
"""
try:
total_deleted = 0
while True:
if session.bind.dialect.name == 'postgresql':
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < NOW() - INTERVAL :days_param
AND ctid IN (
SELECT ctid FROM {table_name}
WHERE {date_column} < NOW() - INTERVAL :days_param
LIMIT :batch_size
)
""")
params = {
"days_param": f"{days_old} days",
"batch_size": batch_size
}
elif session.bind.dialect.name == 'sqlite':
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < datetime('now', :days_param)
AND rowid IN (
SELECT rowid FROM {table_name}
WHERE {date_column} < datetime('now', :days_param)
LIMIT :batch_size
)
""")
params = {
"days_param": f"-{days_old} days",
"batch_size": batch_size
}
else:
# Generic fallback (may not work for all databases)
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY)
LIMIT :batch_size
""")
params = {
"days_old": days_old,
"batch_size": batch_size
}
result = await session.execute(delete_query, params)
deleted_count = result.rowcount
if deleted_count == 0:
break
total_deleted += deleted_count
await session.commit()
logger.debug(f"Deleted batch from {table_name}",
batch_size=deleted_count,
total_deleted=total_deleted)
logger.info(f"Cleanup completed for {table_name}",
total_deleted=total_deleted,
days_old=days_old)
return total_deleted
except Exception as e:
await session.rollback()
logger.error(f"Cleanup failed for {table_name}", error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
@staticmethod
async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]:
"""
Execute database maintenance tasks
Returns:
Dictionary with maintenance results
"""
try:
results = {}
if session.bind.dialect.name == 'postgresql':
# PostgreSQL maintenance
await session.execute(text("VACUUM ANALYZE"))
results["vacuum"] = "completed"
# Update statistics
await session.execute(text("ANALYZE"))
results["analyze"] = "completed"
elif session.bind.dialect.name == 'sqlite':
# SQLite maintenance
await session.execute(text("VACUUM"))
results["vacuum"] = "completed"
await session.execute(text("ANALYZE"))
results["analyze"] = "completed"
else:
results["maintenance"] = "not_supported"
await session.commit()
logger.info("Database maintenance completed", results=results)
return results
except Exception as e:
await session.rollback()
logger.error("Database maintenance failed", error=str(e))
raise DatabaseError(f"Maintenance failed: {str(e)}")
class QueryLogger:
"""Utility for logging and analyzing database queries"""
def __init__(self, session: AsyncSession):
self.session = session
self._query_log = []
async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None):
"""Log a database query with metadata"""
log_entry = {
"query": query,
"params": params,
"execution_time": execution_time,
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
}
self._query_log.append(log_entry)
# Log slow queries
if execution_time and execution_time > 1.0: # 1 second threshold
logger.warning("Slow query detected",
query=query,
execution_time=execution_time)
def get_query_stats(self) -> Dict[str, Any]:
"""Get statistics about logged queries"""
if not self._query_log:
return {"total_queries": 0}
execution_times = [
entry["execution_time"]
for entry in self._query_log
if entry["execution_time"] is not None
]
return {
"total_queries": len(self._query_log),
"avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0,
"max_execution_time": max(execution_times) if execution_times else 0,
"slow_queries_count": len([t for t in execution_times if t > 1.0])
}
def clear_log(self):
"""Clear the query log"""
self._query_log.clear()