409 lines
17 KiB
Python
409 lines
17 KiB
Python
"""
|
|
Enhanced Base Database Configuration for All Microservices
|
|
Provides DatabaseManager with connection pooling, health checks, and multi-database support
|
|
|
|
Fixed: SSL configuration now uses connect_args instead of URL parameters to avoid asyncpg parameter parsing issues
|
|
"""
|
|
|
|
import os
|
|
import ssl
|
|
from typing import Optional, Dict, Any, List
|
|
from sqlalchemy import create_engine, text
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
|
from sqlalchemy.orm import sessionmaker, declarative_base
|
|
from sqlalchemy.pool import StaticPool
|
|
from contextlib import asynccontextmanager
|
|
import structlog
|
|
import time
|
|
|
|
from .exceptions import DatabaseError, ConnectionError, HealthCheckError
|
|
from .utils import DatabaseUtils
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
Base = declarative_base()
|
|
|
|
class DatabaseManager:
|
|
"""Enhanced Database Manager for Microservices
|
|
|
|
Provides:
|
|
- Connection pooling with configurable settings
|
|
- Health checks and monitoring
|
|
- Multi-database support
|
|
- Session lifecycle management
|
|
- Background task session support
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
database_url: str,
|
|
service_name: str = "unknown",
|
|
pool_size: int = 20,
|
|
max_overflow: int = 30,
|
|
pool_recycle: int = 3600,
|
|
pool_pre_ping: bool = True,
|
|
echo: bool = False,
|
|
connect_timeout: int = 30,
|
|
**engine_kwargs
|
|
):
|
|
self.database_url = database_url
|
|
|
|
# Configure SSL for PostgreSQL via connect_args instead of URL parameters
|
|
# This avoids asyncpg parameter parsing issues
|
|
self.use_ssl = False
|
|
if "postgresql" in database_url.lower():
|
|
# Check if SSL is already configured in URL or should be enabled
|
|
if "ssl" not in database_url.lower() and "sslmode" not in database_url.lower():
|
|
# Enable SSL for production, but allow override via URL
|
|
self.use_ssl = True
|
|
logger.info(f"SSL will be enabled for PostgreSQL connection: {service_name}")
|
|
self.service_name = service_name
|
|
self.pool_size = pool_size
|
|
self.max_overflow = max_overflow
|
|
|
|
# Configure pool for async engines
|
|
# Note: SQLAlchemy 2.0 async engines automatically use AsyncAdaptedQueuePool
|
|
# We should NOT specify poolclass for async engines unless using StaticPool for SQLite
|
|
|
|
# Prepare connect_args for asyncpg
|
|
connect_args = {"timeout": connect_timeout}
|
|
|
|
# Add SSL configuration if needed (for asyncpg driver)
|
|
if self.use_ssl and "asyncpg" in database_url.lower():
|
|
# Create SSL context that doesn't verify certificates (for local development)
|
|
# In production, you should use a proper SSL context with certificate verification
|
|
ssl_context = ssl.create_default_context()
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
connect_args["ssl"] = ssl_context
|
|
logger.info(f"SSL enabled with relaxed verification for {service_name}")
|
|
|
|
engine_config = {
|
|
"echo": echo,
|
|
"pool_pre_ping": pool_pre_ping,
|
|
"pool_recycle": pool_recycle,
|
|
"pool_size": pool_size,
|
|
"max_overflow": max_overflow,
|
|
"connect_args": connect_args,
|
|
**engine_kwargs
|
|
}
|
|
|
|
# Only set poolclass for SQLite (requires StaticPool for async)
|
|
if "sqlite" in database_url.lower():
|
|
engine_config["poolclass"] = StaticPool
|
|
engine_config["pool_size"] = 1
|
|
engine_config["max_overflow"] = 0
|
|
|
|
self.async_engine = create_async_engine(database_url, **engine_config)
|
|
|
|
# Create session factory
|
|
self.async_session_local = async_sessionmaker(
|
|
self.async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False
|
|
)
|
|
|
|
logger.info(f"DatabaseManager initialized for {service_name}",
|
|
pool_size=pool_size,
|
|
max_overflow=max_overflow,
|
|
database_type=self._get_database_type())
|
|
|
|
async def get_db(self):
|
|
"""Get database session for request handlers (FastAPI dependency)"""
|
|
async with self.async_session_local() as session:
|
|
try:
|
|
logger.debug("Database session created for request")
|
|
yield session
|
|
except Exception as e:
|
|
await session.rollback()
|
|
|
|
# Don't wrap HTTPExceptions - let them pass through
|
|
# Check by type name to avoid import dependencies
|
|
exception_type = type(e).__name__
|
|
if exception_type in ('HTTPException', 'StarletteHTTPException'):
|
|
logger.debug(f"Re-raising HTTPException: {e}", service=self.service_name)
|
|
raise
|
|
|
|
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
|
|
logger.error(f"Database session error: {error_msg}", service=self.service_name)
|
|
|
|
# Handle specific ASGI stream issues more gracefully
|
|
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
|
|
raise DatabaseError(f"Session error: Request stream disconnected ({type(e).__name__})")
|
|
else:
|
|
raise DatabaseError(f"Session error: {error_msg}")
|
|
finally:
|
|
await session.close()
|
|
logger.debug("Database session closed")
|
|
|
|
@asynccontextmanager
|
|
async def get_background_session(self):
|
|
"""
|
|
Get database session for background tasks with auto-commit
|
|
|
|
Usage:
|
|
async with database_manager.get_background_session() as session:
|
|
# Your background task code here
|
|
# Auto-commits on success, rolls back on exception
|
|
"""
|
|
async with self.async_session_local() as session:
|
|
try:
|
|
logger.debug("Background session created", service=self.service_name)
|
|
yield session
|
|
await session.commit()
|
|
logger.debug("Background session committed")
|
|
except Exception as e:
|
|
await session.rollback()
|
|
logger.error(f"Background task database error: {e}",
|
|
service=self.service_name)
|
|
raise DatabaseError(f"Background task failed: {str(e)}")
|
|
finally:
|
|
await session.close()
|
|
logger.debug("Background session closed")
|
|
|
|
@asynccontextmanager
|
|
async def get_session(self):
|
|
"""Get a plain database session (no auto-commit)"""
|
|
async with self.async_session_local() as session:
|
|
try:
|
|
yield session
|
|
except Exception as e:
|
|
await session.rollback()
|
|
|
|
# Don't wrap HTTPExceptions - let them pass through
|
|
exception_type = type(e).__name__
|
|
if exception_type in ('HTTPException', 'StarletteHTTPException'):
|
|
logger.debug(f"Re-raising HTTPException: {e}", service=self.service_name)
|
|
raise
|
|
|
|
logger.error(f"Session error: {e}", service=self.service_name)
|
|
raise DatabaseError(f"Session error: {str(e)}")
|
|
finally:
|
|
await session.close()
|
|
|
|
# ===== TABLE MANAGEMENT =====
|
|
|
|
async def create_tables(self, metadata=None):
|
|
"""Create database tables with enhanced error handling and transaction verification"""
|
|
try:
|
|
target_metadata = metadata or Base.metadata
|
|
table_names = list(target_metadata.tables.keys())
|
|
logger.info(f"Creating tables: {table_names}", service=self.service_name)
|
|
|
|
# Use explicit transaction with proper error handling
|
|
async with self.async_engine.begin() as conn:
|
|
try:
|
|
# Create tables within the transaction
|
|
await conn.run_sync(target_metadata.create_all, checkfirst=True)
|
|
|
|
# Verify transaction is not in error state
|
|
# Try a simple query to ensure connection is still valid
|
|
await conn.execute(text("SELECT 1"))
|
|
|
|
logger.info("Database tables creation transaction completed successfully",
|
|
service=self.service_name, tables=table_names)
|
|
|
|
except Exception as create_error:
|
|
logger.error(f"Error during table creation within transaction: {create_error}",
|
|
service=self.service_name)
|
|
# Re-raise to trigger transaction rollback
|
|
raise
|
|
|
|
logger.info("Database tables created successfully", service=self.service_name)
|
|
|
|
except Exception as e:
|
|
# Check if it's a "relation already exists" error which can be safely ignored
|
|
error_str = str(e).lower()
|
|
if "already exists" in error_str or "duplicate" in error_str:
|
|
logger.warning(f"Some database objects already exist - continuing: {e}", service=self.service_name)
|
|
logger.info("Database tables creation completed (some already existed)", service=self.service_name)
|
|
else:
|
|
logger.error(f"Failed to create tables: {e}", service=self.service_name)
|
|
|
|
# Check for specific transaction error indicators
|
|
if any(indicator in error_str for indicator in [
|
|
"transaction", "rollback", "aborted", "failed sql transaction"
|
|
]):
|
|
logger.error("Transaction-related error detected during table creation",
|
|
service=self.service_name)
|
|
|
|
raise DatabaseError(f"Table creation failed: {str(e)}")
|
|
|
|
async def drop_tables(self, metadata=None):
|
|
"""Drop database tables"""
|
|
try:
|
|
target_metadata = metadata or Base.metadata
|
|
async with self.async_engine.begin() as conn:
|
|
await conn.run_sync(target_metadata.drop_all)
|
|
logger.info("Database tables dropped successfully", service=self.service_name)
|
|
except Exception as e:
|
|
logger.error(f"Failed to drop tables: {e}", service=self.service_name)
|
|
raise DatabaseError(f"Table drop failed: {str(e)}")
|
|
|
|
# ===== HEALTH CHECKS AND MONITORING =====
|
|
|
|
async def health_check(self) -> Dict[str, Any]:
|
|
"""Comprehensive health check for the database"""
|
|
try:
|
|
async with self.get_session() as session:
|
|
return await DatabaseUtils.execute_health_check(session)
|
|
except Exception as e:
|
|
logger.error(f"Health check failed: {e}", service=self.service_name)
|
|
raise HealthCheckError(f"Health check failed: {str(e)}")
|
|
|
|
async def get_connection_info(self) -> Dict[str, Any]:
|
|
"""Get database connection information"""
|
|
try:
|
|
pool = self.async_engine.pool
|
|
return {
|
|
"service_name": self.service_name,
|
|
"database_type": self._get_database_type(),
|
|
"pool_size": self.pool_size,
|
|
"max_overflow": self.max_overflow,
|
|
"current_checked_in": pool.checkedin() if pool else 0,
|
|
"current_checked_out": pool.checkedout() if pool else 0,
|
|
"current_overflow": pool.overflow() if pool else 0,
|
|
"invalid_connections": getattr(pool, 'invalid', lambda: 0)() if pool else 0
|
|
}
|
|
except Exception as e:
|
|
logger.error(f"Failed to get connection info: {e}", service=self.service_name)
|
|
return {"error": str(e)}
|
|
|
|
def _get_database_type(self) -> str:
|
|
"""Get database type from URL"""
|
|
return self.database_url.split("://")[0].lower() if "://" in self.database_url else "unknown"
|
|
|
|
# ===== CLEANUP AND MAINTENANCE =====
|
|
|
|
async def close_connections(self):
|
|
"""Close all database connections"""
|
|
try:
|
|
await self.async_engine.dispose()
|
|
logger.info("Database connections closed", service=self.service_name)
|
|
except Exception as e:
|
|
logger.error(f"Failed to close connections: {e}", service=self.service_name)
|
|
raise DatabaseError(f"Connection cleanup failed: {str(e)}")
|
|
|
|
async def execute_maintenance(self) -> Dict[str, Any]:
|
|
"""Execute database maintenance tasks"""
|
|
try:
|
|
async with self.get_session() as session:
|
|
return await DatabaseUtils.execute_maintenance(session)
|
|
except Exception as e:
|
|
logger.error(f"Maintenance failed: {e}", service=self.service_name)
|
|
raise DatabaseError(f"Maintenance failed: {str(e)}")
|
|
|
|
# ===== UTILITY METHODS =====
|
|
|
|
async def test_connection(self) -> bool:
|
|
"""Test database connectivity"""
|
|
try:
|
|
async with self.async_engine.begin() as conn:
|
|
await conn.execute(text("SELECT 1"))
|
|
logger.debug("Connection test successful", service=self.service_name)
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Connection test failed: {e}", service=self.service_name)
|
|
return False
|
|
|
|
def __repr__(self) -> str:
|
|
return f"DatabaseManager(service='{self.service_name}', type='{self._get_database_type()}')"
|
|
|
|
async def execute(self, query: str, *args, **kwargs):
|
|
"""
|
|
Execute a raw SQL query with proper session management
|
|
Note: Use this method carefully to avoid transaction conflicts
|
|
"""
|
|
from sqlalchemy import text
|
|
|
|
# Use a new session context to avoid conflicts with existing sessions
|
|
async with self.get_session() as session:
|
|
try:
|
|
# Convert query to SQLAlchemy text object if it's a string
|
|
if isinstance(query, str):
|
|
query = text(query)
|
|
|
|
result = await session.execute(query, *args, **kwargs)
|
|
# For UPDATE/DELETE operations that need to be committed
|
|
if query.text.strip().upper().startswith(('UPDATE', 'DELETE', 'INSERT')):
|
|
await session.commit()
|
|
|
|
return result
|
|
except Exception as e:
|
|
# Only rollback if it was a modifying operation
|
|
if isinstance(query, str) and query.strip().upper().startswith(('UPDATE', 'DELETE', 'INSERT')):
|
|
await session.rollback()
|
|
logger.error("Database execute failed", error=str(e))
|
|
raise
|
|
|
|
|
|
# ===== CONVENIENCE FUNCTIONS =====
|
|
|
|
# ===== CONVENIENCE FUNCTIONS =====
|
|
|
|
def create_database_manager(
|
|
database_url: str,
|
|
service_name: str,
|
|
**kwargs
|
|
) -> DatabaseManager:
|
|
"""Factory function to create DatabaseManager instances"""
|
|
return DatabaseManager(database_url, service_name, **kwargs)
|
|
|
|
|
|
# ===== LEGACY COMPATIBILITY =====
|
|
|
|
# Keep backward compatibility for existing code
|
|
engine = None
|
|
AsyncSessionLocal = None
|
|
|
|
def init_legacy_compatibility(database_url: str):
|
|
"""Initialize legacy global variables for backward compatibility"""
|
|
global engine, AsyncSessionLocal
|
|
|
|
# Configure SSL for PostgreSQL if needed
|
|
connect_args = {}
|
|
if "postgresql" in database_url.lower() and "asyncpg" in database_url.lower():
|
|
if "ssl" not in database_url.lower() and "sslmode" not in database_url.lower():
|
|
# Create SSL context that doesn't verify certificates (for local development)
|
|
ssl_context = ssl.create_default_context()
|
|
ssl_context.check_hostname = False
|
|
ssl_context.verify_mode = ssl.CERT_NONE
|
|
connect_args["ssl"] = ssl_context
|
|
logger.info("SSL enabled with relaxed verification for legacy database connection")
|
|
|
|
engine = create_async_engine(
|
|
database_url,
|
|
echo=False,
|
|
pool_pre_ping=True,
|
|
pool_recycle=300,
|
|
pool_size=20,
|
|
max_overflow=30,
|
|
connect_args=connect_args
|
|
)
|
|
|
|
AsyncSessionLocal = async_sessionmaker(
|
|
engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False
|
|
)
|
|
|
|
logger.warning("Using legacy database configuration - consider migrating to DatabaseManager")
|
|
|
|
|
|
async def get_legacy_db():
|
|
"""Legacy database session getter for backward compatibility"""
|
|
if not AsyncSessionLocal:
|
|
raise RuntimeError("Legacy database not initialized - call init_legacy_compatibility first")
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
try:
|
|
yield session
|
|
except Exception as e:
|
|
logger.error(f"Legacy database session error: {e}")
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|