REFACTOR - Database logic

This commit is contained in:
Urtzi Alfaro
2025-08-08 09:08:41 +02:00
parent 0154365bfc
commit 488bb3ef93
113 changed files with 22842 additions and 6503 deletions

View File

@@ -238,7 +238,7 @@ async def verify_tenant_access_dep(
Raises:
HTTPException: If user doesn't have access to tenant
"""
has_access = await tenant_access_manager.verify_user_tenant_access(current_user["user_id"], tenant_id)
has_access = await tenant_access_manager.verify_basic_tenant_access(current_user["user_id"], tenant_id)
if not has_access:
logger.warning(f"Access denied to tenant",
user_id=current_user["user_id"],
@@ -276,7 +276,7 @@ async def verify_tenant_permission_dep(
HTTPException: If user doesn't have access or permission
"""
# First verify basic tenant access
has_access = await tenant_access_manager.verify_user_tenant_access(current_user["user_id"], tenant_id)
has_access = await tenant_access_manager.verify_basic_tenant_access(current_user["user_id"], tenant_id)
if not has_access:
raise HTTPException(
status_code=403,

390
shared/clients/README.md Normal file
View File

@@ -0,0 +1,390 @@
# Enhanced Inter-Service Communication System
This directory contains the enhanced inter-service communication system that integrates with the new repository pattern architecture. The system provides circuit breakers, caching, monitoring, and event tracking for all service-to-service communications.
## Architecture Overview
### Base Components
1. **BaseServiceClient** - Foundation class providing authentication, retries, and basic HTTP operations
2. **EnhancedServiceClient** - Adds circuit breaker, caching, and monitoring capabilities
3. **ServiceRegistry** - Central registry for managing all enhanced service clients
### Enhanced Service Clients
Each service has a specialized enhanced client:
- **EnhancedDataServiceClient** - Sales data, weather, traffic, products with optimized caching
- **EnhancedAuthServiceClient** - Authentication, user management, permissions with security focus
- **EnhancedTrainingServiceClient** - ML training, model management, deployment with pipeline monitoring
- **EnhancedForecastingServiceClient** - Forecasting, predictions, scenarios with analytics
- **EnhancedTenantServiceClient** - Tenant management, memberships, organization features
- **EnhancedNotificationServiceClient** - Notifications, templates, delivery tracking
## Key Features
### Circuit Breaker Pattern
- **States**: Closed (normal), Open (failing), Half-Open (testing recovery)
- **Configuration**: Failure threshold, recovery timeout, success threshold
- **Monitoring**: State changes tracked and logged
### Intelligent Caching
- **TTL-based**: Different cache durations for different data types
- **Invalidation**: Pattern-based cache invalidation on updates
- **Statistics**: Hit/miss ratios and performance metrics
- **Manual Control**: Clear specific cache patterns when needed
### Event Integration
- **Repository Events**: Entity created/updated/deleted events
- **Correlation IDs**: Track operations across services
- **Metadata**: Rich event metadata for debugging and monitoring
### Monitoring & Metrics
- **Request Metrics**: Success/failure rates, latencies
- **Cache Metrics**: Hit rates, entry counts
- **Circuit Breaker Metrics**: State changes, failure counts
- **Health Checks**: Per-service and aggregate health status
## Usage Examples
### Basic Usage with Service Registry
```python
from shared.clients.enhanced_service_client import ServiceRegistry
from shared.config.base import BaseServiceSettings
# Initialize registry
config = BaseServiceSettings()
registry = ServiceRegistry(config, calling_service="forecasting")
# Get enhanced clients
data_client = registry.get_data_client()
auth_client = registry.get_auth_client()
training_client = registry.get_training_client()
# Use with full features
sales_data = await data_client.get_all_sales_data_with_monitoring(
tenant_id="tenant-123",
start_date="2024-01-01",
end_date="2024-12-31",
correlation_id="forecast-job-456"
)
```
### Data Service Operations
```python
# Get sales data with intelligent caching
sales_data = await data_client.get_sales_data_cached(
tenant_id="tenant-123",
start_date="2024-01-01",
end_date="2024-01-31",
aggregation="daily"
)
# Upload sales data with cache invalidation and events
result = await data_client.upload_sales_data_with_events(
tenant_id="tenant-123",
sales_data=sales_records,
correlation_id="data-import-789"
)
# Get weather data with caching (30 min TTL)
weather_data = await data_client.get_weather_historical_cached(
tenant_id="tenant-123",
start_date="2024-01-01",
end_date="2024-01-31"
)
```
### Authentication & User Management
```python
# Authenticate with security monitoring
auth_result = await auth_client.authenticate_user_cached(
email="user@example.com",
password="password"
)
# Check permissions with caching
has_access = await auth_client.check_user_permissions_cached(
user_id="user-123",
tenant_id="tenant-456",
resource="sales_data",
action="read"
)
# Create user with events
user = await auth_client.create_user_with_events(
user_data={
"email": "new@example.com",
"name": "New User",
"role": "analyst"
},
tenant_id="tenant-123",
correlation_id="user-creation-789"
)
```
### Training & ML Operations
```python
# Create training job with monitoring
job = await training_client.create_training_job_with_monitoring(
tenant_id="tenant-123",
include_weather=True,
include_traffic=False,
min_data_points=30,
correlation_id="training-pipeline-456"
)
# Get active model with caching
model = await training_client.get_active_model_for_product_cached(
tenant_id="tenant-123",
product_name="croissants"
)
# Deploy model with events
deployment = await training_client.deploy_model_with_events(
tenant_id="tenant-123",
model_id="model-789",
correlation_id="deployment-123"
)
# Get pipeline status
status = await training_client.get_training_pipeline_status("tenant-123")
```
### Forecasting & Predictions
```python
# Create forecast with monitoring
forecast = await forecasting_client.create_forecast_with_monitoring(
tenant_id="tenant-123",
model_id="model-456",
start_date="2024-02-01",
end_date="2024-02-29",
correlation_id="forecast-creation-789"
)
# Get predictions with caching
predictions = await forecasting_client.get_predictions_cached(
tenant_id="tenant-123",
forecast_id="forecast-456",
start_date="2024-02-01",
end_date="2024-02-07"
)
# Real-time prediction with caching
prediction = await forecasting_client.create_realtime_prediction_with_monitoring(
tenant_id="tenant-123",
model_id="model-456",
target_date="2024-02-01",
features={"temperature": 20, "day_of_week": 1},
correlation_id="realtime-pred-123"
)
# Get forecasting dashboard
dashboard = await forecasting_client.get_forecasting_dashboard("tenant-123")
```
### Tenant Management
```python
# Create tenant with monitoring
tenant = await tenant_client.create_tenant_with_monitoring(
name="New Bakery Chain",
owner_id="user-123",
description="Multi-location bakery chain",
correlation_id="tenant-creation-456"
)
# Add member with events
membership = await tenant_client.add_tenant_member_with_events(
tenant_id="tenant-123",
user_id="user-456",
role="manager",
correlation_id="member-add-789"
)
# Get tenant analytics
analytics = await tenant_client.get_tenant_analytics("tenant-123")
```
### Notification Management
```python
# Send notification with monitoring
notification = await notification_client.send_notification_with_monitoring(
recipient_id="user-123",
notification_type="forecast_ready",
title="Forecast Complete",
message="Your weekly forecast is ready for review",
tenant_id="tenant-456",
priority="high",
channels=["email", "in_app"],
correlation_id="forecast-notification-789"
)
# Send bulk notification
bulk_result = await notification_client.send_bulk_notification_with_monitoring(
recipients=["user-123", "user-456", "user-789"],
notification_type="system_update",
title="System Maintenance",
message="Scheduled maintenance tonight at 2 AM",
priority="normal",
correlation_id="maintenance-notification-123"
)
# Get delivery analytics
analytics = await notification_client.get_delivery_analytics(
tenant_id="tenant-123",
start_date="2024-01-01",
end_date="2024-01-31"
)
```
## Health Monitoring
### Individual Service Health
```python
# Get specific service health
data_health = data_client.get_data_service_health()
auth_health = auth_client.get_auth_service_health()
training_health = training_client.get_training_service_health()
# Health includes:
# - Circuit breaker status
# - Cache statistics and configuration
# - Service-specific features
# - Supported endpoints
```
### Registry-Level Health
```python
# Get all service health status
all_health = registry.get_all_health_status()
# Get aggregate metrics
metrics = registry.get_aggregate_metrics()
# Returns:
# - Total cache hits/misses and hit rate
# - Circuit breaker states for all services
# - Count of healthy vs total services
```
## Configuration
### Cache TTL Configuration
Each enhanced client has optimized cache TTL values:
```python
# Data Service
sales_cache_ttl = 600 # 10 minutes
weather_cache_ttl = 1800 # 30 minutes
traffic_cache_ttl = 3600 # 1 hour
product_cache_ttl = 300 # 5 minutes
# Auth Service
user_cache_ttl = 300 # 5 minutes
token_cache_ttl = 60 # 1 minute
permission_cache_ttl = 900 # 15 minutes
# Training Service
job_cache_ttl = 180 # 3 minutes
model_cache_ttl = 600 # 10 minutes
metrics_cache_ttl = 300 # 5 minutes
# And so on...
```
### Circuit Breaker Configuration
```python
CircuitBreakerConfig(
failure_threshold=5, # Failures before opening
recovery_timeout=60, # Seconds before testing recovery
success_threshold=2, # Successes needed to close
timeout=30 # Request timeout in seconds
)
```
## Event System Integration
All enhanced clients integrate with the enhanced event system:
### Event Types
- **EntityCreatedEvent** - When entities are created
- **EntityUpdatedEvent** - When entities are modified
- **EntityDeletedEvent** - When entities are removed
### Event Metadata
- **correlation_id** - Track operations across services
- **source_service** - Service that generated the event
- **destination_service** - Target service
- **tenant_id** - Tenant context
- **user_id** - User context
- **tags** - Additional metadata
### Usage in Enhanced Clients
Events are automatically published for:
- Data uploads and modifications
- User creation/updates/deletion
- Training job lifecycle
- Model deployments
- Forecast creation
- Tenant management operations
- Notification delivery
## Error Handling & Resilience
### Circuit Breaker Protection
- Automatically stops requests when services are failing
- Provides fallback to cached data when available
- Gradually tests service recovery
### Retry Logic
- Exponential backoff for transient failures
- Configurable retry counts and delays
- Authentication token refresh on 401 errors
### Cache Fallbacks
- Returns cached data when services are unavailable
- Graceful degradation with stale data warnings
- Manual cache invalidation for data consistency
## Integration with Repository Pattern
The enhanced clients seamlessly integrate with the new repository pattern:
### Service Layer Integration
```python
class ForecastingService:
def __init__(self,
forecast_repository: ForecastRepository,
service_registry: ServiceRegistry):
self.forecast_repository = forecast_repository
self.data_client = service_registry.get_data_client()
self.training_client = service_registry.get_training_client()
async def create_forecast(self, tenant_id: str, model_id: str):
# Get data through enhanced client
sales_data = await self.data_client.get_all_sales_data_with_monitoring(
tenant_id=tenant_id,
correlation_id=f"forecast_data_{datetime.utcnow().isoformat()}"
)
# Use repository for database operations
forecast = await self.forecast_repository.create({
"tenant_id": tenant_id,
"model_id": model_id,
"status": "pending"
})
return forecast
```
This completes the comprehensive enhanced inter-service communication system that integrates seamlessly with the new repository pattern architecture, providing resilience, monitoring, and advanced features for all service interactions.

View File

@@ -0,0 +1,68 @@
"""
Shared Database Infrastructure
Provides consistent database patterns across all microservices
"""
from .base import DatabaseManager, Base, create_database_manager
from .repository import BaseRepository
from .unit_of_work import UnitOfWork, ServiceUnitOfWork, RepositoryRegistry
from .transactions import (
transactional,
unit_of_work_transactional,
managed_transaction,
managed_unit_of_work,
TransactionManager,
run_in_transaction,
run_with_unit_of_work
)
from .exceptions import (
DatabaseError,
ConnectionError,
RecordNotFoundError,
DuplicateRecordError,
ConstraintViolationError,
TransactionError,
ValidationError,
MigrationError,
HealthCheckError
)
from .utils import DatabaseUtils, QueryLogger
__all__ = [
# Core components
"DatabaseManager",
"Base",
"create_database_manager",
# Repository pattern
"BaseRepository",
# Unit of Work pattern
"UnitOfWork",
"ServiceUnitOfWork",
"RepositoryRegistry",
# Transaction management
"transactional",
"unit_of_work_transactional",
"managed_transaction",
"managed_unit_of_work",
"TransactionManager",
"run_in_transaction",
"run_with_unit_of_work",
# Exceptions
"DatabaseError",
"ConnectionError",
"RecordNotFoundError",
"DuplicateRecordError",
"ConstraintViolationError",
"TransactionError",
"ValidationError",
"MigrationError",
"HealthCheckError",
# Utilities
"DatabaseUtils",
"QueryLogger"
]

View File

@@ -1,78 +1,298 @@
"""
Base database configuration for all microservices
Enhanced Base Database Configuration for All Microservices
Provides DatabaseManager with connection pooling, health checks, and multi-database support
"""
import os
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from typing import Optional, Dict, Any, List
from sqlalchemy import create_engine, text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import StaticPool
from sqlalchemy.pool import StaticPool, QueuePool
from contextlib import asynccontextmanager
import logging
import structlog
import time
logger = logging.getLogger(__name__)
from .exceptions import DatabaseError, ConnectionError, HealthCheckError
from .utils import DatabaseUtils
logger = structlog.get_logger()
Base = declarative_base()
class DatabaseManager:
"""Database manager for microservices"""
"""Enhanced Database Manager for Microservices
def __init__(self, database_url: str):
Provides:
- Connection pooling with configurable settings
- Health checks and monitoring
- Multi-database support
- Session lifecycle management
- Background task session support
"""
def __init__(
self,
database_url: str,
service_name: str = "unknown",
pool_size: int = 20,
max_overflow: int = 30,
pool_recycle: int = 3600,
pool_pre_ping: bool = True,
echo: bool = False,
connect_timeout: int = 30,
**engine_kwargs
):
self.database_url = database_url
self.async_engine = create_async_engine(
database_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300,
pool_size=20,
max_overflow=30
)
self.service_name = service_name
self.pool_size = pool_size
self.max_overflow = max_overflow
self.async_session_local = sessionmaker(
# Configure pool class based on database type
poolclass = QueuePool
if "sqlite" in database_url.lower():
poolclass = StaticPool
pool_size = 1
max_overflow = 0
# Create async engine with enhanced configuration
engine_config = {
"echo": echo,
"pool_pre_ping": pool_pre_ping,
"pool_recycle": pool_recycle,
"pool_size": pool_size,
"max_overflow": max_overflow,
"poolclass": poolclass,
"connect_args": {"command_timeout": connect_timeout},
**engine_kwargs
}
self.async_engine = create_async_engine(database_url, **engine_config)
# Create session factory
self.async_session_local = async_sessionmaker(
self.async_engine,
class_=AsyncSession,
expire_on_commit=False
expire_on_commit=False,
autoflush=False,
autocommit=False
)
logger.info(f"DatabaseManager initialized for {service_name}",
pool_size=pool_size,
max_overflow=max_overflow,
database_type=self._get_database_type())
async def get_db(self):
"""Get database session for request handlers"""
"""Get database session for request handlers (FastAPI dependency)"""
async with self.async_session_local() as session:
try:
logger.debug("Database session created for request")
yield session
except Exception as e:
logger.error(f"Database session error: {e}")
# Don't wrap HTTPExceptions - let them pass through
if hasattr(e, 'status_code') and hasattr(e, 'detail'):
# This is likely an HTTPException - don't wrap it
await session.rollback()
raise
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
logger.error(f"Database session error: {error_msg}", service=self.service_name)
await session.rollback()
raise
# Handle specific ASGI stream issues more gracefully
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
raise DatabaseError(f"Session error: Request stream disconnected ({type(e).__name__})")
else:
raise DatabaseError(f"Session error: {error_msg}")
finally:
await session.close()
logger.debug("Database session closed")
@asynccontextmanager
async def get_background_session(self):
"""
✅ NEW: Get database session for background tasks
Get database session for background tasks with auto-commit
Usage:
async with database_manager.get_background_session() as session:
# Your background task code here
await session.commit()
# Auto-commits on success, rolls back on exception
"""
async with self.async_session_local() as session:
try:
logger.debug("Background session created", service=self.service_name)
yield session
await session.commit()
logger.debug("Background session committed")
except Exception as e:
await session.rollback()
logger.error(f"Background task database error: {e}")
raise
logger.error(f"Background task database error: {e}",
service=self.service_name)
raise DatabaseError(f"Background task failed: {str(e)}")
finally:
await session.close()
logger.debug("Background session closed")
@asynccontextmanager
async def get_session(self):
"""Get a plain database session (no auto-commit)"""
async with self.async_session_local() as session:
try:
yield session
except Exception as e:
await session.rollback()
logger.error(f"Session error: {e}", service=self.service_name)
raise DatabaseError(f"Session error: {str(e)}")
finally:
await session.close()
async def create_tables(self):
"""Create database tables"""
async with self.async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# ===== TABLE MANAGEMENT =====
async def drop_tables(self):
async def create_tables(self, metadata=None):
"""Create database tables"""
try:
target_metadata = metadata or Base.metadata
async with self.async_engine.begin() as conn:
await conn.run_sync(target_metadata.create_all)
logger.info("Database tables created successfully", service=self.service_name)
except Exception as e:
logger.error(f"Failed to create tables: {e}", service=self.service_name)
raise DatabaseError(f"Table creation failed: {str(e)}")
async def drop_tables(self, metadata=None):
"""Drop database tables"""
async with self.async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
try:
target_metadata = metadata or Base.metadata
async with self.async_engine.begin() as conn:
await conn.run_sync(target_metadata.drop_all)
logger.info("Database tables dropped successfully", service=self.service_name)
except Exception as e:
logger.error(f"Failed to drop tables: {e}", service=self.service_name)
raise DatabaseError(f"Table drop failed: {str(e)}")
# ===== HEALTH CHECKS AND MONITORING =====
async def health_check(self) -> Dict[str, Any]:
"""Comprehensive health check for the database"""
try:
async with self.get_session() as session:
return await DatabaseUtils.execute_health_check(session)
except Exception as e:
logger.error(f"Health check failed: {e}", service=self.service_name)
raise HealthCheckError(f"Health check failed: {str(e)}")
async def get_connection_info(self) -> Dict[str, Any]:
"""Get database connection information"""
try:
pool = self.async_engine.pool
return {
"service_name": self.service_name,
"database_type": self._get_database_type(),
"pool_size": self.pool_size,
"max_overflow": self.max_overflow,
"current_checked_in": pool.checkedin() if pool else 0,
"current_checked_out": pool.checkedout() if pool else 0,
"current_overflow": pool.overflow() if pool else 0,
"invalid_connections": pool.invalid() if pool else 0
}
except Exception as e:
logger.error(f"Failed to get connection info: {e}", service=self.service_name)
return {"error": str(e)}
def _get_database_type(self) -> str:
"""Get database type from URL"""
return self.database_url.split("://")[0].lower() if "://" in self.database_url else "unknown"
# ===== CLEANUP AND MAINTENANCE =====
async def close_connections(self):
"""Close all database connections"""
try:
await self.async_engine.dispose()
logger.info("Database connections closed", service=self.service_name)
except Exception as e:
logger.error(f"Failed to close connections: {e}", service=self.service_name)
raise DatabaseError(f"Connection cleanup failed: {str(e)}")
async def execute_maintenance(self) -> Dict[str, Any]:
"""Execute database maintenance tasks"""
try:
async with self.get_session() as session:
return await DatabaseUtils.execute_maintenance(session)
except Exception as e:
logger.error(f"Maintenance failed: {e}", service=self.service_name)
raise DatabaseError(f"Maintenance failed: {str(e)}")
# ===== UTILITY METHODS =====
async def test_connection(self) -> bool:
"""Test database connectivity"""
try:
async with self.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Connection test successful", service=self.service_name)
return True
except Exception as e:
logger.error(f"Connection test failed: {e}", service=self.service_name)
return False
def __repr__(self) -> str:
return f"DatabaseManager(service='{self.service_name}', type='{self._get_database_type()}')"
# ===== CONVENIENCE FUNCTIONS =====
# ===== CONVENIENCE FUNCTIONS =====
def create_database_manager(
database_url: str,
service_name: str,
**kwargs
) -> DatabaseManager:
"""Factory function to create DatabaseManager instances"""
return DatabaseManager(database_url, service_name, **kwargs)
# ===== LEGACY COMPATIBILITY =====
# Keep backward compatibility for existing code
engine = None
AsyncSessionLocal = None
def init_legacy_compatibility(database_url: str):
"""Initialize legacy global variables for backward compatibility"""
global engine, AsyncSessionLocal
engine = create_async_engine(
database_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300,
pool_size=20,
max_overflow=30
)
AsyncSessionLocal = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False
)
logger.warning("Using legacy database configuration - consider migrating to DatabaseManager")
async def get_legacy_db():
"""Legacy database session getter for backward compatibility"""
if not AsyncSessionLocal:
raise RuntimeError("Legacy database not initialized - call init_legacy_compatibility first")
async with AsyncSessionLocal() as session:
try:
yield session
except Exception as e:
logger.error(f"Legacy database session error: {e}")
await session.rollback()
raise
finally:
await session.close()

View File

@@ -0,0 +1,78 @@
"""
Base database configuration for all microservices
"""
import os
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker, declarative_base
from sqlalchemy.pool import StaticPool
from contextlib import asynccontextmanager
import logging
logger = logging.getLogger(__name__)
Base = declarative_base()
class DatabaseManager:
"""Database manager for microservices"""
def __init__(self, database_url: str):
self.database_url = database_url
self.async_engine = create_async_engine(
database_url,
echo=False,
pool_pre_ping=True,
pool_recycle=300,
pool_size=20,
max_overflow=30
)
self.async_session_local = sessionmaker(
self.async_engine,
class_=AsyncSession,
expire_on_commit=False
)
async def get_db(self):
"""Get database session for request handlers"""
async with self.async_session_local() as session:
try:
yield session
except Exception as e:
logger.error(f"Database session error: {e}")
await session.rollback()
raise
finally:
await session.close()
@asynccontextmanager
async def get_background_session(self):
"""
✅ NEW: Get database session for background tasks
Usage:
async with database_manager.get_background_session() as session:
# Your background task code here
await session.commit()
"""
async with self.async_session_local() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
logger.error(f"Background task database error: {e}")
raise
finally:
await session.close()
async def create_tables(self):
"""Create database tables"""
async with self.async_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
async def drop_tables(self):
"""Drop database tables"""
async with self.async_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)

View File

@@ -0,0 +1,52 @@
"""
Custom Database Exceptions
Provides consistent error handling across all microservices
"""
class DatabaseError(Exception):
"""Base exception for database-related errors"""
def __init__(self, message: str, details: dict = None):
self.message = message
self.details = details or {}
super().__init__(self.message)
class ConnectionError(DatabaseError):
"""Raised when database connection fails"""
pass
class RecordNotFoundError(DatabaseError):
"""Raised when a requested record is not found"""
pass
class DuplicateRecordError(DatabaseError):
"""Raised when trying to create a duplicate record"""
pass
class ConstraintViolationError(DatabaseError):
"""Raised when database constraints are violated"""
pass
class TransactionError(DatabaseError):
"""Raised when transaction operations fail"""
pass
class ValidationError(DatabaseError):
"""Raised when data validation fails before database operations"""
pass
class MigrationError(DatabaseError):
"""Raised when database migration operations fail"""
pass
class HealthCheckError(DatabaseError):
"""Raised when database health checks fail"""
pass

View File

@@ -0,0 +1,422 @@
"""
Base Repository Pattern for Database Operations
Provides generic CRUD operations, query building, and caching
"""
from typing import Optional, List, Dict, Any, TypeVar, Generic, Type, Union
from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import declarative_base
from sqlalchemy import select, update, delete, and_, or_, desc, asc, func, text
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from contextlib import asynccontextmanager
import structlog
from .exceptions import (
DatabaseError,
RecordNotFoundError,
DuplicateRecordError,
ConstraintViolationError
)
logger = structlog.get_logger()
# Type variables for generic repository
Model = TypeVar('Model', bound=declarative_base())
CreateSchema = TypeVar('CreateSchema')
UpdateSchema = TypeVar('UpdateSchema')
class BaseRepository(Generic[Model, CreateSchema, UpdateSchema], ABC):
"""
Base repository providing generic CRUD operations
Args:
model: SQLAlchemy model class
session: Database session
cache_ttl: Cache time-to-live in seconds (optional)
"""
def __init__(self, model: Type[Model], session: AsyncSession, cache_ttl: Optional[int] = None):
self.model = model
self.session = session
self.cache_ttl = cache_ttl
self._cache = {} if cache_ttl else None
# ===== CORE CRUD OPERATIONS =====
async def create(self, obj_in: CreateSchema, **kwargs) -> Model:
"""Create a new record"""
try:
# Convert schema to dict if needed
if hasattr(obj_in, 'model_dump'):
obj_data = obj_in.model_dump()
elif hasattr(obj_in, 'dict'):
obj_data = obj_in.dict()
else:
obj_data = obj_in
# Merge with additional kwargs
obj_data.update(kwargs)
db_obj = self.model(**obj_data)
self.session.add(db_obj)
await self.session.flush() # Get ID without committing
await self.session.refresh(db_obj)
logger.debug(f"Created {self.model.__name__}", record_id=getattr(db_obj, 'id', None))
return db_obj
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error creating {self.model.__name__}", error=str(e))
raise DuplicateRecordError(f"Record with provided data already exists")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error creating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to create record: {str(e)}")
async def get_by_id(self, record_id: Any) -> Optional[Model]:
"""Get record by ID with optional caching"""
cache_key = f"{self.model.__name__}:{record_id}"
# Check cache first
if self._cache and cache_key in self._cache:
logger.debug(f"Cache hit for {cache_key}")
return self._cache[cache_key]
try:
result = await self.session.execute(
select(self.model).where(self.model.id == record_id)
)
record = result.scalar_one_or_none()
# Cache the result
if self._cache and record:
self._cache[cache_key] = record
return record
except SQLAlchemyError as e:
logger.error(f"Database error getting {self.model.__name__} by ID",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to get record: {str(e)}")
async def get_by_field(self, field_name: str, value: Any) -> Optional[Model]:
"""Get record by specific field"""
try:
result = await self.session.execute(
select(self.model).where(getattr(self.model, field_name) == value)
)
return result.scalar_one_or_none()
except AttributeError:
raise ValueError(f"Field '{field_name}' not found in {self.model.__name__}")
except SQLAlchemyError as e:
logger.error(f"Database error getting {self.model.__name__} by {field_name}",
value=value, error=str(e))
raise DatabaseError(f"Failed to get record: {str(e)}")
async def get_multi(
self,
skip: int = 0,
limit: int = 100,
order_by: Optional[str] = None,
order_desc: bool = False,
filters: Optional[Dict[str, Any]] = None
) -> List[Model]:
"""Get multiple records with pagination, sorting, and filtering"""
try:
query = select(self.model)
# Apply filters
if filters:
conditions = []
for field, value in filters.items():
if hasattr(self.model, field):
if isinstance(value, list):
conditions.append(getattr(self.model, field).in_(value))
else:
conditions.append(getattr(self.model, field) == value)
if conditions:
query = query.where(and_(*conditions))
# Apply ordering
if order_by and hasattr(self.model, order_by):
order_field = getattr(self.model, order_by)
if order_desc:
query = query.order_by(desc(order_field))
else:
query = query.order_by(asc(order_field))
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except SQLAlchemyError as e:
logger.error(f"Database error getting multiple {self.model.__name__} records",
error=str(e))
raise DatabaseError(f"Failed to get records: {str(e)}")
async def update(self, record_id: Any, obj_in: UpdateSchema, **kwargs) -> Optional[Model]:
"""Update record by ID"""
try:
# Convert schema to dict if needed
if hasattr(obj_in, 'model_dump'):
update_data = obj_in.model_dump(exclude_unset=True)
elif hasattr(obj_in, 'dict'):
update_data = obj_in.dict(exclude_unset=True)
else:
update_data = obj_in
# Merge with additional kwargs
update_data.update(kwargs)
# Remove None values
update_data = {k: v for k, v in update_data.items() if v is not None}
if not update_data:
logger.warning(f"No data to update for {self.model.__name__}", record_id=record_id)
return await self.get_by_id(record_id)
# Perform update
result = await self.session.execute(
update(self.model)
.where(self.model.id == record_id)
.values(**update_data)
.returning(self.model)
)
updated_record = result.scalar_one_or_none()
if not updated_record:
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
# Clear cache
if self._cache:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Updated {self.model.__name__}", record_id=record_id)
return updated_record
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error updating {self.model.__name__}",
record_id=record_id, error=str(e))
raise ConstraintViolationError(f"Update violates database constraints")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error updating {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to update record: {str(e)}")
async def delete(self, record_id: Any) -> bool:
"""Delete record by ID"""
try:
result = await self.session.execute(
delete(self.model).where(self.model.id == record_id)
)
deleted_count = result.rowcount
if deleted_count == 0:
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
# Clear cache
if self._cache:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Deleted {self.model.__name__}", record_id=record_id)
return True
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error deleting {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to delete record: {str(e)}")
# ===== ADVANCED QUERY OPERATIONS =====
async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
"""Count records with optional filters"""
try:
query = select(func.count(self.model.id))
if filters:
conditions = []
for field, value in filters.items():
if hasattr(self.model, field):
if isinstance(value, list):
conditions.append(getattr(self.model, field).in_(value))
else:
conditions.append(getattr(self.model, field) == value)
if conditions:
query = query.where(and_(*conditions))
result = await self.session.execute(query)
return result.scalar() or 0
except SQLAlchemyError as e:
logger.error(f"Database error counting {self.model.__name__} records", error=str(e))
raise DatabaseError(f"Failed to count records: {str(e)}")
async def exists(self, record_id: Any) -> bool:
"""Check if record exists by ID"""
try:
result = await self.session.execute(
select(func.count(self.model.id)).where(self.model.id == record_id)
)
count = result.scalar() or 0
return count > 0
except SQLAlchemyError as e:
logger.error(f"Database error checking existence of {self.model.__name__}",
record_id=record_id, error=str(e))
raise DatabaseError(f"Failed to check record existence: {str(e)}")
async def bulk_create(self, objects: List[CreateSchema]) -> List[Model]:
"""Create multiple records in bulk"""
try:
if not objects:
return []
db_objects = []
for obj_in in objects:
if hasattr(obj_in, 'model_dump'):
obj_data = obj_in.model_dump()
elif hasattr(obj_in, 'dict'):
obj_data = obj_in.dict()
else:
obj_data = obj_in
db_objects.append(self.model(**obj_data))
self.session.add_all(db_objects)
await self.session.flush()
for db_obj in db_objects:
await self.session.refresh(db_obj)
logger.debug(f"Bulk created {len(db_objects)} {self.model.__name__} records")
return db_objects
except IntegrityError as e:
await self.session.rollback()
logger.error(f"Integrity error bulk creating {self.model.__name__}", error=str(e))
raise DuplicateRecordError(f"One or more records already exist")
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error bulk creating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to create records: {str(e)}")
async def bulk_update(self, updates: List[Dict[str, Any]]) -> int:
"""Update multiple records in bulk"""
try:
if not updates:
return 0
# Group updates by fields being updated for efficiency
for update_data in updates:
if 'id' not in update_data:
raise ValueError("Each update must include 'id' field")
record_id = update_data.pop('id')
await self.session.execute(
update(self.model)
.where(self.model.id == record_id)
.values(**update_data)
)
# Clear relevant cache entries
if self._cache:
for update_data in updates:
record_id = update_data.get('id')
if record_id:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
logger.debug(f"Bulk updated {len(updates)} {self.model.__name__} records")
return len(updates)
except SQLAlchemyError as e:
await self.session.rollback()
logger.error(f"Database error bulk updating {self.model.__name__}", error=str(e))
raise DatabaseError(f"Failed to update records: {str(e)}")
# ===== SEARCH AND QUERY BUILDING =====
async def search(
self,
search_term: str,
search_fields: List[str],
skip: int = 0,
limit: int = 100
) -> List[Model]:
"""Search records across multiple fields"""
try:
conditions = []
for field in search_fields:
if hasattr(self.model, field):
field_obj = getattr(self.model, field)
# Case-insensitive partial match
conditions.append(field_obj.ilike(f"%{search_term}%"))
if not conditions:
logger.warning(f"No valid search fields provided for {self.model.__name__}")
return []
query = select(self.model).where(or_(*conditions)).offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except SQLAlchemyError as e:
logger.error(f"Database error searching {self.model.__name__}",
search_term=search_term, error=str(e))
raise DatabaseError(f"Failed to search records: {str(e)}")
async def execute_raw_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
"""Execute raw SQL query (use with caution)"""
try:
result = await self.session.execute(text(query), params or {})
return result
except SQLAlchemyError as e:
logger.error(f"Database error executing raw query", query=query, error=str(e))
raise DatabaseError(f"Failed to execute query: {str(e)}")
# ===== CACHE MANAGEMENT =====
def clear_cache(self, record_id: Optional[Any] = None):
"""Clear cache for specific record or all records"""
if not self._cache:
return
if record_id:
cache_key = f"{self.model.__name__}:{record_id}"
self._cache.pop(cache_key, None)
else:
# Clear all cache entries for this model
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{self.model.__name__}:")]
for key in keys_to_remove:
self._cache.pop(key, None)
logger.debug(f"Cleared cache for {self.model.__name__}", record_id=record_id)
# ===== CONTEXT MANAGERS =====
@asynccontextmanager
async def transaction(self):
"""Context manager for explicit transaction handling"""
try:
yield self.session
await self.session.commit()
except Exception as e:
await self.session.rollback()
logger.error(f"Transaction failed for {self.model.__name__}", error=str(e))
raise

View File

@@ -0,0 +1,306 @@
"""
Transaction Decorators and Context Managers
Provides convenient transaction handling for service methods
"""
from functools import wraps
from typing import Callable, Any, Optional
from contextlib import asynccontextmanager
import structlog
from .base import DatabaseManager
from .unit_of_work import UnitOfWork
from .exceptions import TransactionError
logger = structlog.get_logger()
def transactional(database_manager: DatabaseManager, auto_commit: bool = True):
"""
Decorator that wraps a method in a database transaction
Args:
database_manager: DatabaseManager instance
auto_commit: Whether to auto-commit on success
Usage:
@transactional(database_manager)
async def create_user_with_profile(self, user_data, profile_data):
# Your business logic here
# Transaction is automatically managed
pass
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
async with database_manager.get_background_session() as session:
try:
# Inject session into kwargs if not present
if 'session' not in kwargs:
kwargs['session'] = session
result = await func(*args, **kwargs)
# Session is auto-committed by get_background_session
logger.debug(f"Transaction completed successfully for {func.__name__}")
return result
except Exception as e:
# Session is auto-rolled back by get_background_session
logger.error(f"Transaction failed for {func.__name__}", error=str(e))
raise TransactionError(f"Transaction failed: {str(e)}")
return wrapper
return decorator
def unit_of_work_transactional(database_manager: DatabaseManager):
"""
Decorator that provides Unit of Work pattern for complex operations
Usage:
@unit_of_work_transactional(database_manager)
async def complex_business_operation(self, data, uow: UnitOfWork):
user_repo = uow.register_repository("users", UserRepository, User)
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
user = await user_repo.create(data.user)
sale = await sales_repo.create(data.sale)
# UnitOfWork automatically commits
return {"user": user, "sale": sale}
"""
def decorator(func: Callable) -> Callable:
@wraps(func)
async def wrapper(*args, **kwargs):
async with database_manager.get_background_session() as session:
async with UnitOfWork(session, auto_commit=True) as uow:
try:
# Inject UnitOfWork into kwargs
kwargs['uow'] = uow
result = await func(*args, **kwargs)
logger.debug(f"Unit of Work transaction completed for {func.__name__}")
return result
except Exception as e:
logger.error(f"Unit of Work transaction failed for {func.__name__}",
error=str(e))
raise TransactionError(f"Transaction failed: {str(e)}")
return wrapper
return decorator
@asynccontextmanager
async def managed_transaction(database_manager: DatabaseManager):
"""
Context manager for explicit transaction control
Usage:
async with managed_transaction(database_manager) as session:
# Your database operations here
user = User(name="John")
session.add(user)
# Auto-commits on exit, rolls back on exception
"""
async with database_manager.get_background_session() as session:
try:
logger.debug("Starting managed transaction")
yield session
logger.debug("Managed transaction completed successfully")
except Exception as e:
logger.error("Managed transaction failed", error=str(e))
raise
@asynccontextmanager
async def managed_unit_of_work(database_manager: DatabaseManager, event_publisher=None):
"""
Context manager for explicit Unit of Work control
Usage:
async with managed_unit_of_work(database_manager) as uow:
user_repo = uow.register_repository("users", UserRepository, User)
user = await user_repo.create(user_data)
await uow.commit()
"""
async with database_manager.get_background_session() as session:
uow = UnitOfWork(session)
try:
logger.debug("Starting managed Unit of Work")
yield uow
if not uow._committed:
await uow.commit()
logger.debug("Managed Unit of Work completed successfully")
except Exception as e:
if not uow._rolled_back:
await uow.rollback()
logger.error("Managed Unit of Work failed", error=str(e))
raise
class TransactionManager:
"""
Advanced transaction manager for complex scenarios
Usage:
tx_manager = TransactionManager(database_manager)
async with tx_manager.create_transaction() as tx:
await tx.execute_in_transaction(my_business_logic, data)
"""
def __init__(self, database_manager: DatabaseManager):
self.database_manager = database_manager
@asynccontextmanager
async def create_transaction(self, isolation_level: Optional[str] = None):
"""Create a transaction with optional isolation level"""
async with self.database_manager.get_background_session() as session:
transaction_context = TransactionContext(session, isolation_level)
try:
yield transaction_context
except Exception as e:
logger.error("Transaction manager failed", error=str(e))
raise
async def execute_with_retry(
self,
func: Callable,
max_retries: int = 3,
*args,
**kwargs
):
"""Execute function with transaction retry on failure"""
last_error = None
for attempt in range(max_retries):
try:
async with managed_transaction(self.database_manager) as session:
kwargs['session'] = session
result = await func(*args, **kwargs)
logger.debug(f"Transaction succeeded on attempt {attempt + 1}")
return result
except Exception as e:
last_error = e
logger.warning(f"Transaction attempt {attempt + 1} failed",
error=str(e), remaining_attempts=max_retries - attempt - 1)
if attempt == max_retries - 1:
break
logger.error(f"All transaction attempts failed after {max_retries} tries")
raise TransactionError(f"Transaction failed after {max_retries} retries: {str(last_error)}")
class TransactionContext:
"""Context for managing individual transactions"""
def __init__(self, session, isolation_level: Optional[str] = None):
self.session = session
self.isolation_level = isolation_level
async def execute_in_transaction(self, func: Callable, *args, **kwargs):
"""Execute function within the transaction context"""
try:
kwargs['session'] = self.session
result = await func(*args, **kwargs)
return result
except Exception as e:
logger.error("Function execution failed in transaction context", error=str(e))
raise
# ===== UTILITY FUNCTIONS =====
async def run_in_transaction(database_manager: DatabaseManager, func: Callable, *args, **kwargs):
"""
Utility function to run any async function in a transaction
Usage:
result = await run_in_transaction(
database_manager,
my_async_function,
arg1, arg2,
kwarg1="value"
)
"""
async with managed_transaction(database_manager) as session:
kwargs['session'] = session
return await func(*args, **kwargs)
async def run_with_unit_of_work(
database_manager: DatabaseManager,
func: Callable,
*args,
**kwargs
):
"""
Utility function to run any async function with Unit of Work
Usage:
result = await run_with_unit_of_work(
database_manager,
my_complex_function,
arg1, arg2
)
"""
async with managed_unit_of_work(database_manager) as uow:
kwargs['uow'] = uow
return await func(*args, **kwargs)
# ===== BATCH OPERATIONS =====
@asynccontextmanager
async def batch_operation(database_manager: DatabaseManager, batch_size: int = 1000):
"""
Context manager for batch operations with automatic commit batching
Usage:
async with batch_operation(database_manager, batch_size=500) as batch:
for item in large_dataset:
await batch.add_operation(create_record, item)
"""
async with database_manager.get_background_session() as session:
batch_context = BatchOperationContext(session, batch_size)
try:
yield batch_context
await batch_context.flush_remaining()
except Exception as e:
logger.error("Batch operation failed", error=str(e))
raise
class BatchOperationContext:
"""Context for managing batch database operations"""
def __init__(self, session, batch_size: int):
self.session = session
self.batch_size = batch_size
self.operation_count = 0
async def add_operation(self, func: Callable, *args, **kwargs):
"""Add operation to batch"""
kwargs['session'] = self.session
await func(*args, **kwargs)
self.operation_count += 1
if self.operation_count >= self.batch_size:
await self.session.commit()
self.operation_count = 0
logger.debug(f"Batch committed at {self.batch_size} operations")
async def flush_remaining(self):
"""Commit any remaining operations"""
if self.operation_count > 0:
await self.session.commit()
logger.debug(f"Final batch committed with {self.operation_count} operations")

View File

@@ -0,0 +1,304 @@
"""
Unit of Work Pattern Implementation
Manages transactions across multiple repositories with event publishing
"""
from typing import Dict, Any, List, Optional, Type, TypeVar, Generic
from contextlib import asynccontextmanager
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.exc import SQLAlchemyError
from abc import ABC, abstractmethod
import structlog
from .repository import BaseRepository
from .exceptions import TransactionError
logger = structlog.get_logger()
Model = TypeVar('Model')
Repository = TypeVar('Repository', bound=BaseRepository)
class BaseEvent(ABC):
"""Base class for domain events"""
def __init__(self, event_type: str, data: Dict[str, Any]):
self.event_type = event_type
self.data = data
@abstractmethod
def to_dict(self) -> Dict[str, Any]:
"""Convert event to dictionary for publishing"""
pass
class DomainEvent(BaseEvent):
"""Standard domain event implementation"""
def to_dict(self) -> Dict[str, Any]:
return {
"event_type": self.event_type,
"data": self.data
}
class UnitOfWork:
"""
Unit of Work pattern for managing transactions and coordinating repositories
Usage:
async with UnitOfWork(session) as uow:
user_repo = uow.register_repository("users", UserRepository, User)
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
user = await user_repo.create(user_data)
sale = await sales_repo.create(sales_data)
await uow.commit()
"""
def __init__(self, session: AsyncSession, auto_commit: bool = False):
self.session = session
self.auto_commit = auto_commit
self._repositories: Dict[str, BaseRepository] = {}
self._events: List[BaseEvent] = []
self._committed = False
self._rolled_back = False
def register_repository(
self,
name: str,
repository_class: Type[Repository],
model_class: Type[Model],
**kwargs
) -> Repository:
"""
Register a repository with the unit of work
Args:
name: Unique name for the repository
repository_class: Repository class to instantiate
model_class: SQLAlchemy model class
**kwargs: Additional arguments for repository
Returns:
Instantiated repository
"""
if name in self._repositories:
logger.warning(f"Repository '{name}' already registered, returning existing instance")
return self._repositories[name]
repository = repository_class(model_class, self.session, **kwargs)
self._repositories[name] = repository
logger.debug(f"Registered repository", name=name, model=model_class.__name__)
return repository
def get_repository(self, name: str) -> Optional[Repository]:
"""Get registered repository by name"""
return self._repositories.get(name)
def add_event(self, event: BaseEvent):
"""Add domain event to be published after commit"""
self._events.append(event)
logger.debug(f"Added event", event_type=event.event_type)
async def commit(self):
"""Commit the transaction and publish events"""
if self._committed:
logger.warning("Unit of Work already committed")
return
if self._rolled_back:
raise TransactionError("Cannot commit after rollback")
try:
await self.session.commit()
self._committed = True
# Publish events after successful commit
await self._publish_events()
logger.debug(f"Unit of Work committed successfully",
repositories=list(self._repositories.keys()),
events_published=len(self._events))
except SQLAlchemyError as e:
await self.rollback()
logger.error("Failed to commit Unit of Work", error=str(e))
raise TransactionError(f"Commit failed: {str(e)}")
async def rollback(self):
"""Rollback the transaction"""
if self._rolled_back:
logger.warning("Unit of Work already rolled back")
return
try:
await self.session.rollback()
self._rolled_back = True
self._events.clear() # Clear events on rollback
logger.debug(f"Unit of Work rolled back",
repositories=list(self._repositories.keys()))
except SQLAlchemyError as e:
logger.error("Failed to rollback Unit of Work", error=str(e))
raise TransactionError(f"Rollback failed: {str(e)}")
async def _publish_events(self):
"""Publish domain events (override in subclasses for actual publishing)"""
if not self._events:
return
# Default implementation just logs events
# Override this method in service-specific implementations
for event in self._events:
logger.info(f"Publishing event",
event_type=event.event_type,
event_data=event.to_dict())
# Clear events after publishing
self._events.clear()
async def __aenter__(self):
"""Async context manager entry"""
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
if exc_type is not None:
# Exception occurred, rollback
await self.rollback()
return False
# No exception, auto-commit if enabled
if self.auto_commit and not self._committed:
await self.commit()
return False
class ServiceUnitOfWork(UnitOfWork):
"""
Service-specific Unit of Work with event publishing integration
Example usage with message publishing:
class AuthUnitOfWork(ServiceUnitOfWork):
def __init__(self, session: AsyncSession, message_publisher=None):
super().__init__(session)
self.message_publisher = message_publisher
async def _publish_events(self):
for event in self._events:
if self.message_publisher:
await self.message_publisher.publish(
topic="auth.events",
message=event.to_dict()
)
"""
def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False):
super().__init__(session, auto_commit)
self.event_publisher = event_publisher
async def _publish_events(self):
"""Publish events using the provided event publisher"""
if not self._events or not self.event_publisher:
return
try:
for event in self._events:
await self.event_publisher.publish(event)
logger.debug(f"Published event via publisher",
event_type=event.event_type)
self._events.clear()
except Exception as e:
logger.error("Failed to publish events", error=str(e))
# Don't raise here to avoid breaking the transaction
# Events will be retried or handled by the event publisher
# ===== TRANSACTION CONTEXT MANAGER =====
@asynccontextmanager
async def transaction_scope(session: AsyncSession, auto_commit: bool = True):
"""
Simple transaction context manager for single-repository operations
Usage:
async with transaction_scope(session) as tx_session:
user = User(name="John")
tx_session.add(user)
# Auto-commits on success, rolls back on exception
"""
try:
yield session
if auto_commit:
await session.commit()
except Exception as e:
await session.rollback()
logger.error("Transaction scope failed", error=str(e))
raise
# ===== UTILITIES =====
class RepositoryRegistry:
"""Registry for commonly used repository configurations"""
_registry: Dict[str, Dict[str, Any]] = {}
@classmethod
def register(
self,
name: str,
repository_class: Type[Repository],
model_class: Type[Model],
**kwargs
):
"""Register a repository configuration"""
self._registry[name] = {
"repository_class": repository_class,
"model_class": model_class,
"kwargs": kwargs
}
logger.debug(f"Registered repository configuration", name=name)
@classmethod
def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]:
"""Create repository instance from registry"""
config = self._registry.get(name)
if not config:
logger.warning(f"Repository configuration '{name}' not found in registry")
return None
return config["repository_class"](
config["model_class"],
session,
**config["kwargs"]
)
@classmethod
def list_registered(self) -> List[str]:
"""List all registered repository names"""
return list(self._registry.keys())
# ===== FACTORY FUNCTIONS =====
def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork:
"""Factory function to create Unit of Work instances"""
return UnitOfWork(session, **kwargs)
def create_service_unit_of_work(
session: AsyncSession,
event_publisher=None,
**kwargs
) -> ServiceUnitOfWork:
"""Factory function to create Service Unit of Work instances"""
return ServiceUnitOfWork(session, event_publisher, **kwargs)

402
shared/database/utils.py Normal file
View File

@@ -0,0 +1,402 @@
"""
Database Utilities
Helper functions for database operations and maintenance
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import text, inspect
from sqlalchemy.exc import SQLAlchemyError
import structlog
from .exceptions import DatabaseError, HealthCheckError
logger = structlog.get_logger()
class DatabaseUtils:
"""Utility functions for database operations"""
@staticmethod
async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]:
"""
Comprehensive database health check
Returns:
Dict with health status, metrics, and diagnostics
"""
try:
# Basic connectivity test
start_time = __import__('time').time()
await session.execute(text("SELECT 1"))
response_time = __import__('time').time() - start_time
# Get database info
db_info = await DatabaseUtils._get_database_info(session)
# Connection pool status (if available)
pool_info = await DatabaseUtils._get_pool_info(session)
return {
"status": "healthy",
"response_time_seconds": round(response_time, 4),
"database": db_info,
"connection_pool": pool_info,
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
}
except Exception as e:
logger.error("Database health check failed", error=str(e))
raise HealthCheckError(f"Health check failed: {str(e)}")
@staticmethod
async def _get_database_info(session: AsyncSession) -> Dict[str, Any]:
"""Get database server information"""
try:
# Try to get database version and basic stats
if session.bind.dialect.name == 'postgresql':
version_result = await session.execute(text("SELECT version()"))
version = version_result.scalar()
stats_result = await session.execute(text("""
SELECT
count(*) as active_connections,
(SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections
FROM pg_stat_activity
WHERE state = 'active'
"""))
stats = stats_result.fetchone()
return {
"type": "postgresql",
"version": version,
"active_connections": stats.active_connections if stats else 0,
"max_connections": stats.max_connections if stats else "unknown"
}
elif session.bind.dialect.name == 'sqlite':
version_result = await session.execute(text("SELECT sqlite_version()"))
version = version_result.scalar()
return {
"type": "sqlite",
"version": version,
"active_connections": 1,
"max_connections": "unlimited"
}
else:
return {
"type": session.bind.dialect.name,
"version": "unknown",
"active_connections": "unknown",
"max_connections": "unknown"
}
except Exception as e:
logger.warning("Could not retrieve database info", error=str(e))
return {
"type": session.bind.dialect.name,
"version": "unknown",
"error": str(e)
}
@staticmethod
async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]:
"""Get connection pool information"""
try:
pool = session.bind.pool
if pool:
return {
"size": pool.size(),
"checked_in": pool.checkedin(),
"checked_out": pool.checkedout(),
"overflow": pool.overflow(),
"invalid": pool.invalid()
}
else:
return {"status": "no_pool"}
except Exception as e:
logger.warning("Could not retrieve pool info", error=str(e))
return {"error": str(e)}
@staticmethod
async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]:
"""
Validate database schema against expected tables
Args:
session: Database session
expected_tables: List of table names that should exist
Returns:
Validation results with missing/extra tables
"""
try:
# Get existing tables
inspector = inspect(session.bind)
existing_tables = set(inspector.get_table_names())
expected_tables_set = set(expected_tables)
missing_tables = expected_tables_set - existing_tables
extra_tables = existing_tables - expected_tables_set
return {
"valid": len(missing_tables) == 0,
"existing_tables": list(existing_tables),
"expected_tables": expected_tables,
"missing_tables": list(missing_tables),
"extra_tables": list(extra_tables),
"total_tables": len(existing_tables)
}
except Exception as e:
logger.error("Schema validation failed", error=str(e))
raise DatabaseError(f"Schema validation failed: {str(e)}")
@staticmethod
async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]:
"""
Get statistics for specified tables
Args:
session: Database session
table_names: List of table names to analyze
Returns:
Dictionary with table statistics
"""
try:
stats = {}
for table_name in table_names:
if session.bind.dialect.name == 'postgresql':
# PostgreSQL specific queries
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
size_result = await session.execute(
text(f"SELECT pg_total_relation_size('{table_name}')")
)
table_size = size_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": table_size,
"size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0
}
elif session.bind.dialect.name == 'sqlite':
# SQLite specific queries
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": "unknown",
"size_mb": "unknown"
}
else:
# Generic fallback
count_result = await session.execute(
text(f"SELECT COUNT(*) FROM {table_name}")
)
row_count = count_result.scalar()
stats[table_name] = {
"row_count": row_count,
"size_bytes": "unknown",
"size_mb": "unknown"
}
return stats
except Exception as e:
logger.error("Failed to get table statistics",
tables=table_names, error=str(e))
raise DatabaseError(f"Failed to get table stats: {str(e)}")
@staticmethod
async def cleanup_old_records(
session: AsyncSession,
table_name: str,
date_column: str,
days_old: int,
batch_size: int = 1000
) -> int:
"""
Clean up old records from a table
Args:
session: Database session
table_name: Name of table to clean
date_column: Date column to filter by
days_old: Records older than this many days will be deleted
batch_size: Number of records to delete per batch
Returns:
Total number of records deleted
"""
try:
total_deleted = 0
while True:
if session.bind.dialect.name == 'postgresql':
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < NOW() - INTERVAL :days_param
AND ctid IN (
SELECT ctid FROM {table_name}
WHERE {date_column} < NOW() - INTERVAL :days_param
LIMIT :batch_size
)
""")
params = {
"days_param": f"{days_old} days",
"batch_size": batch_size
}
elif session.bind.dialect.name == 'sqlite':
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < datetime('now', :days_param)
AND rowid IN (
SELECT rowid FROM {table_name}
WHERE {date_column} < datetime('now', :days_param)
LIMIT :batch_size
)
""")
params = {
"days_param": f"-{days_old} days",
"batch_size": batch_size
}
else:
# Generic fallback (may not work for all databases)
delete_query = text(f"""
DELETE FROM {table_name}
WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY)
LIMIT :batch_size
""")
params = {
"days_old": days_old,
"batch_size": batch_size
}
result = await session.execute(delete_query, params)
deleted_count = result.rowcount
if deleted_count == 0:
break
total_deleted += deleted_count
await session.commit()
logger.debug(f"Deleted batch from {table_name}",
batch_size=deleted_count,
total_deleted=total_deleted)
logger.info(f"Cleanup completed for {table_name}",
total_deleted=total_deleted,
days_old=days_old)
return total_deleted
except Exception as e:
await session.rollback()
logger.error(f"Cleanup failed for {table_name}", error=str(e))
raise DatabaseError(f"Cleanup failed: {str(e)}")
@staticmethod
async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]:
"""
Execute database maintenance tasks
Returns:
Dictionary with maintenance results
"""
try:
results = {}
if session.bind.dialect.name == 'postgresql':
# PostgreSQL maintenance
await session.execute(text("VACUUM ANALYZE"))
results["vacuum"] = "completed"
# Update statistics
await session.execute(text("ANALYZE"))
results["analyze"] = "completed"
elif session.bind.dialect.name == 'sqlite':
# SQLite maintenance
await session.execute(text("VACUUM"))
results["vacuum"] = "completed"
await session.execute(text("ANALYZE"))
results["analyze"] = "completed"
else:
results["maintenance"] = "not_supported"
await session.commit()
logger.info("Database maintenance completed", results=results)
return results
except Exception as e:
await session.rollback()
logger.error("Database maintenance failed", error=str(e))
raise DatabaseError(f"Maintenance failed: {str(e)}")
class QueryLogger:
"""Utility for logging and analyzing database queries"""
def __init__(self, session: AsyncSession):
self.session = session
self._query_log = []
async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None):
"""Log a database query with metadata"""
log_entry = {
"query": query,
"params": params,
"execution_time": execution_time,
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
}
self._query_log.append(log_entry)
# Log slow queries
if execution_time and execution_time > 1.0: # 1 second threshold
logger.warning("Slow query detected",
query=query,
execution_time=execution_time)
def get_query_stats(self) -> Dict[str, Any]:
"""Get statistics about logged queries"""
if not self._query_log:
return {"total_queries": 0}
execution_times = [
entry["execution_time"]
for entry in self._query_log
if entry["execution_time"] is not None
]
return {
"total_queries": len(self._query_log),
"avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0,
"max_execution_time": max(execution_times) if execution_times else 0,
"slow_queries_count": len([t for t in execution_times if t > 1.0])
}
def clear_log(self):
"""Clear the query log"""
self._query_log.clear()

View File

@@ -100,6 +100,27 @@ class MetricsCollector:
self._histograms[name] = histogram
logger.info(f"Registered histogram: {name} for {self.service_name}")
return histogram
except ValueError as e:
if "Duplicated timeseries" in str(e):
# Metric already exists in global registry, try to find it
from prometheus_client import REGISTRY
metric_name = f"{self.service_name.replace('-', '_')}_{name}"
for collector in REGISTRY._collector_to_names.keys():
if hasattr(collector, '_name') and collector._name == metric_name:
self._histograms[name] = collector
logger.warning(f"Reusing existing histogram: {name} for {self.service_name}")
return collector
# If we can't find it, create a new name with suffix
import time
suffix = str(int(time.time() * 1000))[-6:] # Last 6 digits of timestamp
histogram = Histogram(f"{self.service_name.replace('-', '_')}_{name}_{suffix}",
documentation, labelnames=labels, buckets=buckets)
self._histograms[name] = histogram
logger.warning(f"Created histogram with suffix: {name}_{suffix} for {self.service_name}")
return histogram
else:
logger.error(f"Failed to register histogram {name} for {self.service_name}: {e}")
raise
except Exception as e:
logger.error(f"Failed to register histogram {name} for {self.service_name}: {e}")
raise
@@ -295,3 +316,14 @@ def setup_metrics_early(app, service_name: str = None) -> MetricsCollector:
logger.info(f"Metrics setup completed for service: {service_name}")
return metrics_collector
# Additional helper function for endpoint tracking
def track_endpoint_metrics(endpoint_name: str = None, service_name: str = None):
"""Decorator for tracking endpoint metrics"""
def decorator(func):
def wrapper(*args, **kwargs):
# For now, just pass through - metrics are handled by middleware
return func(*args, **kwargs)
return wrapper
return decorator