Initial commit - production deployment

This commit is contained in:
2026-01-21 17:17:16 +01:00
commit c23d00dd92
2289 changed files with 638440 additions and 0 deletions

68
shared/database/__init__.py Executable file
View File

@@ -0,0 +1,68 @@
"""
Shared Database Infrastructure
Provides consistent database patterns across all microservices
"""
from .base import DatabaseManager, Base, create_database_manager
from .repository import BaseRepository
from .unit_of_work import UnitOfWork, ServiceUnitOfWork, RepositoryRegistry
from .transactions import (
transactional,
unit_of_work_transactional,
managed_transaction,
managed_unit_of_work,
TransactionManager,
run_in_transaction,
run_with_unit_of_work
)
from .exceptions import (
DatabaseError,
ConnectionError,
RecordNotFoundError,
DuplicateRecordError,
ConstraintViolationError,
TransactionError,
ValidationError,
MigrationError,
HealthCheckError
)
from .utils import DatabaseUtils, QueryLogger
__all__ = [
# Core components
"DatabaseManager",
"Base",
"create_database_manager",
# Repository pattern
"BaseRepository",
# Unit of Work pattern
"UnitOfWork",
"ServiceUnitOfWork",
"RepositoryRegistry",
# Transaction management
"transactional",
"unit_of_work_transactional",
"managed_transaction",
"managed_unit_of_work",
"TransactionManager",
"run_in_transaction",
"run_with_unit_of_work",
# Exceptions
"DatabaseError",
"ConnectionError",
"RecordNotFoundError",
"DuplicateRecordError",
"ConstraintViolationError",
"TransactionError",
"ValidationError",
"MigrationError",
"HealthCheckError",
# Utilities
"DatabaseUtils",
"QueryLogger"
]

408
shared/database/base.py Executable file
View File

@@ -0,0 +1,408 @@
"""
Enhanced Base Database Configuration for All Microservices
Provides DatabaseManager with connection pooling, health checks, and multi-database support
Fixed: SSL configuration now uses connect_args instead of URL parameters to avoid asyncpg parameter parsing issues
"""
import os
import ssl
from typing import Optional, Dict, Any, List
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import StaticPool
from contextlib import asynccontextmanager
import structlog
import time
from .exceptions import DatabaseError, ConnectionError, HealthCheckError
from .utils import DatabaseUtils
logger = structlog.get_logger()
Base = declarative_base()
class DatabaseManager:
"""Enhanced Database Manager for Microservices
Provides:
- Connection pooling with configurable settings
- Health checks and monitoring
- Multi-database support
- Session lifecycle management
- Background task session support
"""
def __init__(
self,
database_url: str,
service_name: str = "unknown",
pool_size: int = 20,
max_overflow: int = 30,
pool_recycle: int = 3600,
pool_pre_ping: bool = True,
echo: bool = False,
connect_timeout: int = 30,
**engine_kwargs
):
self.database_url = database_url
# Configure SSL for PostgreSQL via connect_args instead of URL parameters
# This avoids asyncpg parameter parsing issues
self.use_ssl = False
if "postgresql" in database_url.lower():
# Check if SSL is already configured in URL or should be enabled
if "ssl" not in database_url.lower() and "sslmode" not in database_url.lower():
# Enable SSL for production, but allow override via URL
self.use_ssl = True
logger.info(f"SSL will be enabled for PostgreSQL connection: {service_name}")
self.service_name = service_name
self.pool_size = pool_size
self.max_overflow = max_overflow
# Configure pool for async engines
# Note: SQLAlchemy 2.0 async engines automatically use AsyncAdaptedQueuePool
# We should NOT specify poolclass for async engines unless using StaticPool for SQLite
# Prepare connect_args for asyncpg
connect_args = {"timeout": connect_timeout}
# Add SSL configuration if needed (for asyncpg driver)
if self.use_ssl and "asyncpg" in database_url.lower():
# Create SSL context that doesn't verify certificates (for local development)
# In production, you should use a proper SSL context with certificate verification
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
connect_args["ssl"] = ssl_context
logger.info(f"SSL enabled with relaxed verification for {service_name}")
engine_config = {
"echo": echo,
"pool_pre_ping": pool_pre_ping,
"pool_recycle": pool_recycle,
"pool_size": pool_size,
"max_overflow": max_overflow,
"connect_args": connect_args,
**engine_kwargs
}
# Only set poolclass for SQLite (requires StaticPool for async)
if "sqlite" in database_url.lower():
engine_config["poolclass"] = StaticPool
engine_config["pool_size"] = 1
engine_config["max_overflow"] = 0
self.async_engine = create_async_engine(database_url, **engine_config)
# Create session factory
self.async_session_local = async_sessionmaker(
self.async_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False
)
logger.info(f"DatabaseManager initialized for {service_name}",
pool_size=pool_size,
max_overflow=max_overflow,
database_type=self._get_database_type())
async def get_db(self):
"""Get database session for request handlers (FastAPI dependency)"""
async with self.async_session_local() as session:
try:
logger.debug("Database session created for request")
yield session
except Exception as e:
await session.rollback()
# Don't wrap HTTPExceptions - let them pass through
# Check by type name to avoid import dependencies
exception_type = type(e).__name__
if exception_type in ('HTTPException', 'StarletteHTTPException', 'RequestValidationError', 'ValidationError'):
logger.debug(f"Re-raising {exception_type}: {e}", service=self.service_name)
raise
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
logger.error(f"Database session error: {error_msg}", service=self.service_name)
# Handle specific ASGI stream issues more gracefully
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
raise DatabaseError(f"Session error: Request stream disconnected ({type(e).__name__})")
else:
raise DatabaseError(f"Session error: {error_msg}")
finally:
await session.close()
logger.debug("Database session closed")
@asynccontextmanager
async def get_background_session(self):
"""
Get database session for background tasks with auto-commit
Usage:
async with database_manager.get_background_session() as session:
# Your background task code here
# Auto-commits on success, rolls back on exception
"""
async with self.async_session_local() as session:
try:
logger.debug("Background session created", service=self.service_name)
yield session
await session.commit()
logger.debug("Background session committed")
except Exception as e:
await session.rollback()
logger.error(f"Background task database error: {e}",
service=self.service_name)
raise DatabaseError(f"Background task failed: {str(e)}")
finally:
await session.close()
logger.debug("Background session closed")
@asynccontextmanager
async def get_session(self):
"""Get a plain database session (no auto-commit)"""
async with self.async_session_local() as session:
try:
yield session
except Exception as e:
await session.rollback()
# Don't wrap HTTPExceptions - let them pass through
exception_type = type(e).__name__
if exception_type in ('HTTPException', 'StarletteHTTPException'):
logger.debug(f"Re-raising HTTPException: {e}", service=self.service_name)
raise
logger.error(f"Session error: {e}", service=self.service_name)
raise DatabaseError(f"Session error: {str(e)}")
finally:
await session.close()
# ===== TABLE MANAGEMENT =====
async def create_tables(self, metadata=None):
"""Create database tables with enhanced error handling and transaction verification"""
try:
target_metadata = metadata or Base.metadata
table_names = list(target_metadata.tables.keys())
logger.info(f"Creating tables: {table_names}", service=self.service_name)
# Use explicit transaction with proper error handling
async with self.async_engine.begin() as conn:
try:
# Create tables within the transaction
await conn.run_sync(target_metadata.create_all, checkfirst=True)
# Verify transaction is not in error state
# Try a simple query to ensure connection is still valid
await conn.execute(text("SELECT 1"))
logger.info("Database tables creation transaction completed successfully",
service=self.service_name, tables=table_names)
except Exception as create_error:
logger.error(f"Error during table creation within transaction: {create_error}",
service=self.service_name)
# Re-raise to trigger transaction rollback
raise
logger.info("Database tables created successfully", service=self.service_name)
except Exception as e:
# Check if it's a "relation already exists" error which can be safely ignored
error_str = str(e).lower()
if "already exists" in error_str or "duplicate" in error_str:
logger.warning(f"Some database objects already exist - continuing: {e}", service=self.service_name)
logger.info("Database tables creation completed (some already existed)", service=self.service_name)
else:
logger.error(f"Failed to create tables: {e}", service=self.service_name)
# Check for specific transaction error indicators
if any(indicator in error_str for indicator in [
"transaction", "rollback", "aborted", "failed sql transaction"
]):
logger.error("Transaction-related error detected during table creation",
service=self.service_name)
raise DatabaseError(f"Table creation failed: {str(e)}")
async def drop_tables(self, metadata=None):
"""Drop database tables"""
try:
target_metadata = metadata or Base.metadata
async with self.async_engine.begin() as conn:
await conn.run_sync(target_metadata.drop_all)
logger.info("Database tables dropped successfully", service=self.service_name)
except Exception as e:
logger.error(f"Failed to drop tables: {e}", service=self.service_name)
raise DatabaseError(f"Table drop failed: {str(e)}")
# ===== HEALTH CHECKS AND MONITORING =====
async def health_check(self) -> Dict[str, Any]:
"""Comprehensive health check for the database"""
try:
async with self.get_session() as session:
return await DatabaseUtils.execute_health_check(session)
except Exception as e:
logger.error(f"Health check failed: {e}", service=self.service_name)
raise HealthCheckError(f"Health check failed: {str(e)}")
async def get_connection_info(self) -> Dict[str, Any]:
"""Get database connection information"""
try:
pool = self.async_engine.pool
return {
"service_name": self.service_name,
"database_type": self._get_database_type(),
"pool_size": self.pool_size,
"max_overflow": self.max_overflow,
"current_checked_in": pool.checkedin() if pool else 0,
"current_checked_out": pool.checkedout() if pool else 0,
"current_overflow": pool.overflow() if pool else 0,
"invalid_connections": getattr(pool, 'invalid', lambda: 0)() if pool else 0
}
except Exception as e:
logger.error(f"Failed to get connection info: {e}", service=self.service_name)
return {"error": str(e)}
def _get_database_type(self) -> str:
"""Get database type from URL"""
return self.database_url.split("://")[0].lower() if "://" in self.database_url else "unknown"
# ===== CLEANUP AND MAINTENANCE =====
async def close_connections(self):
"""Close all database connections"""
try:
await self.async_engine.dispose()
logger.info("Database connections closed", service=self.service_name)
except Exception as e:
logger.error(f"Failed to close connections: {e}", service=self.service_name)
raise DatabaseError(f"Connection cleanup failed: {str(e)}")
async def execute_maintenance(self) -> Dict[str, Any]:
"""Execute database maintenance tasks"""
try:
async with self.get_session() as session:
return await DatabaseUtils.execute_maintenance(session)
except Exception as e:
logger.error(f"Maintenance failed: {e}", service=self.service_name)
raise DatabaseError(f"Maintenance failed: {str(e)}")
# ===== UTILITY METHODS =====
async def test_connection(self) -> bool:
"""Test database connectivity"""
try:
async with self.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Connection test successful", service=self.service_name)
return True
except Exception as e:
logger.error(f"Connection test failed: {e}", service=self.service_name)
return False
def __repr__(self) -> str:
return f"DatabaseManager(service='{self.service_name}', type='{self._get_database_type()}')"
async def execute(self, query: str, *args, **kwargs):
"""
Execute a raw SQL query with proper session management
Note: Use this method carefully to avoid transaction conflicts
"""
from sqlalchemy import text
# Use a new session context to avoid conflicts with existing sessions
async with self.get_session() as session:
try:
# Convert query to SQLAlchemy text object if it's a string
if isinstance(query, str):
query = text(query)
result = await session.execute(query, *args, **kwargs)
# For UPDATE/DELETE operations that need to be committed
if query.text.strip().upper().startswith(('UPDATE', 'DELETE', 'INSERT')):
await session.commit()
return result
except Exception as e:
# Only rollback if it was a modifying operation
if isinstance(query, str) and query.strip().upper().startswith(('UPDATE', 'DELETE', 'INSERT')):
await session.rollback()
logger.error("Database execute failed", error=str(e))
raise
# ===== CONVENIENCE FUNCTIONS =====
# ===== CONVENIENCE FUNCTIONS =====
def create_database_manager(
database_url: str,
service_name: str,
**kwargs
) -> DatabaseManager:
"""Factory function to create DatabaseManager instances"""
return DatabaseManager(database_url, service_name, **kwargs)
# ===== LEGACY COMPATIBILITY =====
# Keep backward compatibility for existing code
engine = None
AsyncSessionLocal = None
def init_legacy_compatibility(database_url: str):
"""Initialize legacy global variables for backward compatibility"""
global engine, AsyncSessionLocal
# Configure SSL for PostgreSQL if needed
connect_args = {}
if "postgresql" in database_url.lower() and "asyncpg" in database_url.lower():
if "ssl" not in database_url.lower() and "sslmode" not in database_url.lower():
# Create SSL context that doesn't verify certificates (for local development)
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
connect_args["ssl"] = ssl_context
logger.info("SSL enabled with relaxed verification for legacy database connection")
engine = create_async_engine(
database_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300,
pool_size=20,
max_overflow=30,
connect_args=connect_args
)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
logger.warning("Using legacy database configuration - consider migrating to DatabaseManager")
async def get_legacy_db():
"""Legacy database session getter for backward compatibility"""
if not AsyncSessionLocal:
raise RuntimeError("Legacy database not initialized - call init_legacy_compatibility first")
async with AsyncSessionLocal() as session:
try:
yield session
except Exception as e:
logger.error(f"Legacy database session error: {e}")
await session.rollback()
raise
finally:
await session.close()

52
shared/database/exceptions.py Executable file
View File

@@ -0,0 +1,52 @@
"""
Custom Database Exceptions
Provides consistent error handling across all microservices
"""
class DatabaseError(Exception):
"""Base exception for database-related errors"""
def __init__(self, message: str, details: dict = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
class ConnectionError(DatabaseError):
"""Raised when database connection fails"""
pass
class RecordNotFoundError(DatabaseError):
"""Raised when a requested record is not found"""
pass
class DuplicateRecordError(DatabaseError):
"""Raised when trying to create a duplicate record"""
pass
class ConstraintViolationError(DatabaseError):
"""Raised when database constraints are violated"""
pass
class TransactionError(DatabaseError):
"""Raised when transaction operations fail"""
pass
class ValidationError(DatabaseError):
"""Raised when data validation fails before database operations"""
pass
class MigrationError(DatabaseError):
"""Raised when database migration operations fail"""
pass
class HealthCheckError(DatabaseError):
"""Raised when database health checks fail"""
pass

381
shared/database/init_manager.py Executable file
View File

@@ -0,0 +1,381 @@
"""
Database Initialization Manager
Handles Alembic-based migrations with autogenerate support:
1. First-time deployment: Generate initial migration from models
2. Subsequent deployments: Run pending migrations
3. Development reset: Drop tables and regenerate migrations
"""
import os
import asyncio
import structlog
from typing import Optional, List, Dict, Any
from pathlib import Path
from sqlalchemy import text, inspect
from sqlalchemy.ext.asyncio import AsyncSession
from alembic.config import Config
from alembic import command
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from .base import DatabaseManager, Base
logger = structlog.get_logger()
class DatabaseInitManager:
"""
Manages database initialization using Alembic migrations exclusively.
Two modes:
1. Migration mode (for migration jobs): Runs alembic upgrade head
2. Verification mode (for services): Only verifies database is ready
"""
def __init__(
self,
database_manager: DatabaseManager,
service_name: str,
alembic_ini_path: Optional[str] = None,
models_module: Optional[str] = None,
verify_only: bool = True, # Default: services only verify
force_recreate: bool = False
):
self.database_manager = database_manager
self.service_name = service_name
self.alembic_ini_path = alembic_ini_path
self.models_module = models_module
self.verify_only = verify_only
self.force_recreate = force_recreate
self.logger = logger.bind(service=service_name)
async def initialize_database(self) -> Dict[str, Any]:
"""
Main initialization method.
Two modes:
1. verify_only=True (default, for services):
- Verifies database is ready
- Checks tables exist
- Checks alembic_version exists
- DOES NOT run migrations
2. verify_only=False (for migration jobs only):
- Runs alembic upgrade head
- Applies pending migrations
- Can force recreate if needed
"""
if self.verify_only:
self.logger.info("Database verification mode - checking database is ready")
return await self._verify_database_ready()
else:
self.logger.info("Migration mode - running database migrations")
return await self._run_migrations_mode()
async def _verify_database_ready(self) -> Dict[str, Any]:
"""
Verify database is ready for service startup.
Services should NOT run migrations - only verify they've been applied.
"""
try:
# Check alembic configuration exists
if not self.alembic_ini_path or not os.path.exists(self.alembic_ini_path):
raise Exception(f"Alembic configuration not found at {self.alembic_ini_path}")
# Check database state
db_state = await self._check_database_state()
self.logger.info("Database state checked", state=db_state)
# Verify migrations exist
if not db_state["has_migrations"]:
raise Exception(
f"No migration files found for {self.service_name}. "
f"Migrations must be generated and included in the Docker image."
)
# Verify database is not empty
if db_state["is_empty"]:
raise Exception(
f"Database is empty. Migration job must run before service startup. "
f"Ensure migration job completes successfully before starting services."
)
# Verify alembic_version table exists
if not db_state["has_alembic_version"]:
raise Exception(
f"No alembic_version table found. Migration job must run before service startup."
)
# Verify current revision exists
if not db_state["current_revision"]:
raise Exception(
f"No current migration revision found. Database may not be properly initialized."
)
self.logger.info(
"Database verification successful",
migration_count=db_state["migration_count"],
current_revision=db_state["current_revision"],
table_count=len(db_state["existing_tables"])
)
return {
"action": "verified",
"message": "Database verified successfully - ready for service",
"current_revision": db_state["current_revision"],
"migration_count": db_state["migration_count"],
"table_count": len(db_state["existing_tables"])
}
except Exception as e:
self.logger.error("Database verification failed", error=str(e))
raise
async def _run_migrations_mode(self) -> Dict[str, Any]:
"""
Run migrations mode - for migration jobs only.
"""
try:
if not self.alembic_ini_path or not os.path.exists(self.alembic_ini_path):
raise Exception(f"Alembic configuration not found at {self.alembic_ini_path}")
# Check current database state
db_state = await self._check_database_state()
self.logger.info("Database state checked", state=db_state)
# Handle force recreate
if self.force_recreate:
return await self._handle_force_recreate()
# Check migrations exist
if not db_state["has_migrations"]:
raise Exception(
f"No migration files found for {self.service_name}. "
f"Generate migrations using regenerate_migrations_k8s.sh script."
)
# Run migrations
result = await self._handle_run_migrations()
self.logger.info("Migration mode completed", result=result)
return result
except Exception as e:
self.logger.error("Migration mode failed", error=str(e))
raise
async def _check_database_state(self) -> Dict[str, Any]:
"""Check the current state of migrations"""
state = {
"has_migrations": False,
"migration_count": 0,
"is_empty": False,
"existing_tables": [],
"has_alembic_version": False,
"current_revision": None
}
try:
# Check if migration files exist
migrations_dir = self._get_migrations_versions_dir()
if migrations_dir.exists():
migration_files = list(migrations_dir.glob("*.py"))
migration_files = [f for f in migration_files if f.name != "__pycache__" and not f.name.startswith("_")]
state["migration_count"] = len(migration_files)
state["has_migrations"] = len(migration_files) > 0
self.logger.info("Found migration files", count=len(migration_files))
# Check database tables
async with self.database_manager.get_session() as session:
existing_tables = await self._get_existing_tables(session)
state["existing_tables"] = existing_tables
state["is_empty"] = len(existing_tables) == 0
# Check alembic_version table
if "alembic_version" in existing_tables:
state["has_alembic_version"] = True
result = await session.execute(text("SELECT version_num FROM alembic_version"))
version = result.scalar()
state["current_revision"] = version
except Exception as e:
self.logger.warning("Error checking database state", error=str(e))
return state
async def _handle_run_migrations(self) -> Dict[str, Any]:
"""Handle normal migration scenario - run pending migrations"""
self.logger.info("Running pending migrations")
try:
await self._run_migrations()
return {
"action": "migrations_applied",
"message": "Pending migrations applied successfully"
}
except Exception as e:
self.logger.error("Failed to run migrations", error=str(e))
raise
async def _handle_force_recreate(self) -> Dict[str, Any]:
"""Handle development reset scenario - drop and recreate tables using existing migrations"""
self.logger.info("Force recreate: dropping tables and rerunning migrations")
try:
# Drop all tables
await self._drop_all_tables()
# Apply migrations from scratch
await self._run_migrations()
return {
"action": "force_recreate",
"tables_dropped": True,
"migrations_applied": True,
"message": "Database recreated from existing migrations"
}
except Exception as e:
self.logger.error("Failed to force recreate", error=str(e))
raise
async def _run_migrations(self):
"""Run pending Alembic migrations (upgrade head)"""
try:
def run_alembic_upgrade():
import os
from pathlib import Path
# Ensure we're in the correct working directory
alembic_dir = Path(self.alembic_ini_path).parent
original_cwd = os.getcwd()
try:
os.chdir(alembic_dir)
alembic_cfg = Config(self.alembic_ini_path)
# Set the SQLAlchemy URL from the database manager
alembic_cfg.set_main_option("sqlalchemy.url", str(self.database_manager.database_url))
# Run upgrade
command.upgrade(alembic_cfg, "head")
finally:
os.chdir(original_cwd)
# Run in executor to avoid blocking
await asyncio.get_event_loop().run_in_executor(None, run_alembic_upgrade)
self.logger.info("Migrations applied successfully")
except Exception as e:
self.logger.error("Failed to run migrations", error=str(e))
raise
async def _drop_all_tables(self):
"""Drop all tables (for development reset)"""
try:
async with self.database_manager.async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
self.logger.info("All tables dropped")
except Exception as e:
self.logger.error("Failed to drop tables", error=str(e))
raise
def _get_migrations_versions_dir(self) -> Path:
"""Get the migrations/versions directory path"""
alembic_path = Path(self.alembic_ini_path).parent
return alembic_path / "migrations" / "versions"
async def _get_existing_tables(self, session: AsyncSession) -> List[str]:
"""Get list of existing tables in the database"""
def get_tables_sync(connection):
insp = inspect(connection)
return insp.get_table_names()
connection = await session.connection()
return await connection.run_sync(get_tables_sync)
def create_init_manager(
database_manager: DatabaseManager,
service_name: str,
service_path: Optional[str] = None,
verify_only: bool = True,
force_recreate: bool = False
) -> DatabaseInitManager:
"""
Factory function to create a DatabaseInitManager with auto-detected paths
Args:
database_manager: DatabaseManager instance
service_name: Name of the service
service_path: Path to service directory (auto-detected if None)
verify_only: True = verify DB ready (services), False = run migrations (jobs only)
force_recreate: Force recreate tables (requires verify_only=False)
"""
# Auto-detect paths if not provided
if service_path is None:
# Try Docker container path first (service files at root level)
if os.path.exists("alembic.ini"):
service_path = "."
else:
# Fallback to development path
service_path = f"services/{service_name}"
# Set up paths based on environment
if service_path == ".":
# Docker container environment
alembic_ini_path = "alembic.ini"
models_module = "app.models"
else:
# Development environment
alembic_ini_path = f"{service_path}/alembic.ini"
models_module = f"services.{service_name}.app.models"
# Check if paths exist
if not os.path.exists(alembic_ini_path):
logger.warning("Alembic config not found", path=alembic_ini_path)
alembic_ini_path = None
return DatabaseInitManager(
database_manager=database_manager,
service_name=service_name,
alembic_ini_path=alembic_ini_path,
models_module=models_module,
verify_only=verify_only,
force_recreate=force_recreate
)
async def initialize_service_database(
database_manager: DatabaseManager,
service_name: str,
verify_only: bool = True,
force_recreate: bool = False
) -> Dict[str, Any]:
"""
Convenience function for database initialization
Args:
database_manager: DatabaseManager instance
service_name: Name of the service
verify_only: True = verify DB ready (default, services), False = run migrations (jobs only)
force_recreate: Force recreate tables (requires verify_only=False)
Returns:
Dict with initialization results
"""
init_manager = create_init_manager(
database_manager=database_manager,
service_name=service_name,
verify_only=verify_only,
force_recreate=force_recreate
)
return await init_manager.initialize_database()

428
shared/database/repository.py Executable file
View File

@@ -0,0 +1,428 @@
"""
Base Repository Pattern for Database Operations
Provides generic CRUD operations, query building, and caching
"""
from typing import Optional, List, Dict, Any, TypeVar, Generic, Type, Union
from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import declarative_base
from sqlalchemy import select, update, delete, and_, or_, desc, asc, func, text
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from contextlib import asynccontextmanager
import structlog
from .exceptions import (
DatabaseError,
RecordNotFoundError,
DuplicateRecordError,
ConstraintViolationError
)
logger = structlog.get_logger()
# Type variables for generic repository
Model = TypeVar('Model', bound=declarative_base())
CreateSchema = TypeVar('CreateSchema')
UpdateSchema = TypeVar('UpdateSchema')
class BaseRepository(Generic[Model, CreateSchema, UpdateSchema], ABC):
"""
Base repository providing generic CRUD operations
Args:
model: SQLAlchemy model class
session: Database session
cache_ttl: Cache time-to-live in seconds (optional)
"""
def __init__(self, model: Type[Model], session: AsyncSession, cache_ttl: Optional[int] = None):
self.model = model
self.session = session
self.cache_ttl = cache_ttl
self._cache = {} if cache_ttl else None
# ===== CORE CRUD OPERATIONS =====
async def create(self, obj_in: CreateSchema, **kwargs) -> Model:
"""Create a new record"""
try:
# Convert schema to dict if needed
if hasattr(obj_in, 'model_dump'):
obj_data = obj_in.model_dump()
elif hasattr(obj_in, 'dict'):
obj_data = obj_in.dict()
else:
obj_data = obj_in
# Merge with additional kwargs
obj_data.update(kwargs)
db_obj = self.model(**obj_data)
self.session.add(db_obj)
await self.session.flush() # Get ID without committing
await self.session.refresh(db_obj)
logger.debug(f"Created {self.model.__name__}", record_id=getattr(db_obj, 'id', None))
return db_obj
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error creating {self.model.__name__}", error=str(e))
raise DuplicateRecordError(f"Record with provided data already exists")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error creating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to create record: {str(e)}")
async def get_by_id(self, record_id: Any) -> Optional[Model]:
"""Get record by ID with optional caching"""
cache_key = f"{self.model.__name__}:{record_id}"
# Check cache first
if self._cache and cache_key in self._cache:
logger.debug(f"Cache hit for {cache_key}")
return self._cache[cache_key]
try:
result = await self.session.execute(
select(self.model).where(self.model.id == record_id)
)
record = result.scalar_one_or_none()
# Cache the result
if self._cache and record:
self._cache[cache_key] = record
return record
except SQLAlchemyError as e:
logger.error(f"Database error getting {self.model.__name__} by ID",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to get record: {str(e)}")
async def get_by_field(self, field_name: str, value: Any) -> Optional[Model]:
"""Get record by specific field"""
try:
result = await self.session.execute(
select(self.model).where(getattr(self.model, field_name) == value)
)
return result.scalar_one_or_none()
except AttributeError:
raise ValueError(f"Field '{field_name}' not found in {self.model.__name__}")
except SQLAlchemyError as e:
logger.error(f"Database error getting {self.model.__name__} by {field_name}",
value=value, error=str(e))
raise DatabaseError(f"Failed to get record: {str(e)}")
async def get_multi(
self,
skip: int = 0,
limit: int = 100,
order_by: Optional[str] = None,
order_desc: bool = False,
filters: Optional[Dict[str, Any]] = None
) -> List[Model]:
"""Get multiple records with pagination, sorting, and filtering"""
try:
query = select(self.model)
# Apply filters
if filters:
conditions = []
for field, value in filters.items():
if hasattr(self.model, field):
if isinstance(value, list):
conditions.append(getattr(self.model, field).in_(value))
else:
conditions.append(getattr(self.model, field) == value)
if conditions:
query = query.where(and_(*conditions))
# Apply ordering
if order_by and hasattr(self.model, order_by):
order_field = getattr(self.model, order_by)
if order_desc:
query = query.order_by(desc(order_field))
else:
query = query.order_by(asc(order_field))
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except SQLAlchemyError as e:
logger.error(f"Database error getting multiple {self.model.__name__} records",
error=str(e))
raise DatabaseError(f"Failed to get records: {str(e)}")
async def update(self, record_id: Any, obj_in: UpdateSchema, **kwargs) -> Optional[Model]:
"""Update record by ID"""
try:
# Convert schema to dict if needed
if hasattr(obj_in, 'model_dump'):
update_data = obj_in.model_dump(exclude_unset=True)
elif hasattr(obj_in, 'dict'):
update_data = obj_in.dict(exclude_unset=True)
else:
update_data = obj_in
# Merge with additional kwargs
update_data.update(kwargs)
# Remove None values
update_data = {k: v for k, v in update_data.items() if v is not None}
if not update_data:
logger.warning(f"No data to update for {self.model.__name__}", record_id=record_id)
return await self.get_by_id(record_id)
# Perform update
result = await self.session.execute(
update(self.model)
.where(self.model.id == record_id)
.values(**update_data)
.returning(self.model)
)
updated_record = result.scalar_one_or_none()
if not updated_record:
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
# Clear cache
if self._cache:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Updated {self.model.__name__}", record_id=record_id)
return updated_record
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error updating {self.model.__name__}",
record_id=record_id, error=str(e))
raise ConstraintViolationError(f"Update violates database constraints")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error updating {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to update record: {str(e)}")
async def delete(self, record_id: Any) -> bool:
"""Delete record by ID"""
try:
result = await self.session.execute(
delete(self.model).where(self.model.id == record_id)
)
deleted_count = result.rowcount
if deleted_count == 0:
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
# Clear cache
if self._cache:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Deleted {self.model.__name__}", record_id=record_id)
return True
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error deleting {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to delete record: {str(e)}")
# ===== ADVANCED QUERY OPERATIONS =====
async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
"""Count records with optional filters"""
try:
query = select(func.count(self.model.id))
if filters:
conditions = []
for field, value in filters.items():
if hasattr(self.model, field):
if isinstance(value, list):
conditions.append(getattr(self.model, field).in_(value))
else:
conditions.append(getattr(self.model, field) == value)
if conditions:
query = query.where(and_(*conditions))
result = await self.session.execute(query)
return result.scalar() or 0
except SQLAlchemyError as e:
logger.error(f"Database error counting {self.model.__name__} records", error=str(e))
raise DatabaseError(f"Failed to count records: {str(e)}")
async def exists(self, record_id: Any) -> bool:
"""Check if record exists by ID"""
try:
result = await self.session.execute(
select(func.count(self.model.id)).where(self.model.id == record_id)
)
count = result.scalar() or 0
return count > 0
except SQLAlchemyError as e:
logger.error(f"Database error checking existence of {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to check record existence: {str(e)}")
async def bulk_create(self, objects: List[CreateSchema]) -> List[Model]:
"""Create multiple records in bulk"""
try:
if not objects:
return []
db_objects = []
for obj_in in objects:
if hasattr(obj_in, 'model_dump'):
obj_data = obj_in.model_dump()
elif hasattr(obj_in, 'dict'):
obj_data = obj_in.dict()
else:
obj_data = obj_in
db_objects.append(self.model(**obj_data))
self.session.add_all(db_objects)
await self.session.flush()
# Skip expensive individual refresh operations for large datasets
# Only refresh if we have a small number of objects
if len(db_objects) <= 100:
for db_obj in db_objects:
await self.session.refresh(db_obj)
else:
# For large datasets, just log without refresh to prevent memory issues
logger.debug(f"Skipped individual refresh for large bulk operation ({len(db_objects)} records)")
logger.debug(f"Bulk created {len(db_objects)} {self.model.__name__} records")
return db_objects
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error bulk creating {self.model.__name__}", error=str(e))
raise DuplicateRecordError(f"One or more records already exist")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error bulk creating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to create records: {str(e)}")
async def bulk_update(self, updates: List[Dict[str, Any]]) -> int:
"""Update multiple records in bulk"""
try:
if not updates:
return 0
# Group updates by fields being updated for efficiency
for update_data in updates:
if 'id' not in update_data:
raise ValueError("Each update must include 'id' field")
record_id = update_data.pop('id')
await self.session.execute(
update(self.model)
.where(self.model.id == record_id)
.values(**update_data)
)
# Clear relevant cache entries
if self._cache:
for update_data in updates:
record_id = update_data.get('id')
if record_id:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Bulk updated {len(updates)} {self.model.__name__} records")
return len(updates)
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error bulk updating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to update records: {str(e)}")
# ===== SEARCH AND QUERY BUILDING =====
async def search(
self,
search_term: str,
search_fields: List[str],
skip: int = 0,
limit: int = 100
) -> List[Model]:
"""Search records across multiple fields"""
try:
conditions = []
for field in search_fields:
if hasattr(self.model, field):
field_obj = getattr(self.model, field)
# Case-insensitive partial match
conditions.append(field_obj.ilike(f"%{search_term}%"))
if not conditions:
logger.warning(f"No valid search fields provided for {self.model.__name__}")
return []
query = select(self.model).where(or_(*conditions)).offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except SQLAlchemyError as e:
logger.error(f"Database error searching {self.model.__name__}",
search_term=search_term, error=str(e))
raise DatabaseError(f"Failed to search records: {str(e)}")
async def execute_raw_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
"""Execute raw SQL query (use with caution)"""
try:
result = await self.session.execute(text(query), params or {})
return result
except SQLAlchemyError as e:
logger.error(f"Database error executing raw query", query=query, error=str(e))
raise DatabaseError(f"Failed to execute query: {str(e)}")
# ===== CACHE MANAGEMENT =====
def clear_cache(self, record_id: Optional[Any] = None):
"""Clear cache for specific record or all records"""
if not self._cache:
return
if record_id:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
else:
# Clear all cache entries for this model
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{self.model.__name__}:")]
for key in keys_to_remove:
self._cache.pop(key, None)
logger.debug(f"Cleared cache for {self.model.__name__}", record_id=record_id)
# ===== CONTEXT MANAGERS =====
@asynccontextmanager
async def transaction(self):
"""Context manager for explicit transaction handling"""
try:
yield self.session
await self.session.commit()
except Exception as e:
await self.session.rollback()
logger.error(f"Transaction failed for {self.model.__name__}", error=str(e))
raise

306
shared/database/transactions.py Executable file
View File

@@ -0,0 +1,306 @@
"""
Transaction Decorators and Context Managers
Provides convenient transaction handling for service methods
"""
from functools import wraps
from typing import Callable, Any, Optional
from contextlib import asynccontextmanager
import structlog
from .base import DatabaseManager
from .unit_of_work import UnitOfWork
from .exceptions import TransactionError
logger = structlog.get_logger()
def transactional(database_manager: DatabaseManager, auto_commit: bool = True):
"""
Decorator that wraps a method in a database transaction
Args:
database_manager: DatabaseManager instance
auto_commit: Whether to auto-commit on success
Usage:
@transactional(database_manager)
async def create_user_with_profile(self, user_data, profile_data):
# Your business logic here
# Transaction is automatically managed
pass
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
async with database_manager.get_background_session() as session:
try:
# Inject session into kwargs if not present
if 'session' not in kwargs:
kwargs['session'] = session
result = await func(*args, **kwargs)
# Session is auto-committed by get_background_session
logger.debug(f"Transaction completed successfully for {func.__name__}")
return result
except Exception as e:
# Session is auto-rolled back by get_background_session
logger.error(f"Transaction failed for {func.__name__}", error=str(e))
raise TransactionError(f"Transaction failed: {str(e)}")
return wrapper
return decorator
def unit_of_work_transactional(database_manager: DatabaseManager):
"""
Decorator that provides Unit of Work pattern for complex operations
Usage:
@unit_of_work_transactional(database_manager)
async def complex_business_operation(self, data, uow: UnitOfWork):
user_repo = uow.register_repository("users", UserRepository, User)
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
user = await user_repo.create(data.user)
sale = await sales_repo.create(data.sale)
# UnitOfWork automatically commits
return {"user": user, "sale": sale}
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
async with database_manager.get_background_session() as session:
async with UnitOfWork(session, auto_commit=True) as uow:
try:
# Inject UnitOfWork into kwargs
kwargs['uow'] = uow
result = await func(*args, **kwargs)
logger.debug(f"Unit of Work transaction completed for {func.__name__}")
return result
except Exception as e:
logger.error(f"Unit of Work transaction failed for {func.__name__}",
error=str(e))
raise TransactionError(f"Transaction failed: {str(e)}")
return wrapper
return decorator
@asynccontextmanager
async def managed_transaction(database_manager: DatabaseManager):
"""
Context manager for explicit transaction control
Usage:
async with managed_transaction(database_manager) as session:
# Your database operations here
user = User(name="John")
session.add(user)
# Auto-commits on exit, rolls back on exception
"""
async with database_manager.get_background_session() as session:
try:
logger.debug("Starting managed transaction")
yield session
logger.debug("Managed transaction completed successfully")
except Exception as e:
logger.error("Managed transaction failed", error=str(e))
raise
@asynccontextmanager
async def managed_unit_of_work(database_manager: DatabaseManager, event_publisher=None):
"""
Context manager for explicit Unit of Work control
Usage:
async with managed_unit_of_work(database_manager) as uow:
user_repo = uow.register_repository("users", UserRepository, User)
user = await user_repo.create(user_data)
await uow.commit()
"""
async with database_manager.get_background_session() as session:
uow = UnitOfWork(session)
try:
logger.debug("Starting managed Unit of Work")
yield uow
if not uow._committed:
await uow.commit()
logger.debug("Managed Unit of Work completed successfully")
except Exception as e:
if not uow._rolled_back:
await uow.rollback()
logger.error("Managed Unit of Work failed", error=str(e))
raise
class TransactionManager:
"""
Advanced transaction manager for complex scenarios
Usage:
tx_manager = TransactionManager(database_manager)
async with tx_manager.create_transaction() as tx:
await tx.execute_in_transaction(my_business_logic, data)
"""
def __init__(self, database_manager: DatabaseManager):
self.database_manager = database_manager
@asynccontextmanager
async def create_transaction(self, isolation_level: Optional[str] = None):
"""Create a transaction with optional isolation level"""
async with self.database_manager.get_background_session() as session:
transaction_context = TransactionContext(session, isolation_level)
try:
yield transaction_context
except Exception as e:
logger.error("Transaction manager failed", error=str(e))
raise
async def execute_with_retry(
self,
func: Callable,
max_retries: int = 3,
*args,
**kwargs
):
"""Execute function with transaction retry on failure"""
last_error = None
for attempt in range(max_retries):
try:
async with managed_transaction(self.database_manager) as session:
kwargs['session'] = session
result = await func(*args, **kwargs)
logger.debug(f"Transaction succeeded on attempt {attempt + 1}")
return result
except Exception as e:
last_error = e
logger.warning(f"Transaction attempt {attempt + 1} failed",
error=str(e), remaining_attempts=max_retries - attempt - 1)
if attempt == max_retries - 1:
break
logger.error(f"All transaction attempts failed after {max_retries} tries")
raise TransactionError(f"Transaction failed after {max_retries} retries: {str(last_error)}")
class TransactionContext:
"""Context for managing individual transactions"""
def __init__(self, session, isolation_level: Optional[str] = None):
self.session = session
self.isolation_level = isolation_level
async def execute_in_transaction(self, func: Callable, *args, **kwargs):
"""Execute function within the transaction context"""
try:
kwargs['session'] = self.session
result = await func(*args, **kwargs)
return result
except Exception as e:
logger.error("Function execution failed in transaction context", error=str(e))
raise
# ===== UTILITY FUNCTIONS =====
async def run_in_transaction(database_manager: DatabaseManager, func: Callable, *args, **kwargs):
"""
Utility function to run any async function in a transaction
Usage:
result = await run_in_transaction(
database_manager,
my_async_function,
arg1, arg2,
kwarg1="value"
)
"""
async with managed_transaction(database_manager) as session:
kwargs['session'] = session
return await func(*args, **kwargs)
async def run_with_unit_of_work(
database_manager: DatabaseManager,
func: Callable,
*args,
**kwargs
):
"""
Utility function to run any async function with Unit of Work
Usage:
result = await run_with_unit_of_work(
database_manager,
my_complex_function,
arg1, arg2
)
"""
async with managed_unit_of_work(database_manager) as uow:
kwargs['uow'] = uow
return await func(*args, **kwargs)
# ===== BATCH OPERATIONS =====
@asynccontextmanager
async def batch_operation(database_manager: DatabaseManager, batch_size: int = 1000):
"""
Context manager for batch operations with automatic commit batching
Usage:
async with batch_operation(database_manager, batch_size=500) as batch:
for item in large_dataset:
await batch.add_operation(create_record, item)
"""
async with database_manager.get_background_session() as session:
batch_context = BatchOperationContext(session, batch_size)
try:
yield batch_context
await batch_context.flush_remaining()
except Exception as e:
logger.error("Batch operation failed", error=str(e))
raise
class BatchOperationContext:
"""Context for managing batch database operations"""
def __init__(self, session, batch_size: int):
self.session = session
self.batch_size = batch_size
self.operation_count = 0
async def add_operation(self, func: Callable, *args, **kwargs):
"""Add operation to batch"""
kwargs['session'] = self.session
await func(*args, **kwargs)
self.operation_count += 1
if self.operation_count >= self.batch_size:
await self.session.commit()
self.operation_count = 0
logger.debug(f"Batch committed at {self.batch_size} operations")
async def flush_remaining(self):
"""Commit any remaining operations"""
if self.operation_count > 0:
await self.session.commit()
logger.debug(f"Final batch committed with {self.operation_count} operations")

304
shared/database/unit_of_work.py Executable file
View File

@@ -0,0 +1,304 @@
"""
Unit of Work Pattern Implementation
Manages transactions across multiple repositories with event publishing
"""
from typing import Dict, Any, List, Optional, Type, TypeVar, Generic
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from abc import ABC, abstractmethod
import structlog
from .repository import BaseRepository
from .exceptions import TransactionError
logger = structlog.get_logger()
Model = TypeVar('Model')
Repository = TypeVar('Repository', bound=BaseRepository)
class BaseEvent(ABC):
"""Base class for domain events"""
def __init__(self, event_type: str, data: Dict[str, Any]):
self.event_type = event_type
self.data = data
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""Convert event to dictionary for publishing"""
pass
class DomainEvent(BaseEvent):
"""Standard domain event implementation"""
def to_dict(self) -> Dict[str, Any]:
return {
"event_type": self.event_type,
"data": self.data
}
class UnitOfWork:
"""
Unit of Work pattern for managing transactions and coordinating repositories
Usage:
async with UnitOfWork(session) as uow:
user_repo = uow.register_repository("users", UserRepository, User)
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
user = await user_repo.create(user_data)
sale = await sales_repo.create(sales_data)
await uow.commit()
"""
def __init__(self, session: AsyncSession, auto_commit: bool = False):
self.session = session
self.auto_commit = auto_commit
self._repositories: Dict[str, BaseRepository] = {}
self._events: List[BaseEvent] = []
self._committed = False
self._rolled_back = False
def register_repository(
self,
name: str,
repository_class: Type[Repository],
model_class: Type[Model],
**kwargs
) -> Repository:
"""
Register a repository with the unit of work
Args:
name: Unique name for the repository
repository_class: Repository class to instantiate
model_class: SQLAlchemy model class
**kwargs: Additional arguments for repository
Returns:
Instantiated repository
"""
if name in self._repositories:
logger.warning(f"Repository '{name}' already registered, returning existing instance")
return self._repositories[name]
repository = repository_class(model_class, self.session, **kwargs)
self._repositories[name] = repository
logger.debug(f"Registered repository", name=name, model=model_class.__name__)
return repository
def get_repository(self, name: str) -> Optional[Repository]:
"""Get registered repository by name"""
return self._repositories.get(name)
def add_event(self, event: BaseEvent):
"""Add domain event to be published after commit"""
self._events.append(event)
logger.debug(f"Added event", event_type=event.event_type)
async def commit(self):
"""Commit the transaction and publish events"""
if self._committed:
logger.warning("Unit of Work already committed")
return
if self._rolled_back:
raise TransactionError("Cannot commit after rollback")
try:
await self.session.commit()
self._committed = True
# Publish events after successful commit
await self._publish_events()
logger.debug(f"Unit of Work committed successfully",
repositories=list(self._repositories.keys()),
events_published=len(self._events))
except SQLAlchemyError as e:
await self.rollback()
logger.error("Failed to commit Unit of Work", error=str(e))
raise TransactionError(f"Commit failed: {str(e)}")
async def rollback(self):
"""Rollback the transaction"""
if self._rolled_back:
logger.warning("Unit of Work already rolled back")
return
try:
await self.session.rollback()
self._rolled_back = True
self._events.clear() # Clear events on rollback
logger.debug(f"Unit of Work rolled back",
repositories=list(self._repositories.keys()))
except SQLAlchemyError as e:
logger.error("Failed to rollback Unit of Work", error=str(e))
raise TransactionError(f"Rollback failed: {str(e)}")
async def _publish_events(self):
"""Publish domain events (override in subclasses for actual publishing)"""
if not self._events:
return
# Default implementation just logs events
# Override this method in service-specific implementations
for event in self._events:
logger.info(f"Publishing event",
event_type=event.event_type,
event_data=event.to_dict())
# Clear events after publishing
self._events.clear()
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
if exc_type is not None:
# Exception occurred, rollback
await self.rollback()
return False
# No exception, auto-commit if enabled
if self.auto_commit and not self._committed:
await self.commit()
return False
class ServiceUnitOfWork(UnitOfWork):
"""
Service-specific Unit of Work with event publishing integration
Example usage with message publishing:
class AuthUnitOfWork(ServiceUnitOfWork):
def __init__(self, session: AsyncSession, message_publisher=None):
super().__init__(session)
self.message_publisher = message_publisher
async def _publish_events(self):
for event in self._events:
if self.message_publisher:
await self.message_publisher.publish(
topic="auth.events",
message=event.to_dict()
)
"""
def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False):
super().__init__(session, auto_commit)
self.event_publisher = event_publisher
async def _publish_events(self):
"""Publish events using the provided event publisher"""
if not self._events or not self.event_publisher:
return
try:
for event in self._events:
await self.event_publisher.publish(event)
logger.debug(f"Published event via publisher",
event_type=event.event_type)
self._events.clear()
except Exception as e:
logger.error("Failed to publish events", error=str(e))
# Don't raise here to avoid breaking the transaction
# Events will be retried or handled by the event publisher
# ===== TRANSACTION CONTEXT MANAGER =====
@asynccontextmanager
async def transaction_scope(session: AsyncSession, auto_commit: bool = True):
"""
Simple transaction context manager for single-repository operations
Usage:
async with transaction_scope(session) as tx_session:
user = User(name="John")
tx_session.add(user)
# Auto-commits on success, rolls back on exception
"""
try:
yield session
if auto_commit:
await session.commit()
except Exception as e:
await session.rollback()
logger.error("Transaction scope failed", error=str(e))
raise
# ===== UTILITIES =====
class RepositoryRegistry:
"""Registry for commonly used repository configurations"""
_registry: Dict[str, Dict[str, Any]] = {}
@classmethod
def register(
self,
name: str,
repository_class: Type[Repository],
model_class: Type[Model],
**kwargs
):
"""Register a repository configuration"""
self._registry[name] = {
"repository_class": repository_class,
"model_class": model_class,
"kwargs": kwargs
}
logger.debug(f"Registered repository configuration", name=name)
@classmethod
def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]:
"""Create repository instance from registry"""
config = self._registry.get(name)
if not config:
logger.warning(f"Repository configuration '{name}' not found in registry")
return None
return config["repository_class"](
config["model_class"],
session,
**config["kwargs"]
)
@classmethod
def list_registered(self) -> List[str]:
"""List all registered repository names"""
return list(self._registry.keys())
# ===== FACTORY FUNCTIONS =====
def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork:
"""Factory function to create Unit of Work instances"""
return UnitOfWork(session, **kwargs)
def create_service_unit_of_work(
session: AsyncSession,
event_publisher=None,
**kwargs
) -> ServiceUnitOfWork:
"""Factory function to create Service Unit of Work instances"""
return ServiceUnitOfWork(session, event_publisher, **kwargs)

402
shared/database/utils.py Executable file
View File

@@ -0,0 +1,402 @@
"""
Database Utilities
Helper functions for database operations and maintenance
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text, inspect
from sqlalchemy.exc import SQLAlchemyError
import structlog
from .exceptions import DatabaseError, HealthCheckError
logger = structlog.get_logger()
class DatabaseUtils:
"""Utility functions for database operations"""
@staticmethod
async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]:
"""
Comprehensive database health check
Returns:
Dict with health status, metrics, and diagnostics
"""
try:
# Basic connectivity test
start_time = __import__('time').time()
await session.execute(text("SELECT 1"))
response_time = __import__('time').time() - start_time
# Get database info
db_info = await DatabaseUtils._get_database_info(session)
# Connection pool status (if available)
pool_info = await DatabaseUtils._get_pool_info(session)
return {
"status": "healthy",
"response_time_seconds": round(response_time, 4),
"database": db_info,
"connection_pool": pool_info,
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
}
except Exception as e:
logger.error("Database health check failed", error=str(e))
raise HealthCheckError(f"Health check failed: {str(e)}")
@staticmethod
async def _get_database_info(session: AsyncSession) -> Dict[str, Any]:
"""Get database server information"""
try:
# Try to get database version and basic stats
if session.bind.dialect.name == 'postgresql':
version_result = await session.execute(text("SELECT version()"))
version = version_result.scalar()
stats_result = await session.execute(text("""
SELECT
count(*) as active_connections,
(SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections
FROM pg_stat_activity
WHERE state = 'active'
"""))
stats = stats_result.fetchone()
return {
"type": "postgresql",
"version": version,
"active_connections": stats.active_connections if stats else 0,
"max_connections": stats.max_connections if stats else "unknown"
}
elif session.bind.dialect.name == 'sqlite':
version_result = await session.execute(text("SELECT sqlite_version()"))
version = version_result.scalar()
return {
"type": "sqlite",
"version": version,
"active_connections": 1,
"max_connections": "unlimited"
}
else:
return {
"type": session.bind.dialect.name,
"version": "unknown",
"active_connections": "unknown",
"max_connections": "unknown"
}
except Exception as e:
logger.warning("Could not retrieve database info", error=str(e))
return {
"type": session.bind.dialect.name,
"version": "unknown",
"error": str(e)
}
@staticmethod
async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]:
"""Get connection pool information"""
try:
pool = session.bind.pool
if pool:
return {
"size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"status": pool.status()
}
else:
return {"status": "no_pool"}
except Exception as e:
logger.warning("Could not retrieve pool info", error=str(e))
return {"error": str(e)}
@staticmethod
async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]:
"""
Validate database schema against expected tables
Args:
session: Database session
expected_tables: List of table names that should exist
Returns:
Validation results with missing/extra tables
"""
try:
# Get existing tables
inspector = inspect(session.bind)
existing_tables = set(inspector.get_table_names())
expected_tables_set = set(expected_tables)
missing_tables = expected_tables_set - existing_tables
extra_tables = existing_tables - expected_tables_set
return {
"valid": len(missing_tables) == 0,
"existing_tables": list(existing_tables),
"expected_tables": expected_tables,
"missing_tables": list(missing_tables),
"extra_tables": list(extra_tables),
"total_tables": len(existing_tables)
}
except Exception as e:
logger.error("Schema validation failed", error=str(e))
raise DatabaseError(f"Schema validation failed: {str(e)}")
@staticmethod
async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]:
"""
Get statistics for specified tables
Args:
session: Database session
table_names: List of table names to analyze
Returns:
Dictionary with table statistics
"""
try:
stats = {}
for table_name in table_names:
if session.bind.dialect.name == 'postgresql':
# PostgreSQL specific queries
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
size_result = await session.execute(
text(f"SELECT pg_total_relation_size('{table_name}')")
)
table_size = size_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": table_size,
"size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0
}
elif session.bind.dialect.name == 'sqlite':
# SQLite specific queries
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": "unknown",
"size_mb": "unknown"
}
else:
# Generic fallback
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": "unknown",
"size_mb": "unknown"
}
return stats
except Exception as e:
logger.error("Failed to get table statistics",
tables=table_names, error=str(e))
raise DatabaseError(f"Failed to get table stats: {str(e)}")
@staticmethod
async def cleanup_old_records(
session: AsyncSession,
table_name: str,
date_column: str,
days_old: int,
batch_size: int = 1000
) -> int:
"""
Clean up old records from a table
Args:
session: Database session
table_name: Name of table to clean
date_column: Date column to filter by
days_old: Records older than this many days will be deleted
batch_size: Number of records to delete per batch
Returns:
Total number of records deleted
"""
try:
total_deleted = 0
while True:
if session.bind.dialect.name == 'postgresql':
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < NOW() - INTERVAL :days_param
AND ctid IN (
SELECT ctid FROM {table_name}
WHERE {date_column} < NOW() - INTERVAL :days_param
LIMIT :batch_size
)
""")
params = {
"days_param": f"{days_old} days",
"batch_size": batch_size
}
elif session.bind.dialect.name == 'sqlite':
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < datetime('now', :days_param)
AND rowid IN (
SELECT rowid FROM {table_name}
WHERE {date_column} < datetime('now', :days_param)
LIMIT :batch_size
)
""")
params = {
"days_param": f"-{days_old} days",
"batch_size": batch_size
}
else:
# Generic fallback (may not work for all databases)
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY)
LIMIT :batch_size
""")
params = {
"days_old": days_old,
"batch_size": batch_size
}
result = await session.execute(delete_query, params)
deleted_count = result.rowcount
if deleted_count == 0:
break
total_deleted += deleted_count
await session.commit()
logger.debug(f"Deleted batch from {table_name}",
batch_size=deleted_count,
total_deleted=total_deleted)
logger.info(f"Cleanup completed for {table_name}",
total_deleted=total_deleted,
days_old=days_old)
return total_deleted
except Exception as e:
await session.rollback()
logger.error(f"Cleanup failed for {table_name}", error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
@staticmethod
async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]:
"""
Execute database maintenance tasks
Returns:
Dictionary with maintenance results
"""
try:
results = {}
if session.bind.dialect.name == 'postgresql':
# PostgreSQL maintenance
await session.execute(text("VACUUM ANALYZE"))
results["vacuum"] = "completed"
# Update statistics
await session.execute(text("ANALYZE"))
results["analyze"] = "completed"
elif session.bind.dialect.name == 'sqlite':
# SQLite maintenance
await session.execute(text("VACUUM"))
results["vacuum"] = "completed"
await session.execute(text("ANALYZE"))
results["analyze"] = "completed"
else:
results["maintenance"] = "not_supported"
await session.commit()
logger.info("Database maintenance completed", results=results)
return results
except Exception as e:
await session.rollback()
logger.error("Database maintenance failed", error=str(e))
raise DatabaseError(f"Maintenance failed: {str(e)}")
class QueryLogger:
"""Utility for logging and analyzing database queries"""
def __init__(self, session: AsyncSession):
self.session = session
self._query_log = []
async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None):
"""Log a database query with metadata"""
log_entry = {
"query": query,
"params": params,
"execution_time": execution_time,
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
}
self._query_log.append(log_entry)
# Log slow queries
if execution_time and execution_time > 1.0: # 1 second threshold
logger.warning("Slow query detected",
query=query,
execution_time=execution_time)
def get_query_stats(self) -> Dict[str, Any]:
"""Get statistics about logged queries"""
if not self._query_log:
return {"total_queries": 0}
execution_times = [
entry["execution_time"]
for entry in self._query_log
if entry["execution_time"] is not None
]
return {
"total_queries": len(self._query_log),
"avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0,
"max_execution_time": max(execution_times) if execution_times else 0,
"slow_queries_count": len([t for t in execution_times if t > 1.0])
}
def clear_log(self):
"""Clear the query log"""
self._query_log.clear()