Initial commit - production deployment
This commit is contained in:
68
shared/database/__init__.py
Executable file
68
shared/database/__init__.py
Executable file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Shared Database Infrastructure
|
||||
Provides consistent database patterns across all microservices
|
||||
"""
|
||||
|
||||
from .base import DatabaseManager, Base, create_database_manager
|
||||
from .repository import BaseRepository
|
||||
from .unit_of_work import UnitOfWork, ServiceUnitOfWork, RepositoryRegistry
|
||||
from .transactions import (
|
||||
transactional,
|
||||
unit_of_work_transactional,
|
||||
managed_transaction,
|
||||
managed_unit_of_work,
|
||||
TransactionManager,
|
||||
run_in_transaction,
|
||||
run_with_unit_of_work
|
||||
)
|
||||
from .exceptions import (
|
||||
DatabaseError,
|
||||
ConnectionError,
|
||||
RecordNotFoundError,
|
||||
DuplicateRecordError,
|
||||
ConstraintViolationError,
|
||||
TransactionError,
|
||||
ValidationError,
|
||||
MigrationError,
|
||||
HealthCheckError
|
||||
)
|
||||
from .utils import DatabaseUtils, QueryLogger
|
||||
|
||||
__all__ = [
|
||||
# Core components
|
||||
"DatabaseManager",
|
||||
"Base",
|
||||
"create_database_manager",
|
||||
|
||||
# Repository pattern
|
||||
"BaseRepository",
|
||||
|
||||
# Unit of Work pattern
|
||||
"UnitOfWork",
|
||||
"ServiceUnitOfWork",
|
||||
"RepositoryRegistry",
|
||||
|
||||
# Transaction management
|
||||
"transactional",
|
||||
"unit_of_work_transactional",
|
||||
"managed_transaction",
|
||||
"managed_unit_of_work",
|
||||
"TransactionManager",
|
||||
"run_in_transaction",
|
||||
"run_with_unit_of_work",
|
||||
|
||||
# Exceptions
|
||||
"DatabaseError",
|
||||
"ConnectionError",
|
||||
"RecordNotFoundError",
|
||||
"DuplicateRecordError",
|
||||
"ConstraintViolationError",
|
||||
"TransactionError",
|
||||
"ValidationError",
|
||||
"MigrationError",
|
||||
"HealthCheckError",
|
||||
|
||||
# Utilities
|
||||
"DatabaseUtils",
|
||||
"QueryLogger"
|
||||
]
|
||||
408
shared/database/base.py
Executable file
408
shared/database/base.py
Executable file
@@ -0,0 +1,408 @@
|
||||
"""
|
||||
Enhanced Base Database Configuration for All Microservices
|
||||
Provides DatabaseManager with connection pooling, health checks, and multi-database support
|
||||
|
||||
Fixed: SSL configuration now uses connect_args instead of URL parameters to avoid asyncpg parameter parsing issues
|
||||
"""
|
||||
|
||||
import os
|
||||
import ssl
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from contextlib import asynccontextmanager
|
||||
import structlog
|
||||
import time
|
||||
|
||||
from .exceptions import DatabaseError, ConnectionError, HealthCheckError
|
||||
from .utils import DatabaseUtils
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class DatabaseManager:
|
||||
"""Enhanced Database Manager for Microservices
|
||||
|
||||
Provides:
|
||||
- Connection pooling with configurable settings
|
||||
- Health checks and monitoring
|
||||
- Multi-database support
|
||||
- Session lifecycle management
|
||||
- Background task session support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database_url: str,
|
||||
service_name: str = "unknown",
|
||||
pool_size: int = 20,
|
||||
max_overflow: int = 30,
|
||||
pool_recycle: int = 3600,
|
||||
pool_pre_ping: bool = True,
|
||||
echo: bool = False,
|
||||
connect_timeout: int = 30,
|
||||
**engine_kwargs
|
||||
):
|
||||
self.database_url = database_url
|
||||
|
||||
# Configure SSL for PostgreSQL via connect_args instead of URL parameters
|
||||
# This avoids asyncpg parameter parsing issues
|
||||
self.use_ssl = False
|
||||
if "postgresql" in database_url.lower():
|
||||
# Check if SSL is already configured in URL or should be enabled
|
||||
if "ssl" not in database_url.lower() and "sslmode" not in database_url.lower():
|
||||
# Enable SSL for production, but allow override via URL
|
||||
self.use_ssl = True
|
||||
logger.info(f"SSL will be enabled for PostgreSQL connection: {service_name}")
|
||||
self.service_name = service_name
|
||||
self.pool_size = pool_size
|
||||
self.max_overflow = max_overflow
|
||||
|
||||
# Configure pool for async engines
|
||||
# Note: SQLAlchemy 2.0 async engines automatically use AsyncAdaptedQueuePool
|
||||
# We should NOT specify poolclass for async engines unless using StaticPool for SQLite
|
||||
|
||||
# Prepare connect_args for asyncpg
|
||||
connect_args = {"timeout": connect_timeout}
|
||||
|
||||
# Add SSL configuration if needed (for asyncpg driver)
|
||||
if self.use_ssl and "asyncpg" in database_url.lower():
|
||||
# Create SSL context that doesn't verify certificates (for local development)
|
||||
# In production, you should use a proper SSL context with certificate verification
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
connect_args["ssl"] = ssl_context
|
||||
logger.info(f"SSL enabled with relaxed verification for {service_name}")
|
||||
|
||||
engine_config = {
|
||||
"echo": echo,
|
||||
"pool_pre_ping": pool_pre_ping,
|
||||
"pool_recycle": pool_recycle,
|
||||
"pool_size": pool_size,
|
||||
"max_overflow": max_overflow,
|
||||
"connect_args": connect_args,
|
||||
**engine_kwargs
|
||||
}
|
||||
|
||||
# Only set poolclass for SQLite (requires StaticPool for async)
|
||||
if "sqlite" in database_url.lower():
|
||||
engine_config["poolclass"] = StaticPool
|
||||
engine_config["pool_size"] = 1
|
||||
engine_config["max_overflow"] = 0
|
||||
|
||||
self.async_engine = create_async_engine(database_url, **engine_config)
|
||||
|
||||
# Create session factory
|
||||
self.async_session_local = async_sessionmaker(
|
||||
self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False
|
||||
)
|
||||
|
||||
logger.info(f"DatabaseManager initialized for {service_name}",
|
||||
pool_size=pool_size,
|
||||
max_overflow=max_overflow,
|
||||
database_type=self._get_database_type())
|
||||
|
||||
async def get_db(self):
|
||||
"""Get database session for request handlers (FastAPI dependency)"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created for request")
|
||||
yield session
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
||||
# Don't wrap HTTPExceptions - let them pass through
|
||||
# Check by type name to avoid import dependencies
|
||||
exception_type = type(e).__name__
|
||||
if exception_type in ('HTTPException', 'StarletteHTTPException', 'RequestValidationError', 'ValidationError'):
|
||||
logger.debug(f"Re-raising {exception_type}: {e}", service=self.service_name)
|
||||
raise
|
||||
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
|
||||
logger.error(f"Database session error: {error_msg}", service=self.service_name)
|
||||
|
||||
# Handle specific ASGI stream issues more gracefully
|
||||
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
|
||||
raise DatabaseError(f"Session error: Request stream disconnected ({type(e).__name__})")
|
||||
else:
|
||||
raise DatabaseError(f"Session error: {error_msg}")
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_session(self):
|
||||
"""
|
||||
Get database session for background tasks with auto-commit
|
||||
|
||||
Usage:
|
||||
async with database_manager.get_background_session() as session:
|
||||
# Your background task code here
|
||||
# Auto-commits on success, rolls back on exception
|
||||
"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Background session created", service=self.service_name)
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Background session committed")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Background task database error: {e}",
|
||||
service=self.service_name)
|
||||
raise DatabaseError(f"Background task failed: {str(e)}")
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Background session closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self):
|
||||
"""Get a plain database session (no auto-commit)"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
|
||||
# Don't wrap HTTPExceptions - let them pass through
|
||||
exception_type = type(e).__name__
|
||||
if exception_type in ('HTTPException', 'StarletteHTTPException'):
|
||||
logger.debug(f"Re-raising HTTPException: {e}", service=self.service_name)
|
||||
raise
|
||||
|
||||
logger.error(f"Session error: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Session error: {str(e)}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
# ===== TABLE MANAGEMENT =====
|
||||
|
||||
async def create_tables(self, metadata=None):
|
||||
"""Create database tables with enhanced error handling and transaction verification"""
|
||||
try:
|
||||
target_metadata = metadata or Base.metadata
|
||||
table_names = list(target_metadata.tables.keys())
|
||||
logger.info(f"Creating tables: {table_names}", service=self.service_name)
|
||||
|
||||
# Use explicit transaction with proper error handling
|
||||
async with self.async_engine.begin() as conn:
|
||||
try:
|
||||
# Create tables within the transaction
|
||||
await conn.run_sync(target_metadata.create_all, checkfirst=True)
|
||||
|
||||
# Verify transaction is not in error state
|
||||
# Try a simple query to ensure connection is still valid
|
||||
await conn.execute(text("SELECT 1"))
|
||||
|
||||
logger.info("Database tables creation transaction completed successfully",
|
||||
service=self.service_name, tables=table_names)
|
||||
|
||||
except Exception as create_error:
|
||||
logger.error(f"Error during table creation within transaction: {create_error}",
|
||||
service=self.service_name)
|
||||
# Re-raise to trigger transaction rollback
|
||||
raise
|
||||
|
||||
logger.info("Database tables created successfully", service=self.service_name)
|
||||
|
||||
except Exception as e:
|
||||
# Check if it's a "relation already exists" error which can be safely ignored
|
||||
error_str = str(e).lower()
|
||||
if "already exists" in error_str or "duplicate" in error_str:
|
||||
logger.warning(f"Some database objects already exist - continuing: {e}", service=self.service_name)
|
||||
logger.info("Database tables creation completed (some already existed)", service=self.service_name)
|
||||
else:
|
||||
logger.error(f"Failed to create tables: {e}", service=self.service_name)
|
||||
|
||||
# Check for specific transaction error indicators
|
||||
if any(indicator in error_str for indicator in [
|
||||
"transaction", "rollback", "aborted", "failed sql transaction"
|
||||
]):
|
||||
logger.error("Transaction-related error detected during table creation",
|
||||
service=self.service_name)
|
||||
|
||||
raise DatabaseError(f"Table creation failed: {str(e)}")
|
||||
|
||||
async def drop_tables(self, metadata=None):
|
||||
"""Drop database tables"""
|
||||
try:
|
||||
target_metadata = metadata or Base.metadata
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(target_metadata.drop_all)
|
||||
logger.info("Database tables dropped successfully", service=self.service_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to drop tables: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Table drop failed: {str(e)}")
|
||||
|
||||
# ===== HEALTH CHECKS AND MONITORING =====
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Comprehensive health check for the database"""
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
return await DatabaseUtils.execute_health_check(session)
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}", service=self.service_name)
|
||||
raise HealthCheckError(f"Health check failed: {str(e)}")
|
||||
|
||||
async def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""Get database connection information"""
|
||||
try:
|
||||
pool = self.async_engine.pool
|
||||
return {
|
||||
"service_name": self.service_name,
|
||||
"database_type": self._get_database_type(),
|
||||
"pool_size": self.pool_size,
|
||||
"max_overflow": self.max_overflow,
|
||||
"current_checked_in": pool.checkedin() if pool else 0,
|
||||
"current_checked_out": pool.checkedout() if pool else 0,
|
||||
"current_overflow": pool.overflow() if pool else 0,
|
||||
"invalid_connections": getattr(pool, 'invalid', lambda: 0)() if pool else 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get connection info: {e}", service=self.service_name)
|
||||
return {"error": str(e)}
|
||||
|
||||
def _get_database_type(self) -> str:
|
||||
"""Get database type from URL"""
|
||||
return self.database_url.split("://")[0].lower() if "://" in self.database_url else "unknown"
|
||||
|
||||
# ===== CLEANUP AND MAINTENANCE =====
|
||||
|
||||
async def close_connections(self):
|
||||
"""Close all database connections"""
|
||||
try:
|
||||
await self.async_engine.dispose()
|
||||
logger.info("Database connections closed", service=self.service_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close connections: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Connection cleanup failed: {str(e)}")
|
||||
|
||||
async def execute_maintenance(self) -> Dict[str, Any]:
|
||||
"""Execute database maintenance tasks"""
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
return await DatabaseUtils.execute_maintenance(session)
|
||||
except Exception as e:
|
||||
logger.error(f"Maintenance failed: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Maintenance failed: {str(e)}")
|
||||
|
||||
# ===== UTILITY METHODS =====
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test database connectivity"""
|
||||
try:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Connection test successful", service=self.service_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Connection test failed: {e}", service=self.service_name)
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DatabaseManager(service='{self.service_name}', type='{self._get_database_type()}')"
|
||||
|
||||
async def execute(self, query: str, *args, **kwargs):
|
||||
"""
|
||||
Execute a raw SQL query with proper session management
|
||||
Note: Use this method carefully to avoid transaction conflicts
|
||||
"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# Use a new session context to avoid conflicts with existing sessions
|
||||
async with self.get_session() as session:
|
||||
try:
|
||||
# Convert query to SQLAlchemy text object if it's a string
|
||||
if isinstance(query, str):
|
||||
query = text(query)
|
||||
|
||||
result = await session.execute(query, *args, **kwargs)
|
||||
# For UPDATE/DELETE operations that need to be committed
|
||||
if query.text.strip().upper().startswith(('UPDATE', 'DELETE', 'INSERT')):
|
||||
await session.commit()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
# Only rollback if it was a modifying operation
|
||||
if isinstance(query, str) and query.strip().upper().startswith(('UPDATE', 'DELETE', 'INSERT')):
|
||||
await session.rollback()
|
||||
logger.error("Database execute failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
|
||||
def create_database_manager(
|
||||
database_url: str,
|
||||
service_name: str,
|
||||
**kwargs
|
||||
) -> DatabaseManager:
|
||||
"""Factory function to create DatabaseManager instances"""
|
||||
return DatabaseManager(database_url, service_name, **kwargs)
|
||||
|
||||
|
||||
# ===== LEGACY COMPATIBILITY =====
|
||||
|
||||
# Keep backward compatibility for existing code
|
||||
engine = None
|
||||
AsyncSessionLocal = None
|
||||
|
||||
def init_legacy_compatibility(database_url: str):
|
||||
"""Initialize legacy global variables for backward compatibility"""
|
||||
global engine, AsyncSessionLocal
|
||||
|
||||
# Configure SSL for PostgreSQL if needed
|
||||
connect_args = {}
|
||||
if "postgresql" in database_url.lower() and "asyncpg" in database_url.lower():
|
||||
if "ssl" not in database_url.lower() and "sslmode" not in database_url.lower():
|
||||
# Create SSL context that doesn't verify certificates (for local development)
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
connect_args["ssl"] = ssl_context
|
||||
logger.info("SSL enabled with relaxed verification for legacy database connection")
|
||||
|
||||
engine = create_async_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=300,
|
||||
pool_size=20,
|
||||
max_overflow=30,
|
||||
connect_args=connect_args
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
logger.warning("Using legacy database configuration - consider migrating to DatabaseManager")
|
||||
|
||||
|
||||
async def get_legacy_db():
|
||||
"""Legacy database session getter for backward compatibility"""
|
||||
if not AsyncSessionLocal:
|
||||
raise RuntimeError("Legacy database not initialized - call init_legacy_compatibility first")
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error(f"Legacy database session error: {e}")
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
52
shared/database/exceptions.py
Executable file
52
shared/database/exceptions.py
Executable file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Custom Database Exceptions
|
||||
Provides consistent error handling across all microservices
|
||||
"""
|
||||
|
||||
class DatabaseError(Exception):
|
||||
"""Base exception for database-related errors"""
|
||||
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ConnectionError(DatabaseError):
|
||||
"""Raised when database connection fails"""
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(DatabaseError):
|
||||
"""Raised when a requested record is not found"""
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateRecordError(DatabaseError):
|
||||
"""Raised when trying to create a duplicate record"""
|
||||
pass
|
||||
|
||||
|
||||
class ConstraintViolationError(DatabaseError):
|
||||
"""Raised when database constraints are violated"""
|
||||
pass
|
||||
|
||||
|
||||
class TransactionError(DatabaseError):
|
||||
"""Raised when transaction operations fail"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(DatabaseError):
|
||||
"""Raised when data validation fails before database operations"""
|
||||
pass
|
||||
|
||||
|
||||
class MigrationError(DatabaseError):
|
||||
"""Raised when database migration operations fail"""
|
||||
pass
|
||||
|
||||
|
||||
class HealthCheckError(DatabaseError):
|
||||
"""Raised when database health checks fail"""
|
||||
pass
|
||||
381
shared/database/init_manager.py
Executable file
381
shared/database/init_manager.py
Executable file
@@ -0,0 +1,381 @@
|
||||
"""
|
||||
Database Initialization Manager
|
||||
|
||||
Handles Alembic-based migrations with autogenerate support:
|
||||
1. First-time deployment: Generate initial migration from models
|
||||
2. Subsequent deployments: Run pending migrations
|
||||
3. Development reset: Drop tables and regenerate migrations
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import structlog
|
||||
from typing import Optional, List, Dict, Any
|
||||
from pathlib import Path
|
||||
from sqlalchemy import text, inspect
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
|
||||
from .base import DatabaseManager, Base
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DatabaseInitManager:
|
||||
"""
|
||||
Manages database initialization using Alembic migrations exclusively.
|
||||
|
||||
Two modes:
|
||||
1. Migration mode (for migration jobs): Runs alembic upgrade head
|
||||
2. Verification mode (for services): Only verifies database is ready
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database_manager: DatabaseManager,
|
||||
service_name: str,
|
||||
alembic_ini_path: Optional[str] = None,
|
||||
models_module: Optional[str] = None,
|
||||
verify_only: bool = True, # Default: services only verify
|
||||
force_recreate: bool = False
|
||||
):
|
||||
self.database_manager = database_manager
|
||||
self.service_name = service_name
|
||||
self.alembic_ini_path = alembic_ini_path
|
||||
self.models_module = models_module
|
||||
self.verify_only = verify_only
|
||||
self.force_recreate = force_recreate
|
||||
self.logger = logger.bind(service=service_name)
|
||||
|
||||
async def initialize_database(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Main initialization method.
|
||||
|
||||
Two modes:
|
||||
1. verify_only=True (default, for services):
|
||||
- Verifies database is ready
|
||||
- Checks tables exist
|
||||
- Checks alembic_version exists
|
||||
- DOES NOT run migrations
|
||||
|
||||
2. verify_only=False (for migration jobs only):
|
||||
- Runs alembic upgrade head
|
||||
- Applies pending migrations
|
||||
- Can force recreate if needed
|
||||
"""
|
||||
if self.verify_only:
|
||||
self.logger.info("Database verification mode - checking database is ready")
|
||||
return await self._verify_database_ready()
|
||||
else:
|
||||
self.logger.info("Migration mode - running database migrations")
|
||||
return await self._run_migrations_mode()
|
||||
|
||||
async def _verify_database_ready(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify database is ready for service startup.
|
||||
Services should NOT run migrations - only verify they've been applied.
|
||||
"""
|
||||
try:
|
||||
# Check alembic configuration exists
|
||||
if not self.alembic_ini_path or not os.path.exists(self.alembic_ini_path):
|
||||
raise Exception(f"Alembic configuration not found at {self.alembic_ini_path}")
|
||||
|
||||
# Check database state
|
||||
db_state = await self._check_database_state()
|
||||
self.logger.info("Database state checked", state=db_state)
|
||||
|
||||
# Verify migrations exist
|
||||
if not db_state["has_migrations"]:
|
||||
raise Exception(
|
||||
f"No migration files found for {self.service_name}. "
|
||||
f"Migrations must be generated and included in the Docker image."
|
||||
)
|
||||
|
||||
# Verify database is not empty
|
||||
if db_state["is_empty"]:
|
||||
raise Exception(
|
||||
f"Database is empty. Migration job must run before service startup. "
|
||||
f"Ensure migration job completes successfully before starting services."
|
||||
)
|
||||
|
||||
# Verify alembic_version table exists
|
||||
if not db_state["has_alembic_version"]:
|
||||
raise Exception(
|
||||
f"No alembic_version table found. Migration job must run before service startup."
|
||||
)
|
||||
|
||||
# Verify current revision exists
|
||||
if not db_state["current_revision"]:
|
||||
raise Exception(
|
||||
f"No current migration revision found. Database may not be properly initialized."
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
"Database verification successful",
|
||||
migration_count=db_state["migration_count"],
|
||||
current_revision=db_state["current_revision"],
|
||||
table_count=len(db_state["existing_tables"])
|
||||
)
|
||||
|
||||
return {
|
||||
"action": "verified",
|
||||
"message": "Database verified successfully - ready for service",
|
||||
"current_revision": db_state["current_revision"],
|
||||
"migration_count": db_state["migration_count"],
|
||||
"table_count": len(db_state["existing_tables"])
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Database verification failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def _run_migrations_mode(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Run migrations mode - for migration jobs only.
|
||||
"""
|
||||
try:
|
||||
if not self.alembic_ini_path or not os.path.exists(self.alembic_ini_path):
|
||||
raise Exception(f"Alembic configuration not found at {self.alembic_ini_path}")
|
||||
|
||||
# Check current database state
|
||||
db_state = await self._check_database_state()
|
||||
self.logger.info("Database state checked", state=db_state)
|
||||
|
||||
# Handle force recreate
|
||||
if self.force_recreate:
|
||||
return await self._handle_force_recreate()
|
||||
|
||||
# Check migrations exist
|
||||
if not db_state["has_migrations"]:
|
||||
raise Exception(
|
||||
f"No migration files found for {self.service_name}. "
|
||||
f"Generate migrations using regenerate_migrations_k8s.sh script."
|
||||
)
|
||||
|
||||
# Run migrations
|
||||
result = await self._handle_run_migrations()
|
||||
|
||||
self.logger.info("Migration mode completed", result=result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Migration mode failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def _check_database_state(self) -> Dict[str, Any]:
|
||||
"""Check the current state of migrations"""
|
||||
state = {
|
||||
"has_migrations": False,
|
||||
"migration_count": 0,
|
||||
"is_empty": False,
|
||||
"existing_tables": [],
|
||||
"has_alembic_version": False,
|
||||
"current_revision": None
|
||||
}
|
||||
|
||||
try:
|
||||
# Check if migration files exist
|
||||
migrations_dir = self._get_migrations_versions_dir()
|
||||
if migrations_dir.exists():
|
||||
migration_files = list(migrations_dir.glob("*.py"))
|
||||
migration_files = [f for f in migration_files if f.name != "__pycache__" and not f.name.startswith("_")]
|
||||
state["migration_count"] = len(migration_files)
|
||||
state["has_migrations"] = len(migration_files) > 0
|
||||
self.logger.info("Found migration files", count=len(migration_files))
|
||||
|
||||
# Check database tables
|
||||
async with self.database_manager.get_session() as session:
|
||||
existing_tables = await self._get_existing_tables(session)
|
||||
state["existing_tables"] = existing_tables
|
||||
state["is_empty"] = len(existing_tables) == 0
|
||||
|
||||
# Check alembic_version table
|
||||
if "alembic_version" in existing_tables:
|
||||
state["has_alembic_version"] = True
|
||||
result = await session.execute(text("SELECT version_num FROM alembic_version"))
|
||||
version = result.scalar()
|
||||
state["current_revision"] = version
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning("Error checking database state", error=str(e))
|
||||
|
||||
return state
|
||||
|
||||
async def _handle_run_migrations(self) -> Dict[str, Any]:
|
||||
"""Handle normal migration scenario - run pending migrations"""
|
||||
self.logger.info("Running pending migrations")
|
||||
|
||||
try:
|
||||
await self._run_migrations()
|
||||
|
||||
return {
|
||||
"action": "migrations_applied",
|
||||
"message": "Pending migrations applied successfully"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to run migrations", error=str(e))
|
||||
raise
|
||||
|
||||
async def _handle_force_recreate(self) -> Dict[str, Any]:
|
||||
"""Handle development reset scenario - drop and recreate tables using existing migrations"""
|
||||
self.logger.info("Force recreate: dropping tables and rerunning migrations")
|
||||
|
||||
try:
|
||||
# Drop all tables
|
||||
await self._drop_all_tables()
|
||||
|
||||
# Apply migrations from scratch
|
||||
await self._run_migrations()
|
||||
|
||||
return {
|
||||
"action": "force_recreate",
|
||||
"tables_dropped": True,
|
||||
"migrations_applied": True,
|
||||
"message": "Database recreated from existing migrations"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to force recreate", error=str(e))
|
||||
raise
|
||||
|
||||
async def _run_migrations(self):
|
||||
"""Run pending Alembic migrations (upgrade head)"""
|
||||
try:
|
||||
def run_alembic_upgrade():
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure we're in the correct working directory
|
||||
alembic_dir = Path(self.alembic_ini_path).parent
|
||||
original_cwd = os.getcwd()
|
||||
|
||||
try:
|
||||
os.chdir(alembic_dir)
|
||||
|
||||
alembic_cfg = Config(self.alembic_ini_path)
|
||||
|
||||
# Set the SQLAlchemy URL from the database manager
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", str(self.database_manager.database_url))
|
||||
|
||||
# Run upgrade
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
finally:
|
||||
os.chdir(original_cwd)
|
||||
|
||||
# Run in executor to avoid blocking
|
||||
await asyncio.get_event_loop().run_in_executor(None, run_alembic_upgrade)
|
||||
self.logger.info("Migrations applied successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to run migrations", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
async def _drop_all_tables(self):
|
||||
"""Drop all tables (for development reset)"""
|
||||
try:
|
||||
async with self.database_manager.async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
self.logger.info("All tables dropped")
|
||||
except Exception as e:
|
||||
self.logger.error("Failed to drop tables", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
def _get_migrations_versions_dir(self) -> Path:
|
||||
"""Get the migrations/versions directory path"""
|
||||
alembic_path = Path(self.alembic_ini_path).parent
|
||||
return alembic_path / "migrations" / "versions"
|
||||
|
||||
async def _get_existing_tables(self, session: AsyncSession) -> List[str]:
|
||||
"""Get list of existing tables in the database"""
|
||||
def get_tables_sync(connection):
|
||||
insp = inspect(connection)
|
||||
return insp.get_table_names()
|
||||
|
||||
connection = await session.connection()
|
||||
return await connection.run_sync(get_tables_sync)
|
||||
|
||||
|
||||
def create_init_manager(
|
||||
database_manager: DatabaseManager,
|
||||
service_name: str,
|
||||
service_path: Optional[str] = None,
|
||||
verify_only: bool = True,
|
||||
force_recreate: bool = False
|
||||
) -> DatabaseInitManager:
|
||||
"""
|
||||
Factory function to create a DatabaseInitManager with auto-detected paths
|
||||
|
||||
Args:
|
||||
database_manager: DatabaseManager instance
|
||||
service_name: Name of the service
|
||||
service_path: Path to service directory (auto-detected if None)
|
||||
verify_only: True = verify DB ready (services), False = run migrations (jobs only)
|
||||
force_recreate: Force recreate tables (requires verify_only=False)
|
||||
"""
|
||||
# Auto-detect paths if not provided
|
||||
if service_path is None:
|
||||
# Try Docker container path first (service files at root level)
|
||||
if os.path.exists("alembic.ini"):
|
||||
service_path = "."
|
||||
else:
|
||||
# Fallback to development path
|
||||
service_path = f"services/{service_name}"
|
||||
|
||||
# Set up paths based on environment
|
||||
if service_path == ".":
|
||||
# Docker container environment
|
||||
alembic_ini_path = "alembic.ini"
|
||||
models_module = "app.models"
|
||||
else:
|
||||
# Development environment
|
||||
alembic_ini_path = f"{service_path}/alembic.ini"
|
||||
models_module = f"services.{service_name}.app.models"
|
||||
|
||||
# Check if paths exist
|
||||
if not os.path.exists(alembic_ini_path):
|
||||
logger.warning("Alembic config not found", path=alembic_ini_path)
|
||||
alembic_ini_path = None
|
||||
|
||||
return DatabaseInitManager(
|
||||
database_manager=database_manager,
|
||||
service_name=service_name,
|
||||
alembic_ini_path=alembic_ini_path,
|
||||
models_module=models_module,
|
||||
verify_only=verify_only,
|
||||
force_recreate=force_recreate
|
||||
)
|
||||
|
||||
|
||||
async def initialize_service_database(
|
||||
database_manager: DatabaseManager,
|
||||
service_name: str,
|
||||
verify_only: bool = True,
|
||||
force_recreate: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convenience function for database initialization
|
||||
|
||||
Args:
|
||||
database_manager: DatabaseManager instance
|
||||
service_name: Name of the service
|
||||
verify_only: True = verify DB ready (default, services), False = run migrations (jobs only)
|
||||
force_recreate: Force recreate tables (requires verify_only=False)
|
||||
|
||||
Returns:
|
||||
Dict with initialization results
|
||||
"""
|
||||
init_manager = create_init_manager(
|
||||
database_manager=database_manager,
|
||||
service_name=service_name,
|
||||
verify_only=verify_only,
|
||||
force_recreate=force_recreate
|
||||
)
|
||||
|
||||
return await init_manager.initialize_database()
|
||||
428
shared/database/repository.py
Executable file
428
shared/database/repository.py
Executable file
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
Base Repository Pattern for Database Operations
|
||||
Provides generic CRUD operations, query building, and caching
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, TypeVar, Generic, Type, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy import select, update, delete, and_, or_, desc, asc, func, text
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from contextlib import asynccontextmanager
|
||||
import structlog
|
||||
|
||||
from .exceptions import (
|
||||
DatabaseError,
|
||||
RecordNotFoundError,
|
||||
DuplicateRecordError,
|
||||
ConstraintViolationError
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Type variables for generic repository
|
||||
Model = TypeVar('Model', bound=declarative_base())
|
||||
CreateSchema = TypeVar('CreateSchema')
|
||||
UpdateSchema = TypeVar('UpdateSchema')
|
||||
|
||||
|
||||
class BaseRepository(Generic[Model, CreateSchema, UpdateSchema], ABC):
|
||||
"""
|
||||
Base repository providing generic CRUD operations
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
session: Database session
|
||||
cache_ttl: Cache time-to-live in seconds (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[Model], session: AsyncSession, cache_ttl: Optional[int] = None):
|
||||
self.model = model
|
||||
self.session = session
|
||||
self.cache_ttl = cache_ttl
|
||||
self._cache = {} if cache_ttl else None
|
||||
|
||||
# ===== CORE CRUD OPERATIONS =====
|
||||
|
||||
async def create(self, obj_in: CreateSchema, **kwargs) -> Model:
|
||||
"""Create a new record"""
|
||||
try:
|
||||
# Convert schema to dict if needed
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
obj_data = obj_in.model_dump()
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
obj_data = obj_in.dict()
|
||||
else:
|
||||
obj_data = obj_in
|
||||
|
||||
# Merge with additional kwargs
|
||||
obj_data.update(kwargs)
|
||||
|
||||
db_obj = self.model(**obj_data)
|
||||
self.session.add(db_obj)
|
||||
await self.session.flush() # Get ID without committing
|
||||
await self.session.refresh(db_obj)
|
||||
|
||||
logger.debug(f"Created {self.model.__name__}", record_id=getattr(db_obj, 'id', None))
|
||||
return db_obj
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error creating {self.model.__name__}", error=str(e))
|
||||
raise DuplicateRecordError(f"Record with provided data already exists")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to create record: {str(e)}")
|
||||
|
||||
async def get_by_id(self, record_id: Any) -> Optional[Model]:
|
||||
"""Get record by ID with optional caching"""
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
|
||||
# Check cache first
|
||||
if self._cache and cache_key in self._cache:
|
||||
logger.debug(f"Cache hit for {cache_key}")
|
||||
return self._cache[cache_key]
|
||||
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(self.model).where(self.model.id == record_id)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
# Cache the result
|
||||
if self._cache and record:
|
||||
self._cache[cache_key] = record
|
||||
|
||||
return record
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting {self.model.__name__} by ID",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get record: {str(e)}")
|
||||
|
||||
async def get_by_field(self, field_name: str, value: Any) -> Optional[Model]:
|
||||
"""Get record by specific field"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(self.model).where(getattr(self.model, field_name) == value)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
except AttributeError:
|
||||
raise ValueError(f"Field '{field_name}' not found in {self.model.__name__}")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting {self.model.__name__} by {field_name}",
|
||||
value=value, error=str(e))
|
||||
raise DatabaseError(f"Failed to get record: {str(e)}")
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
order_by: Optional[str] = None,
|
||||
order_desc: bool = False,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Model]:
|
||||
"""Get multiple records with pagination, sorting, and filtering"""
|
||||
try:
|
||||
query = select(self.model)
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
conditions = []
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field):
|
||||
if isinstance(value, list):
|
||||
conditions.append(getattr(self.model, field).in_(value))
|
||||
else:
|
||||
conditions.append(getattr(self.model, field) == value)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
# Apply ordering
|
||||
if order_by and hasattr(self.model, order_by):
|
||||
order_field = getattr(self.model, order_by)
|
||||
if order_desc:
|
||||
query = query.order_by(desc(order_field))
|
||||
else:
|
||||
query = query.order_by(asc(order_field))
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting multiple {self.model.__name__} records",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get records: {str(e)}")
|
||||
|
||||
async def update(self, record_id: Any, obj_in: UpdateSchema, **kwargs) -> Optional[Model]:
|
||||
"""Update record by ID"""
|
||||
try:
|
||||
# Convert schema to dict if needed
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
else:
|
||||
update_data = obj_in
|
||||
|
||||
# Merge with additional kwargs
|
||||
update_data.update(kwargs)
|
||||
|
||||
# Remove None values
|
||||
update_data = {k: v for k, v in update_data.items() if v is not None}
|
||||
|
||||
if not update_data:
|
||||
logger.warning(f"No data to update for {self.model.__name__}", record_id=record_id)
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
# Perform update
|
||||
result = await self.session.execute(
|
||||
update(self.model)
|
||||
.where(self.model.id == record_id)
|
||||
.values(**update_data)
|
||||
.returning(self.model)
|
||||
)
|
||||
|
||||
updated_record = result.scalar_one_or_none()
|
||||
|
||||
if not updated_record:
|
||||
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
|
||||
|
||||
# Clear cache
|
||||
if self._cache:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Updated {self.model.__name__}", record_id=record_id)
|
||||
return updated_record
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error updating {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise ConstraintViolationError(f"Update violates database constraints")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to update record: {str(e)}")
|
||||
|
||||
async def delete(self, record_id: Any) -> bool:
|
||||
"""Delete record by ID"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
delete(self.model).where(self.model.id == record_id)
|
||||
)
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if deleted_count == 0:
|
||||
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
|
||||
|
||||
# Clear cache
|
||||
if self._cache:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Deleted {self.model.__name__}", record_id=record_id)
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error deleting {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to delete record: {str(e)}")
|
||||
|
||||
# ===== ADVANCED QUERY OPERATIONS =====
|
||||
|
||||
async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
|
||||
"""Count records with optional filters"""
|
||||
try:
|
||||
query = select(func.count(self.model.id))
|
||||
|
||||
if filters:
|
||||
conditions = []
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field):
|
||||
if isinstance(value, list):
|
||||
conditions.append(getattr(self.model, field).in_(value))
|
||||
else:
|
||||
conditions.append(getattr(self.model, field) == value)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error counting {self.model.__name__} records", error=str(e))
|
||||
raise DatabaseError(f"Failed to count records: {str(e)}")
|
||||
|
||||
async def exists(self, record_id: Any) -> bool:
|
||||
"""Check if record exists by ID"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(func.count(self.model.id)).where(self.model.id == record_id)
|
||||
)
|
||||
count = result.scalar() or 0
|
||||
return count > 0
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error checking existence of {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to check record existence: {str(e)}")
|
||||
|
||||
async def bulk_create(self, objects: List[CreateSchema]) -> List[Model]:
|
||||
"""Create multiple records in bulk"""
|
||||
try:
|
||||
if not objects:
|
||||
return []
|
||||
|
||||
db_objects = []
|
||||
for obj_in in objects:
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
obj_data = obj_in.model_dump()
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
obj_data = obj_in.dict()
|
||||
else:
|
||||
obj_data = obj_in
|
||||
|
||||
db_objects.append(self.model(**obj_data))
|
||||
|
||||
self.session.add_all(db_objects)
|
||||
await self.session.flush()
|
||||
|
||||
# Skip expensive individual refresh operations for large datasets
|
||||
# Only refresh if we have a small number of objects
|
||||
if len(db_objects) <= 100:
|
||||
for db_obj in db_objects:
|
||||
await self.session.refresh(db_obj)
|
||||
else:
|
||||
# For large datasets, just log without refresh to prevent memory issues
|
||||
logger.debug(f"Skipped individual refresh for large bulk operation ({len(db_objects)} records)")
|
||||
|
||||
logger.debug(f"Bulk created {len(db_objects)} {self.model.__name__} records")
|
||||
return db_objects
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error bulk creating {self.model.__name__}", error=str(e))
|
||||
raise DuplicateRecordError(f"One or more records already exist")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error bulk creating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to create records: {str(e)}")
|
||||
|
||||
async def bulk_update(self, updates: List[Dict[str, Any]]) -> int:
|
||||
"""Update multiple records in bulk"""
|
||||
try:
|
||||
if not updates:
|
||||
return 0
|
||||
|
||||
# Group updates by fields being updated for efficiency
|
||||
for update_data in updates:
|
||||
if 'id' not in update_data:
|
||||
raise ValueError("Each update must include 'id' field")
|
||||
|
||||
record_id = update_data.pop('id')
|
||||
await self.session.execute(
|
||||
update(self.model)
|
||||
.where(self.model.id == record_id)
|
||||
.values(**update_data)
|
||||
)
|
||||
|
||||
# Clear relevant cache entries
|
||||
if self._cache:
|
||||
for update_data in updates:
|
||||
record_id = update_data.get('id')
|
||||
if record_id:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Bulk updated {len(updates)} {self.model.__name__} records")
|
||||
return len(updates)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error bulk updating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to update records: {str(e)}")
|
||||
|
||||
# ===== SEARCH AND QUERY BUILDING =====
|
||||
|
||||
async def search(
|
||||
self,
|
||||
search_term: str,
|
||||
search_fields: List[str],
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Model]:
|
||||
"""Search records across multiple fields"""
|
||||
try:
|
||||
conditions = []
|
||||
for field in search_fields:
|
||||
if hasattr(self.model, field):
|
||||
field_obj = getattr(self.model, field)
|
||||
# Case-insensitive partial match
|
||||
conditions.append(field_obj.ilike(f"%{search_term}%"))
|
||||
|
||||
if not conditions:
|
||||
logger.warning(f"No valid search fields provided for {self.model.__name__}")
|
||||
return []
|
||||
|
||||
query = select(self.model).where(or_(*conditions)).offset(skip).limit(limit)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error searching {self.model.__name__}",
|
||||
search_term=search_term, error=str(e))
|
||||
raise DatabaseError(f"Failed to search records: {str(e)}")
|
||||
|
||||
async def execute_raw_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Execute raw SQL query (use with caution)"""
|
||||
try:
|
||||
result = await self.session.execute(text(query), params or {})
|
||||
return result
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error executing raw query", query=query, error=str(e))
|
||||
raise DatabaseError(f"Failed to execute query: {str(e)}")
|
||||
|
||||
# ===== CACHE MANAGEMENT =====
|
||||
|
||||
def clear_cache(self, record_id: Optional[Any] = None):
|
||||
"""Clear cache for specific record or all records"""
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
if record_id:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
else:
|
||||
# Clear all cache entries for this model
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{self.model.__name__}:")]
|
||||
for key in keys_to_remove:
|
||||
self._cache.pop(key, None)
|
||||
|
||||
logger.debug(f"Cleared cache for {self.model.__name__}", record_id=record_id)
|
||||
|
||||
# ===== CONTEXT MANAGERS =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""Context manager for explicit transaction handling"""
|
||||
try:
|
||||
yield self.session
|
||||
await self.session.commit()
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Transaction failed for {self.model.__name__}", error=str(e))
|
||||
raise
|
||||
306
shared/database/transactions.py
Executable file
306
shared/database/transactions.py
Executable file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Transaction Decorators and Context Managers
|
||||
Provides convenient transaction handling for service methods
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Any, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
import structlog
|
||||
|
||||
from .base import DatabaseManager
|
||||
from .unit_of_work import UnitOfWork
|
||||
from .exceptions import TransactionError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def transactional(database_manager: DatabaseManager, auto_commit: bool = True):
|
||||
"""
|
||||
Decorator that wraps a method in a database transaction
|
||||
|
||||
Args:
|
||||
database_manager: DatabaseManager instance
|
||||
auto_commit: Whether to auto-commit on success
|
||||
|
||||
Usage:
|
||||
@transactional(database_manager)
|
||||
async def create_user_with_profile(self, user_data, profile_data):
|
||||
# Your business logic here
|
||||
# Transaction is automatically managed
|
||||
pass
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with database_manager.get_background_session() as session:
|
||||
try:
|
||||
# Inject session into kwargs if not present
|
||||
if 'session' not in kwargs:
|
||||
kwargs['session'] = session
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# Session is auto-committed by get_background_session
|
||||
logger.debug(f"Transaction completed successfully for {func.__name__}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Session is auto-rolled back by get_background_session
|
||||
logger.error(f"Transaction failed for {func.__name__}", error=str(e))
|
||||
raise TransactionError(f"Transaction failed: {str(e)}")
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def unit_of_work_transactional(database_manager: DatabaseManager):
|
||||
"""
|
||||
Decorator that provides Unit of Work pattern for complex operations
|
||||
|
||||
Usage:
|
||||
@unit_of_work_transactional(database_manager)
|
||||
async def complex_business_operation(self, data, uow: UnitOfWork):
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
user = await user_repo.create(data.user)
|
||||
sale = await sales_repo.create(data.sale)
|
||||
|
||||
# UnitOfWork automatically commits
|
||||
return {"user": user, "sale": sale}
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with database_manager.get_background_session() as session:
|
||||
async with UnitOfWork(session, auto_commit=True) as uow:
|
||||
try:
|
||||
# Inject UnitOfWork into kwargs
|
||||
kwargs['uow'] = uow
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
logger.debug(f"Unit of Work transaction completed for {func.__name__}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unit of Work transaction failed for {func.__name__}",
|
||||
error=str(e))
|
||||
raise TransactionError(f"Transaction failed: {str(e)}")
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def managed_transaction(database_manager: DatabaseManager):
|
||||
"""
|
||||
Context manager for explicit transaction control
|
||||
|
||||
Usage:
|
||||
async with managed_transaction(database_manager) as session:
|
||||
# Your database operations here
|
||||
user = User(name="John")
|
||||
session.add(user)
|
||||
# Auto-commits on exit, rolls back on exception
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
try:
|
||||
logger.debug("Starting managed transaction")
|
||||
yield session
|
||||
logger.debug("Managed transaction completed successfully")
|
||||
except Exception as e:
|
||||
logger.error("Managed transaction failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def managed_unit_of_work(database_manager: DatabaseManager, event_publisher=None):
|
||||
"""
|
||||
Context manager for explicit Unit of Work control
|
||||
|
||||
Usage:
|
||||
async with managed_unit_of_work(database_manager) as uow:
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
user = await user_repo.create(user_data)
|
||||
await uow.commit()
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
uow = UnitOfWork(session)
|
||||
try:
|
||||
logger.debug("Starting managed Unit of Work")
|
||||
yield uow
|
||||
|
||||
if not uow._committed:
|
||||
await uow.commit()
|
||||
|
||||
logger.debug("Managed Unit of Work completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
if not uow._rolled_back:
|
||||
await uow.rollback()
|
||||
logger.error("Managed Unit of Work failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
class TransactionManager:
|
||||
"""
|
||||
Advanced transaction manager for complex scenarios
|
||||
|
||||
Usage:
|
||||
tx_manager = TransactionManager(database_manager)
|
||||
|
||||
async with tx_manager.create_transaction() as tx:
|
||||
await tx.execute_in_transaction(my_business_logic, data)
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager: DatabaseManager):
|
||||
self.database_manager = database_manager
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_transaction(self, isolation_level: Optional[str] = None):
|
||||
"""Create a transaction with optional isolation level"""
|
||||
async with self.database_manager.get_background_session() as session:
|
||||
transaction_context = TransactionContext(session, isolation_level)
|
||||
try:
|
||||
yield transaction_context
|
||||
except Exception as e:
|
||||
logger.error("Transaction manager failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def execute_with_retry(
|
||||
self,
|
||||
func: Callable,
|
||||
max_retries: int = 3,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""Execute function with transaction retry on failure"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with managed_transaction(self.database_manager) as session:
|
||||
kwargs['session'] = session
|
||||
result = await func(*args, **kwargs)
|
||||
logger.debug(f"Transaction succeeded on attempt {attempt + 1}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"Transaction attempt {attempt + 1} failed",
|
||||
error=str(e), remaining_attempts=max_retries - attempt - 1)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
break
|
||||
|
||||
logger.error(f"All transaction attempts failed after {max_retries} tries")
|
||||
raise TransactionError(f"Transaction failed after {max_retries} retries: {str(last_error)}")
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""Context for managing individual transactions"""
|
||||
|
||||
def __init__(self, session, isolation_level: Optional[str] = None):
|
||||
self.session = session
|
||||
self.isolation_level = isolation_level
|
||||
|
||||
async def execute_in_transaction(self, func: Callable, *args, **kwargs):
|
||||
"""Execute function within the transaction context"""
|
||||
try:
|
||||
kwargs['session'] = self.session
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Function execution failed in transaction context", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# ===== UTILITY FUNCTIONS =====
|
||||
|
||||
async def run_in_transaction(database_manager: DatabaseManager, func: Callable, *args, **kwargs):
|
||||
"""
|
||||
Utility function to run any async function in a transaction
|
||||
|
||||
Usage:
|
||||
result = await run_in_transaction(
|
||||
database_manager,
|
||||
my_async_function,
|
||||
arg1, arg2,
|
||||
kwarg1="value"
|
||||
)
|
||||
"""
|
||||
async with managed_transaction(database_manager) as session:
|
||||
kwargs['session'] = session
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
async def run_with_unit_of_work(
|
||||
database_manager: DatabaseManager,
|
||||
func: Callable,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Utility function to run any async function with Unit of Work
|
||||
|
||||
Usage:
|
||||
result = await run_with_unit_of_work(
|
||||
database_manager,
|
||||
my_complex_function,
|
||||
arg1, arg2
|
||||
)
|
||||
"""
|
||||
async with managed_unit_of_work(database_manager) as uow:
|
||||
kwargs['uow'] = uow
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
# ===== BATCH OPERATIONS =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def batch_operation(database_manager: DatabaseManager, batch_size: int = 1000):
|
||||
"""
|
||||
Context manager for batch operations with automatic commit batching
|
||||
|
||||
Usage:
|
||||
async with batch_operation(database_manager, batch_size=500) as batch:
|
||||
for item in large_dataset:
|
||||
await batch.add_operation(create_record, item)
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
batch_context = BatchOperationContext(session, batch_size)
|
||||
try:
|
||||
yield batch_context
|
||||
await batch_context.flush_remaining()
|
||||
except Exception as e:
|
||||
logger.error("Batch operation failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
class BatchOperationContext:
|
||||
"""Context for managing batch database operations"""
|
||||
|
||||
def __init__(self, session, batch_size: int):
|
||||
self.session = session
|
||||
self.batch_size = batch_size
|
||||
self.operation_count = 0
|
||||
|
||||
async def add_operation(self, func: Callable, *args, **kwargs):
|
||||
"""Add operation to batch"""
|
||||
kwargs['session'] = self.session
|
||||
await func(*args, **kwargs)
|
||||
|
||||
self.operation_count += 1
|
||||
|
||||
if self.operation_count >= self.batch_size:
|
||||
await self.session.commit()
|
||||
self.operation_count = 0
|
||||
logger.debug(f"Batch committed at {self.batch_size} operations")
|
||||
|
||||
async def flush_remaining(self):
|
||||
"""Commit any remaining operations"""
|
||||
if self.operation_count > 0:
|
||||
await self.session.commit()
|
||||
logger.debug(f"Final batch committed with {self.operation_count} operations")
|
||||
304
shared/database/unit_of_work.py
Executable file
304
shared/database/unit_of_work.py
Executable file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Unit of Work Pattern Implementation
|
||||
Manages transactions across multiple repositories with event publishing
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Type, TypeVar, Generic
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from abc import ABC, abstractmethod
|
||||
import structlog
|
||||
|
||||
from .repository import BaseRepository
|
||||
from .exceptions import TransactionError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
Model = TypeVar('Model')
|
||||
Repository = TypeVar('Repository', bound=BaseRepository)
|
||||
|
||||
|
||||
class BaseEvent(ABC):
|
||||
"""Base class for domain events"""
|
||||
|
||||
def __init__(self, event_type: str, data: Dict[str, Any]):
|
||||
self.event_type = event_type
|
||||
self.data = data
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert event to dictionary for publishing"""
|
||||
pass
|
||||
|
||||
|
||||
class DomainEvent(BaseEvent):
|
||||
"""Standard domain event implementation"""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"event_type": self.event_type,
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
|
||||
class UnitOfWork:
|
||||
"""
|
||||
Unit of Work pattern for managing transactions and coordinating repositories
|
||||
|
||||
Usage:
|
||||
async with UnitOfWork(session) as uow:
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
user = await user_repo.create(user_data)
|
||||
sale = await sales_repo.create(sales_data)
|
||||
|
||||
await uow.commit()
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, auto_commit: bool = False):
|
||||
self.session = session
|
||||
self.auto_commit = auto_commit
|
||||
self._repositories: Dict[str, BaseRepository] = {}
|
||||
self._events: List[BaseEvent] = []
|
||||
self._committed = False
|
||||
self._rolled_back = False
|
||||
|
||||
def register_repository(
|
||||
self,
|
||||
name: str,
|
||||
repository_class: Type[Repository],
|
||||
model_class: Type[Model],
|
||||
**kwargs
|
||||
) -> Repository:
|
||||
"""
|
||||
Register a repository with the unit of work
|
||||
|
||||
Args:
|
||||
name: Unique name for the repository
|
||||
repository_class: Repository class to instantiate
|
||||
model_class: SQLAlchemy model class
|
||||
**kwargs: Additional arguments for repository
|
||||
|
||||
Returns:
|
||||
Instantiated repository
|
||||
"""
|
||||
if name in self._repositories:
|
||||
logger.warning(f"Repository '{name}' already registered, returning existing instance")
|
||||
return self._repositories[name]
|
||||
|
||||
repository = repository_class(model_class, self.session, **kwargs)
|
||||
self._repositories[name] = repository
|
||||
|
||||
logger.debug(f"Registered repository", name=name, model=model_class.__name__)
|
||||
return repository
|
||||
|
||||
def get_repository(self, name: str) -> Optional[Repository]:
|
||||
"""Get registered repository by name"""
|
||||
return self._repositories.get(name)
|
||||
|
||||
def add_event(self, event: BaseEvent):
|
||||
"""Add domain event to be published after commit"""
|
||||
self._events.append(event)
|
||||
logger.debug(f"Added event", event_type=event.event_type)
|
||||
|
||||
async def commit(self):
|
||||
"""Commit the transaction and publish events"""
|
||||
if self._committed:
|
||||
logger.warning("Unit of Work already committed")
|
||||
return
|
||||
|
||||
if self._rolled_back:
|
||||
raise TransactionError("Cannot commit after rollback")
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
self._committed = True
|
||||
|
||||
# Publish events after successful commit
|
||||
await self._publish_events()
|
||||
|
||||
logger.debug(f"Unit of Work committed successfully",
|
||||
repositories=list(self._repositories.keys()),
|
||||
events_published=len(self._events))
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.rollback()
|
||||
logger.error("Failed to commit Unit of Work", error=str(e))
|
||||
raise TransactionError(f"Commit failed: {str(e)}")
|
||||
|
||||
async def rollback(self):
|
||||
"""Rollback the transaction"""
|
||||
if self._rolled_back:
|
||||
logger.warning("Unit of Work already rolled back")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.session.rollback()
|
||||
self._rolled_back = True
|
||||
self._events.clear() # Clear events on rollback
|
||||
|
||||
logger.debug(f"Unit of Work rolled back",
|
||||
repositories=list(self._repositories.keys()))
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error("Failed to rollback Unit of Work", error=str(e))
|
||||
raise TransactionError(f"Rollback failed: {str(e)}")
|
||||
|
||||
async def _publish_events(self):
|
||||
"""Publish domain events (override in subclasses for actual publishing)"""
|
||||
if not self._events:
|
||||
return
|
||||
|
||||
# Default implementation just logs events
|
||||
# Override this method in service-specific implementations
|
||||
for event in self._events:
|
||||
logger.info(f"Publishing event",
|
||||
event_type=event.event_type,
|
||||
event_data=event.to_dict())
|
||||
|
||||
# Clear events after publishing
|
||||
self._events.clear()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
if exc_type is not None:
|
||||
# Exception occurred, rollback
|
||||
await self.rollback()
|
||||
return False
|
||||
|
||||
# No exception, auto-commit if enabled
|
||||
if self.auto_commit and not self._committed:
|
||||
await self.commit()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ServiceUnitOfWork(UnitOfWork):
|
||||
"""
|
||||
Service-specific Unit of Work with event publishing integration
|
||||
|
||||
Example usage with message publishing:
|
||||
|
||||
class AuthUnitOfWork(ServiceUnitOfWork):
|
||||
def __init__(self, session: AsyncSession, message_publisher=None):
|
||||
super().__init__(session)
|
||||
self.message_publisher = message_publisher
|
||||
|
||||
async def _publish_events(self):
|
||||
for event in self._events:
|
||||
if self.message_publisher:
|
||||
await self.message_publisher.publish(
|
||||
topic="auth.events",
|
||||
message=event.to_dict()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False):
|
||||
super().__init__(session, auto_commit)
|
||||
self.event_publisher = event_publisher
|
||||
|
||||
async def _publish_events(self):
|
||||
"""Publish events using the provided event publisher"""
|
||||
if not self._events or not self.event_publisher:
|
||||
return
|
||||
|
||||
try:
|
||||
for event in self._events:
|
||||
await self.event_publisher.publish(event)
|
||||
logger.debug(f"Published event via publisher",
|
||||
event_type=event.event_type)
|
||||
|
||||
self._events.clear()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish events", error=str(e))
|
||||
# Don't raise here to avoid breaking the transaction
|
||||
# Events will be retried or handled by the event publisher
|
||||
|
||||
|
||||
# ===== TRANSACTION CONTEXT MANAGER =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction_scope(session: AsyncSession, auto_commit: bool = True):
|
||||
"""
|
||||
Simple transaction context manager for single-repository operations
|
||||
|
||||
Usage:
|
||||
async with transaction_scope(session) as tx_session:
|
||||
user = User(name="John")
|
||||
tx_session.add(user)
|
||||
# Auto-commits on success, rolls back on exception
|
||||
"""
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Transaction scope failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# ===== UTILITIES =====
|
||||
|
||||
class RepositoryRegistry:
|
||||
"""Registry for commonly used repository configurations"""
|
||||
|
||||
_registry: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
repository_class: Type[Repository],
|
||||
model_class: Type[Model],
|
||||
**kwargs
|
||||
):
|
||||
"""Register a repository configuration"""
|
||||
self._registry[name] = {
|
||||
"repository_class": repository_class,
|
||||
"model_class": model_class,
|
||||
"kwargs": kwargs
|
||||
}
|
||||
logger.debug(f"Registered repository configuration", name=name)
|
||||
|
||||
@classmethod
|
||||
def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]:
|
||||
"""Create repository instance from registry"""
|
||||
config = self._registry.get(name)
|
||||
if not config:
|
||||
logger.warning(f"Repository configuration '{name}' not found in registry")
|
||||
return None
|
||||
|
||||
return config["repository_class"](
|
||||
config["model_class"],
|
||||
session,
|
||||
**config["kwargs"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_registered(self) -> List[str]:
|
||||
"""List all registered repository names"""
|
||||
return list(self._registry.keys())
|
||||
|
||||
|
||||
# ===== FACTORY FUNCTIONS =====
|
||||
|
||||
def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork:
|
||||
"""Factory function to create Unit of Work instances"""
|
||||
return UnitOfWork(session, **kwargs)
|
||||
|
||||
|
||||
def create_service_unit_of_work(
|
||||
session: AsyncSession,
|
||||
event_publisher=None,
|
||||
**kwargs
|
||||
) -> ServiceUnitOfWork:
|
||||
"""Factory function to create Service Unit of Work instances"""
|
||||
return ServiceUnitOfWork(session, event_publisher, **kwargs)
|
||||
402
shared/database/utils.py
Executable file
402
shared/database/utils.py
Executable file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Database Utilities
|
||||
Helper functions for database operations and maintenance
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text, inspect
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
import structlog
|
||||
|
||||
from .exceptions import DatabaseError, HealthCheckError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DatabaseUtils:
|
||||
"""Utility functions for database operations"""
|
||||
|
||||
@staticmethod
|
||||
async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive database health check
|
||||
|
||||
Returns:
|
||||
Dict with health status, metrics, and diagnostics
|
||||
"""
|
||||
try:
|
||||
# Basic connectivity test
|
||||
start_time = __import__('time').time()
|
||||
await session.execute(text("SELECT 1"))
|
||||
response_time = __import__('time').time() - start_time
|
||||
|
||||
# Get database info
|
||||
db_info = await DatabaseUtils._get_database_info(session)
|
||||
|
||||
# Connection pool status (if available)
|
||||
pool_info = await DatabaseUtils._get_pool_info(session)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_seconds": round(response_time, 4),
|
||||
"database": db_info,
|
||||
"connection_pool": pool_info,
|
||||
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
raise HealthCheckError(f"Health check failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def _get_database_info(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get database server information"""
|
||||
try:
|
||||
# Try to get database version and basic stats
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
version_result = await session.execute(text("SELECT version()"))
|
||||
version = version_result.scalar()
|
||||
|
||||
stats_result = await session.execute(text("""
|
||||
SELECT
|
||||
count(*) as active_connections,
|
||||
(SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections
|
||||
FROM pg_stat_activity
|
||||
WHERE state = 'active'
|
||||
"""))
|
||||
stats = stats_result.fetchone()
|
||||
|
||||
return {
|
||||
"type": "postgresql",
|
||||
"version": version,
|
||||
"active_connections": stats.active_connections if stats else 0,
|
||||
"max_connections": stats.max_connections if stats else "unknown"
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
version_result = await session.execute(text("SELECT sqlite_version()"))
|
||||
version = version_result.scalar()
|
||||
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"version": version,
|
||||
"active_connections": 1,
|
||||
"max_connections": "unlimited"
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"type": session.bind.dialect.name,
|
||||
"version": "unknown",
|
||||
"active_connections": "unknown",
|
||||
"max_connections": "unknown"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve database info", error=str(e))
|
||||
return {
|
||||
"type": session.bind.dialect.name,
|
||||
"version": "unknown",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get connection pool information"""
|
||||
try:
|
||||
pool = session.bind.pool
|
||||
if pool:
|
||||
return {
|
||||
"size": pool.size(),
|
||||
"checked_in": pool.checkedin(),
|
||||
"checked_out": pool.checkedout(),
|
||||
"overflow": pool.overflow(),
|
||||
"status": pool.status()
|
||||
}
|
||||
else:
|
||||
return {"status": "no_pool"}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve pool info", error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate database schema against expected tables
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
expected_tables: List of table names that should exist
|
||||
|
||||
Returns:
|
||||
Validation results with missing/extra tables
|
||||
"""
|
||||
try:
|
||||
# Get existing tables
|
||||
inspector = inspect(session.bind)
|
||||
existing_tables = set(inspector.get_table_names())
|
||||
expected_tables_set = set(expected_tables)
|
||||
|
||||
missing_tables = expected_tables_set - existing_tables
|
||||
extra_tables = existing_tables - expected_tables_set
|
||||
|
||||
return {
|
||||
"valid": len(missing_tables) == 0,
|
||||
"existing_tables": list(existing_tables),
|
||||
"expected_tables": expected_tables,
|
||||
"missing_tables": list(missing_tables),
|
||||
"extra_tables": list(extra_tables),
|
||||
"total_tables": len(existing_tables)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Schema validation failed", error=str(e))
|
||||
raise DatabaseError(f"Schema validation failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for specified tables
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
table_names: List of table names to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary with table statistics
|
||||
"""
|
||||
try:
|
||||
stats = {}
|
||||
|
||||
for table_name in table_names:
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
# PostgreSQL specific queries
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
size_result = await session.execute(
|
||||
text(f"SELECT pg_total_relation_size('{table_name}')")
|
||||
)
|
||||
table_size = size_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": table_size,
|
||||
"size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
# SQLite specific queries
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": "unknown",
|
||||
"size_mb": "unknown"
|
||||
}
|
||||
|
||||
else:
|
||||
# Generic fallback
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": "unknown",
|
||||
"size_mb": "unknown"
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get table statistics",
|
||||
tables=table_names, error=str(e))
|
||||
raise DatabaseError(f"Failed to get table stats: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_records(
|
||||
session: AsyncSession,
|
||||
table_name: str,
|
||||
date_column: str,
|
||||
days_old: int,
|
||||
batch_size: int = 1000
|
||||
) -> int:
|
||||
"""
|
||||
Clean up old records from a table
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
table_name: Name of table to clean
|
||||
date_column: Date column to filter by
|
||||
days_old: Records older than this many days will be deleted
|
||||
batch_size: Number of records to delete per batch
|
||||
|
||||
Returns:
|
||||
Total number of records deleted
|
||||
"""
|
||||
try:
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||||
AND ctid IN (
|
||||
SELECT ctid FROM {table_name}
|
||||
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||||
LIMIT :batch_size
|
||||
)
|
||||
""")
|
||||
params = {
|
||||
"days_param": f"{days_old} days",
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < datetime('now', :days_param)
|
||||
AND rowid IN (
|
||||
SELECT rowid FROM {table_name}
|
||||
WHERE {date_column} < datetime('now', :days_param)
|
||||
LIMIT :batch_size
|
||||
)
|
||||
""")
|
||||
params = {
|
||||
"days_param": f"-{days_old} days",
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
else:
|
||||
# Generic fallback (may not work for all databases)
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY)
|
||||
LIMIT :batch_size
|
||||
""")
|
||||
params = {
|
||||
"days_old": days_old,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
result = await session.execute(delete_query, params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if deleted_count == 0:
|
||||
break
|
||||
|
||||
total_deleted += deleted_count
|
||||
await session.commit()
|
||||
|
||||
logger.debug(f"Deleted batch from {table_name}",
|
||||
batch_size=deleted_count,
|
||||
total_deleted=total_deleted)
|
||||
|
||||
logger.info(f"Cleanup completed for {table_name}",
|
||||
total_deleted=total_deleted,
|
||||
days_old=days_old)
|
||||
|
||||
return total_deleted
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Cleanup failed for {table_name}", error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute database maintenance tasks
|
||||
|
||||
Returns:
|
||||
Dictionary with maintenance results
|
||||
"""
|
||||
try:
|
||||
results = {}
|
||||
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
# PostgreSQL maintenance
|
||||
await session.execute(text("VACUUM ANALYZE"))
|
||||
results["vacuum"] = "completed"
|
||||
|
||||
# Update statistics
|
||||
await session.execute(text("ANALYZE"))
|
||||
results["analyze"] = "completed"
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
# SQLite maintenance
|
||||
await session.execute(text("VACUUM"))
|
||||
results["vacuum"] = "completed"
|
||||
|
||||
await session.execute(text("ANALYZE"))
|
||||
results["analyze"] = "completed"
|
||||
|
||||
else:
|
||||
results["maintenance"] = "not_supported"
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("Database maintenance completed", results=results)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Database maintenance failed", error=str(e))
|
||||
raise DatabaseError(f"Maintenance failed: {str(e)}")
|
||||
|
||||
|
||||
class QueryLogger:
|
||||
"""Utility for logging and analyzing database queries"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self._query_log = []
|
||||
|
||||
async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None):
|
||||
"""Log a database query with metadata"""
|
||||
log_entry = {
|
||||
"query": query,
|
||||
"params": params,
|
||||
"execution_time": execution_time,
|
||||
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
self._query_log.append(log_entry)
|
||||
|
||||
# Log slow queries
|
||||
if execution_time and execution_time > 1.0: # 1 second threshold
|
||||
logger.warning("Slow query detected",
|
||||
query=query,
|
||||
execution_time=execution_time)
|
||||
|
||||
def get_query_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about logged queries"""
|
||||
if not self._query_log:
|
||||
return {"total_queries": 0}
|
||||
|
||||
execution_times = [
|
||||
entry["execution_time"]
|
||||
for entry in self._query_log
|
||||
if entry["execution_time"] is not None
|
||||
]
|
||||
|
||||
return {
|
||||
"total_queries": len(self._query_log),
|
||||
"avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0,
|
||||
"max_execution_time": max(execution_times) if execution_times else 0,
|
||||
"slow_queries_count": len([t for t in execution_times if t > 1.0])
|
||||
}
|
||||
|
||||
def clear_log(self):
|
||||
"""Clear the query log"""
|
||||
self._query_log.clear()
|
||||
Reference in New Issue
Block a user