REFACTOR - Database logic
This commit is contained in:
@@ -238,7 +238,7 @@ async def verify_tenant_access_dep(
|
||||
Raises:
|
||||
HTTPException: If user doesn't have access to tenant
|
||||
"""
|
||||
has_access = await tenant_access_manager.verify_user_tenant_access(current_user["user_id"], tenant_id)
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(current_user["user_id"], tenant_id)
|
||||
if not has_access:
|
||||
logger.warning(f"Access denied to tenant",
|
||||
user_id=current_user["user_id"],
|
||||
@@ -276,7 +276,7 @@ async def verify_tenant_permission_dep(
|
||||
HTTPException: If user doesn't have access or permission
|
||||
"""
|
||||
# First verify basic tenant access
|
||||
has_access = await tenant_access_manager.verify_user_tenant_access(current_user["user_id"], tenant_id)
|
||||
has_access = await tenant_access_manager.verify_basic_tenant_access(current_user["user_id"], tenant_id)
|
||||
if not has_access:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
390
shared/clients/README.md
Normal file
390
shared/clients/README.md
Normal file
@@ -0,0 +1,390 @@
|
||||
# Enhanced Inter-Service Communication System
|
||||
|
||||
This directory contains the enhanced inter-service communication system that integrates with the new repository pattern architecture. The system provides circuit breakers, caching, monitoring, and event tracking for all service-to-service communications.
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
### Base Components
|
||||
|
||||
1. **BaseServiceClient** - Foundation class providing authentication, retries, and basic HTTP operations
|
||||
2. **EnhancedServiceClient** - Adds circuit breaker, caching, and monitoring capabilities
|
||||
3. **ServiceRegistry** - Central registry for managing all enhanced service clients
|
||||
|
||||
### Enhanced Service Clients
|
||||
|
||||
Each service has a specialized enhanced client:
|
||||
|
||||
- **EnhancedDataServiceClient** - Sales data, weather, traffic, products with optimized caching
|
||||
- **EnhancedAuthServiceClient** - Authentication, user management, permissions with security focus
|
||||
- **EnhancedTrainingServiceClient** - ML training, model management, deployment with pipeline monitoring
|
||||
- **EnhancedForecastingServiceClient** - Forecasting, predictions, scenarios with analytics
|
||||
- **EnhancedTenantServiceClient** - Tenant management, memberships, organization features
|
||||
- **EnhancedNotificationServiceClient** - Notifications, templates, delivery tracking
|
||||
|
||||
## Key Features
|
||||
|
||||
### Circuit Breaker Pattern
|
||||
- **States**: Closed (normal), Open (failing), Half-Open (testing recovery)
|
||||
- **Configuration**: Failure threshold, recovery timeout, success threshold
|
||||
- **Monitoring**: State changes tracked and logged
|
||||
|
||||
### Intelligent Caching
|
||||
- **TTL-based**: Different cache durations for different data types
|
||||
- **Invalidation**: Pattern-based cache invalidation on updates
|
||||
- **Statistics**: Hit/miss ratios and performance metrics
|
||||
- **Manual Control**: Clear specific cache patterns when needed
|
||||
|
||||
### Event Integration
|
||||
- **Repository Events**: Entity created/updated/deleted events
|
||||
- **Correlation IDs**: Track operations across services
|
||||
- **Metadata**: Rich event metadata for debugging and monitoring
|
||||
|
||||
### Monitoring & Metrics
|
||||
- **Request Metrics**: Success/failure rates, latencies
|
||||
- **Cache Metrics**: Hit rates, entry counts
|
||||
- **Circuit Breaker Metrics**: State changes, failure counts
|
||||
- **Health Checks**: Per-service and aggregate health status
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Usage with Service Registry
|
||||
|
||||
```python
|
||||
from shared.clients.enhanced_service_client import ServiceRegistry
|
||||
from shared.config.base import BaseServiceSettings
|
||||
|
||||
# Initialize registry
|
||||
config = BaseServiceSettings()
|
||||
registry = ServiceRegistry(config, calling_service="forecasting")
|
||||
|
||||
# Get enhanced clients
|
||||
data_client = registry.get_data_client()
|
||||
auth_client = registry.get_auth_client()
|
||||
training_client = registry.get_training_client()
|
||||
|
||||
# Use with full features
|
||||
sales_data = await data_client.get_all_sales_data_with_monitoring(
|
||||
tenant_id="tenant-123",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
correlation_id="forecast-job-456"
|
||||
)
|
||||
```
|
||||
|
||||
### Data Service Operations
|
||||
|
||||
```python
|
||||
# Get sales data with intelligent caching
|
||||
sales_data = await data_client.get_sales_data_cached(
|
||||
tenant_id="tenant-123",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-31",
|
||||
aggregation="daily"
|
||||
)
|
||||
|
||||
# Upload sales data with cache invalidation and events
|
||||
result = await data_client.upload_sales_data_with_events(
|
||||
tenant_id="tenant-123",
|
||||
sales_data=sales_records,
|
||||
correlation_id="data-import-789"
|
||||
)
|
||||
|
||||
# Get weather data with caching (30 min TTL)
|
||||
weather_data = await data_client.get_weather_historical_cached(
|
||||
tenant_id="tenant-123",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-31"
|
||||
)
|
||||
```
|
||||
|
||||
### Authentication & User Management
|
||||
|
||||
```python
|
||||
# Authenticate with security monitoring
|
||||
auth_result = await auth_client.authenticate_user_cached(
|
||||
email="user@example.com",
|
||||
password="password"
|
||||
)
|
||||
|
||||
# Check permissions with caching
|
||||
has_access = await auth_client.check_user_permissions_cached(
|
||||
user_id="user-123",
|
||||
tenant_id="tenant-456",
|
||||
resource="sales_data",
|
||||
action="read"
|
||||
)
|
||||
|
||||
# Create user with events
|
||||
user = await auth_client.create_user_with_events(
|
||||
user_data={
|
||||
"email": "new@example.com",
|
||||
"name": "New User",
|
||||
"role": "analyst"
|
||||
},
|
||||
tenant_id="tenant-123",
|
||||
correlation_id="user-creation-789"
|
||||
)
|
||||
```
|
||||
|
||||
### Training & ML Operations
|
||||
|
||||
```python
|
||||
# Create training job with monitoring
|
||||
job = await training_client.create_training_job_with_monitoring(
|
||||
tenant_id="tenant-123",
|
||||
include_weather=True,
|
||||
include_traffic=False,
|
||||
min_data_points=30,
|
||||
correlation_id="training-pipeline-456"
|
||||
)
|
||||
|
||||
# Get active model with caching
|
||||
model = await training_client.get_active_model_for_product_cached(
|
||||
tenant_id="tenant-123",
|
||||
product_name="croissants"
|
||||
)
|
||||
|
||||
# Deploy model with events
|
||||
deployment = await training_client.deploy_model_with_events(
|
||||
tenant_id="tenant-123",
|
||||
model_id="model-789",
|
||||
correlation_id="deployment-123"
|
||||
)
|
||||
|
||||
# Get pipeline status
|
||||
status = await training_client.get_training_pipeline_status("tenant-123")
|
||||
```
|
||||
|
||||
### Forecasting & Predictions
|
||||
|
||||
```python
|
||||
# Create forecast with monitoring
|
||||
forecast = await forecasting_client.create_forecast_with_monitoring(
|
||||
tenant_id="tenant-123",
|
||||
model_id="model-456",
|
||||
start_date="2024-02-01",
|
||||
end_date="2024-02-29",
|
||||
correlation_id="forecast-creation-789"
|
||||
)
|
||||
|
||||
# Get predictions with caching
|
||||
predictions = await forecasting_client.get_predictions_cached(
|
||||
tenant_id="tenant-123",
|
||||
forecast_id="forecast-456",
|
||||
start_date="2024-02-01",
|
||||
end_date="2024-02-07"
|
||||
)
|
||||
|
||||
# Real-time prediction with caching
|
||||
prediction = await forecasting_client.create_realtime_prediction_with_monitoring(
|
||||
tenant_id="tenant-123",
|
||||
model_id="model-456",
|
||||
target_date="2024-02-01",
|
||||
features={"temperature": 20, "day_of_week": 1},
|
||||
correlation_id="realtime-pred-123"
|
||||
)
|
||||
|
||||
# Get forecasting dashboard
|
||||
dashboard = await forecasting_client.get_forecasting_dashboard("tenant-123")
|
||||
```
|
||||
|
||||
### Tenant Management
|
||||
|
||||
```python
|
||||
# Create tenant with monitoring
|
||||
tenant = await tenant_client.create_tenant_with_monitoring(
|
||||
name="New Bakery Chain",
|
||||
owner_id="user-123",
|
||||
description="Multi-location bakery chain",
|
||||
correlation_id="tenant-creation-456"
|
||||
)
|
||||
|
||||
# Add member with events
|
||||
membership = await tenant_client.add_tenant_member_with_events(
|
||||
tenant_id="tenant-123",
|
||||
user_id="user-456",
|
||||
role="manager",
|
||||
correlation_id="member-add-789"
|
||||
)
|
||||
|
||||
# Get tenant analytics
|
||||
analytics = await tenant_client.get_tenant_analytics("tenant-123")
|
||||
```
|
||||
|
||||
### Notification Management
|
||||
|
||||
```python
|
||||
# Send notification with monitoring
|
||||
notification = await notification_client.send_notification_with_monitoring(
|
||||
recipient_id="user-123",
|
||||
notification_type="forecast_ready",
|
||||
title="Forecast Complete",
|
||||
message="Your weekly forecast is ready for review",
|
||||
tenant_id="tenant-456",
|
||||
priority="high",
|
||||
channels=["email", "in_app"],
|
||||
correlation_id="forecast-notification-789"
|
||||
)
|
||||
|
||||
# Send bulk notification
|
||||
bulk_result = await notification_client.send_bulk_notification_with_monitoring(
|
||||
recipients=["user-123", "user-456", "user-789"],
|
||||
notification_type="system_update",
|
||||
title="System Maintenance",
|
||||
message="Scheduled maintenance tonight at 2 AM",
|
||||
priority="normal",
|
||||
correlation_id="maintenance-notification-123"
|
||||
)
|
||||
|
||||
# Get delivery analytics
|
||||
analytics = await notification_client.get_delivery_analytics(
|
||||
tenant_id="tenant-123",
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-01-31"
|
||||
)
|
||||
```
|
||||
|
||||
## Health Monitoring
|
||||
|
||||
### Individual Service Health
|
||||
|
||||
```python
|
||||
# Get specific service health
|
||||
data_health = data_client.get_data_service_health()
|
||||
auth_health = auth_client.get_auth_service_health()
|
||||
training_health = training_client.get_training_service_health()
|
||||
|
||||
# Health includes:
|
||||
# - Circuit breaker status
|
||||
# - Cache statistics and configuration
|
||||
# - Service-specific features
|
||||
# - Supported endpoints
|
||||
```
|
||||
|
||||
### Registry-Level Health
|
||||
|
||||
```python
|
||||
# Get all service health status
|
||||
all_health = registry.get_all_health_status()
|
||||
|
||||
# Get aggregate metrics
|
||||
metrics = registry.get_aggregate_metrics()
|
||||
# Returns:
|
||||
# - Total cache hits/misses and hit rate
|
||||
# - Circuit breaker states for all services
|
||||
# - Count of healthy vs total services
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### Cache TTL Configuration
|
||||
|
||||
Each enhanced client has optimized cache TTL values:
|
||||
|
||||
```python
|
||||
# Data Service
|
||||
sales_cache_ttl = 600 # 10 minutes
|
||||
weather_cache_ttl = 1800 # 30 minutes
|
||||
traffic_cache_ttl = 3600 # 1 hour
|
||||
product_cache_ttl = 300 # 5 minutes
|
||||
|
||||
# Auth Service
|
||||
user_cache_ttl = 300 # 5 minutes
|
||||
token_cache_ttl = 60 # 1 minute
|
||||
permission_cache_ttl = 900 # 15 minutes
|
||||
|
||||
# Training Service
|
||||
job_cache_ttl = 180 # 3 minutes
|
||||
model_cache_ttl = 600 # 10 minutes
|
||||
metrics_cache_ttl = 300 # 5 minutes
|
||||
|
||||
# And so on...
|
||||
```
|
||||
|
||||
### Circuit Breaker Configuration
|
||||
|
||||
```python
|
||||
CircuitBreakerConfig(
|
||||
failure_threshold=5, # Failures before opening
|
||||
recovery_timeout=60, # Seconds before testing recovery
|
||||
success_threshold=2, # Successes needed to close
|
||||
timeout=30 # Request timeout in seconds
|
||||
)
|
||||
```
|
||||
|
||||
## Event System Integration
|
||||
|
||||
All enhanced clients integrate with the enhanced event system:
|
||||
|
||||
### Event Types
|
||||
- **EntityCreatedEvent** - When entities are created
|
||||
- **EntityUpdatedEvent** - When entities are modified
|
||||
- **EntityDeletedEvent** - When entities are removed
|
||||
|
||||
### Event Metadata
|
||||
- **correlation_id** - Track operations across services
|
||||
- **source_service** - Service that generated the event
|
||||
- **destination_service** - Target service
|
||||
- **tenant_id** - Tenant context
|
||||
- **user_id** - User context
|
||||
- **tags** - Additional metadata
|
||||
|
||||
### Usage in Enhanced Clients
|
||||
Events are automatically published for:
|
||||
- Data uploads and modifications
|
||||
- User creation/updates/deletion
|
||||
- Training job lifecycle
|
||||
- Model deployments
|
||||
- Forecast creation
|
||||
- Tenant management operations
|
||||
- Notification delivery
|
||||
|
||||
## Error Handling & Resilience
|
||||
|
||||
### Circuit Breaker Protection
|
||||
- Automatically stops requests when services are failing
|
||||
- Provides fallback to cached data when available
|
||||
- Gradually tests service recovery
|
||||
|
||||
### Retry Logic
|
||||
- Exponential backoff for transient failures
|
||||
- Configurable retry counts and delays
|
||||
- Authentication token refresh on 401 errors
|
||||
|
||||
### Cache Fallbacks
|
||||
- Returns cached data when services are unavailable
|
||||
- Graceful degradation with stale data warnings
|
||||
- Manual cache invalidation for data consistency
|
||||
|
||||
## Integration with Repository Pattern
|
||||
|
||||
The enhanced clients seamlessly integrate with the new repository pattern:
|
||||
|
||||
### Service Layer Integration
|
||||
```python
|
||||
class ForecastingService:
|
||||
def __init__(self,
|
||||
forecast_repository: ForecastRepository,
|
||||
service_registry: ServiceRegistry):
|
||||
self.forecast_repository = forecast_repository
|
||||
self.data_client = service_registry.get_data_client()
|
||||
self.training_client = service_registry.get_training_client()
|
||||
|
||||
async def create_forecast(self, tenant_id: str, model_id: str):
|
||||
# Get data through enhanced client
|
||||
sales_data = await self.data_client.get_all_sales_data_with_monitoring(
|
||||
tenant_id=tenant_id,
|
||||
correlation_id=f"forecast_data_{datetime.utcnow().isoformat()}"
|
||||
)
|
||||
|
||||
# Use repository for database operations
|
||||
forecast = await self.forecast_repository.create({
|
||||
"tenant_id": tenant_id,
|
||||
"model_id": model_id,
|
||||
"status": "pending"
|
||||
})
|
||||
|
||||
return forecast
|
||||
```
|
||||
|
||||
This completes the comprehensive enhanced inter-service communication system that integrates seamlessly with the new repository pattern architecture, providing resilience, monitoring, and advanced features for all service interactions.
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Shared Database Infrastructure
|
||||
Provides consistent database patterns across all microservices
|
||||
"""
|
||||
|
||||
from .base import DatabaseManager, Base, create_database_manager
|
||||
from .repository import BaseRepository
|
||||
from .unit_of_work import UnitOfWork, ServiceUnitOfWork, RepositoryRegistry
|
||||
from .transactions import (
|
||||
transactional,
|
||||
unit_of_work_transactional,
|
||||
managed_transaction,
|
||||
managed_unit_of_work,
|
||||
TransactionManager,
|
||||
run_in_transaction,
|
||||
run_with_unit_of_work
|
||||
)
|
||||
from .exceptions import (
|
||||
DatabaseError,
|
||||
ConnectionError,
|
||||
RecordNotFoundError,
|
||||
DuplicateRecordError,
|
||||
ConstraintViolationError,
|
||||
TransactionError,
|
||||
ValidationError,
|
||||
MigrationError,
|
||||
HealthCheckError
|
||||
)
|
||||
from .utils import DatabaseUtils, QueryLogger
|
||||
|
||||
__all__ = [
|
||||
# Core components
|
||||
"DatabaseManager",
|
||||
"Base",
|
||||
"create_database_manager",
|
||||
|
||||
# Repository pattern
|
||||
"BaseRepository",
|
||||
|
||||
# Unit of Work pattern
|
||||
"UnitOfWork",
|
||||
"ServiceUnitOfWork",
|
||||
"RepositoryRegistry",
|
||||
|
||||
# Transaction management
|
||||
"transactional",
|
||||
"unit_of_work_transactional",
|
||||
"managed_transaction",
|
||||
"managed_unit_of_work",
|
||||
"TransactionManager",
|
||||
"run_in_transaction",
|
||||
"run_with_unit_of_work",
|
||||
|
||||
# Exceptions
|
||||
"DatabaseError",
|
||||
"ConnectionError",
|
||||
"RecordNotFoundError",
|
||||
"DuplicateRecordError",
|
||||
"ConstraintViolationError",
|
||||
"TransactionError",
|
||||
"ValidationError",
|
||||
"MigrationError",
|
||||
"HealthCheckError",
|
||||
|
||||
# Utilities
|
||||
"DatabaseUtils",
|
||||
"QueryLogger"
|
||||
]
|
||||
@@ -1,78 +1,298 @@
|
||||
"""
|
||||
Base database configuration for all microservices
|
||||
Enhanced Base Database Configuration for All Microservices
|
||||
Provides DatabaseManager with connection pooling, health checks, and multi-database support
|
||||
"""
|
||||
|
||||
import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from typing import Optional, Dict, Any, List
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from sqlalchemy.pool import StaticPool, QueuePool
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
import structlog
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from .exceptions import DatabaseError, ConnectionError, HealthCheckError
|
||||
from .utils import DatabaseUtils
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class DatabaseManager:
|
||||
"""Database manager for microservices"""
|
||||
"""Enhanced Database Manager for Microservices
|
||||
|
||||
def __init__(self, database_url: str):
|
||||
Provides:
|
||||
- Connection pooling with configurable settings
|
||||
- Health checks and monitoring
|
||||
- Multi-database support
|
||||
- Session lifecycle management
|
||||
- Background task session support
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database_url: str,
|
||||
service_name: str = "unknown",
|
||||
pool_size: int = 20,
|
||||
max_overflow: int = 30,
|
||||
pool_recycle: int = 3600,
|
||||
pool_pre_ping: bool = True,
|
||||
echo: bool = False,
|
||||
connect_timeout: int = 30,
|
||||
**engine_kwargs
|
||||
):
|
||||
self.database_url = database_url
|
||||
self.async_engine = create_async_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=300,
|
||||
pool_size=20,
|
||||
max_overflow=30
|
||||
)
|
||||
self.service_name = service_name
|
||||
self.pool_size = pool_size
|
||||
self.max_overflow = max_overflow
|
||||
|
||||
self.async_session_local = sessionmaker(
|
||||
# Configure pool class based on database type
|
||||
poolclass = QueuePool
|
||||
if "sqlite" in database_url.lower():
|
||||
poolclass = StaticPool
|
||||
pool_size = 1
|
||||
max_overflow = 0
|
||||
|
||||
# Create async engine with enhanced configuration
|
||||
engine_config = {
|
||||
"echo": echo,
|
||||
"pool_pre_ping": pool_pre_ping,
|
||||
"pool_recycle": pool_recycle,
|
||||
"pool_size": pool_size,
|
||||
"max_overflow": max_overflow,
|
||||
"poolclass": poolclass,
|
||||
"connect_args": {"command_timeout": connect_timeout},
|
||||
**engine_kwargs
|
||||
}
|
||||
|
||||
self.async_engine = create_async_engine(database_url, **engine_config)
|
||||
|
||||
# Create session factory
|
||||
self.async_session_local = async_sessionmaker(
|
||||
self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False
|
||||
)
|
||||
|
||||
logger.info(f"DatabaseManager initialized for {service_name}",
|
||||
pool_size=pool_size,
|
||||
max_overflow=max_overflow,
|
||||
database_type=self._get_database_type())
|
||||
|
||||
async def get_db(self):
|
||||
"""Get database session for request handlers"""
|
||||
"""Get database session for request handlers (FastAPI dependency)"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created for request")
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error(f"Database session error: {e}")
|
||||
# Don't wrap HTTPExceptions - let them pass through
|
||||
if hasattr(e, 'status_code') and hasattr(e, 'detail'):
|
||||
# This is likely an HTTPException - don't wrap it
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
error_msg = str(e) if str(e) else f"{type(e).__name__}: {repr(e)}"
|
||||
logger.error(f"Database session error: {error_msg}", service=self.service_name)
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
# Handle specific ASGI stream issues more gracefully
|
||||
if "EndOfStream" in str(type(e)) or "WouldBlock" in str(type(e)):
|
||||
raise DatabaseError(f"Session error: Request stream disconnected ({type(e).__name__})")
|
||||
else:
|
||||
raise DatabaseError(f"Session error: {error_msg}")
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_session(self):
|
||||
"""
|
||||
✅ NEW: Get database session for background tasks
|
||||
Get database session for background tasks with auto-commit
|
||||
|
||||
Usage:
|
||||
async with database_manager.get_background_session() as session:
|
||||
# Your background task code here
|
||||
await session.commit()
|
||||
# Auto-commits on success, rolls back on exception
|
||||
"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Background session created", service=self.service_name)
|
||||
yield session
|
||||
await session.commit()
|
||||
logger.debug("Background session committed")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Background task database error: {e}")
|
||||
raise
|
||||
logger.error(f"Background task database error: {e}",
|
||||
service=self.service_name)
|
||||
raise DatabaseError(f"Background task failed: {str(e)}")
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Background session closed")
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_session(self):
|
||||
"""Get a plain database session (no auto-commit)"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Session error: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Session error: {str(e)}")
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def create_tables(self):
|
||||
"""Create database tables"""
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
# ===== TABLE MANAGEMENT =====
|
||||
|
||||
async def drop_tables(self):
|
||||
async def create_tables(self, metadata=None):
|
||||
"""Create database tables"""
|
||||
try:
|
||||
target_metadata = metadata or Base.metadata
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(target_metadata.create_all)
|
||||
logger.info("Database tables created successfully", service=self.service_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create tables: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Table creation failed: {str(e)}")
|
||||
|
||||
async def drop_tables(self, metadata=None):
|
||||
"""Drop database tables"""
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
try:
|
||||
target_metadata = metadata or Base.metadata
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(target_metadata.drop_all)
|
||||
logger.info("Database tables dropped successfully", service=self.service_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to drop tables: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Table drop failed: {str(e)}")
|
||||
|
||||
# ===== HEALTH CHECKS AND MONITORING =====
|
||||
|
||||
async def health_check(self) -> Dict[str, Any]:
|
||||
"""Comprehensive health check for the database"""
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
return await DatabaseUtils.execute_health_check(session)
|
||||
except Exception as e:
|
||||
logger.error(f"Health check failed: {e}", service=self.service_name)
|
||||
raise HealthCheckError(f"Health check failed: {str(e)}")
|
||||
|
||||
async def get_connection_info(self) -> Dict[str, Any]:
|
||||
"""Get database connection information"""
|
||||
try:
|
||||
pool = self.async_engine.pool
|
||||
return {
|
||||
"service_name": self.service_name,
|
||||
"database_type": self._get_database_type(),
|
||||
"pool_size": self.pool_size,
|
||||
"max_overflow": self.max_overflow,
|
||||
"current_checked_in": pool.checkedin() if pool else 0,
|
||||
"current_checked_out": pool.checkedout() if pool else 0,
|
||||
"current_overflow": pool.overflow() if pool else 0,
|
||||
"invalid_connections": pool.invalid() if pool else 0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get connection info: {e}", service=self.service_name)
|
||||
return {"error": str(e)}
|
||||
|
||||
def _get_database_type(self) -> str:
|
||||
"""Get database type from URL"""
|
||||
return self.database_url.split("://")[0].lower() if "://" in self.database_url else "unknown"
|
||||
|
||||
# ===== CLEANUP AND MAINTENANCE =====
|
||||
|
||||
async def close_connections(self):
|
||||
"""Close all database connections"""
|
||||
try:
|
||||
await self.async_engine.dispose()
|
||||
logger.info("Database connections closed", service=self.service_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to close connections: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Connection cleanup failed: {str(e)}")
|
||||
|
||||
async def execute_maintenance(self) -> Dict[str, Any]:
|
||||
"""Execute database maintenance tasks"""
|
||||
try:
|
||||
async with self.get_session() as session:
|
||||
return await DatabaseUtils.execute_maintenance(session)
|
||||
except Exception as e:
|
||||
logger.error(f"Maintenance failed: {e}", service=self.service_name)
|
||||
raise DatabaseError(f"Maintenance failed: {str(e)}")
|
||||
|
||||
# ===== UTILITY METHODS =====
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test database connectivity"""
|
||||
try:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Connection test successful", service=self.service_name)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Connection test failed: {e}", service=self.service_name)
|
||||
return False
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"DatabaseManager(service='{self.service_name}', type='{self._get_database_type()}')"
|
||||
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
|
||||
# ===== CONVENIENCE FUNCTIONS =====
|
||||
|
||||
def create_database_manager(
|
||||
database_url: str,
|
||||
service_name: str,
|
||||
**kwargs
|
||||
) -> DatabaseManager:
|
||||
"""Factory function to create DatabaseManager instances"""
|
||||
return DatabaseManager(database_url, service_name, **kwargs)
|
||||
|
||||
|
||||
# ===== LEGACY COMPATIBILITY =====
|
||||
|
||||
# Keep backward compatibility for existing code
|
||||
engine = None
|
||||
AsyncSessionLocal = None
|
||||
|
||||
def init_legacy_compatibility(database_url: str):
|
||||
"""Initialize legacy global variables for backward compatibility"""
|
||||
global engine, AsyncSessionLocal
|
||||
|
||||
engine = create_async_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=300,
|
||||
pool_size=20,
|
||||
max_overflow=30
|
||||
)
|
||||
|
||||
AsyncSessionLocal = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
logger.warning("Using legacy database configuration - consider migrating to DatabaseManager")
|
||||
|
||||
|
||||
async def get_legacy_db():
|
||||
"""Legacy database session getter for backward compatibility"""
|
||||
if not AsyncSessionLocal:
|
||||
raise RuntimeError("Legacy database not initialized - call init_legacy_compatibility first")
|
||||
|
||||
async with AsyncSessionLocal() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error(f"Legacy database session error: {e}")
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
78
shared/database/base.py.backup
Normal file
78
shared/database/base.py.backup
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Base database configuration for all microservices
|
||||
"""
|
||||
|
||||
import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker, declarative_base
|
||||
from sqlalchemy.pool import StaticPool
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class DatabaseManager:
|
||||
"""Database manager for microservices"""
|
||||
|
||||
def __init__(self, database_url: str):
|
||||
self.database_url = database_url
|
||||
self.async_engine = create_async_engine(
|
||||
database_url,
|
||||
echo=False,
|
||||
pool_pre_ping=True,
|
||||
pool_recycle=300,
|
||||
pool_size=20,
|
||||
max_overflow=30
|
||||
)
|
||||
|
||||
self.async_session_local = sessionmaker(
|
||||
self.async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False
|
||||
)
|
||||
|
||||
async def get_db(self):
|
||||
"""Get database session for request handlers"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error(f"Database session error: {e}")
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_session(self):
|
||||
"""
|
||||
✅ NEW: Get database session for background tasks
|
||||
|
||||
Usage:
|
||||
async with database_manager.get_background_session() as session:
|
||||
# Your background task code here
|
||||
await session.commit()
|
||||
"""
|
||||
async with self.async_session_local() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Background task database error: {e}")
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def create_tables(self):
|
||||
"""Create database tables"""
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
async def drop_tables(self):
|
||||
"""Drop database tables"""
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
52
shared/database/exceptions.py
Normal file
52
shared/database/exceptions.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
Custom Database Exceptions
|
||||
Provides consistent error handling across all microservices
|
||||
"""
|
||||
|
||||
class DatabaseError(Exception):
|
||||
"""Base exception for database-related errors"""
|
||||
|
||||
def __init__(self, message: str, details: dict = None):
|
||||
self.message = message
|
||||
self.details = details or {}
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ConnectionError(DatabaseError):
|
||||
"""Raised when database connection fails"""
|
||||
pass
|
||||
|
||||
|
||||
class RecordNotFoundError(DatabaseError):
|
||||
"""Raised when a requested record is not found"""
|
||||
pass
|
||||
|
||||
|
||||
class DuplicateRecordError(DatabaseError):
|
||||
"""Raised when trying to create a duplicate record"""
|
||||
pass
|
||||
|
||||
|
||||
class ConstraintViolationError(DatabaseError):
|
||||
"""Raised when database constraints are violated"""
|
||||
pass
|
||||
|
||||
|
||||
class TransactionError(DatabaseError):
|
||||
"""Raised when transaction operations fail"""
|
||||
pass
|
||||
|
||||
|
||||
class ValidationError(DatabaseError):
|
||||
"""Raised when data validation fails before database operations"""
|
||||
pass
|
||||
|
||||
|
||||
class MigrationError(DatabaseError):
|
||||
"""Raised when database migration operations fail"""
|
||||
pass
|
||||
|
||||
|
||||
class HealthCheckError(DatabaseError):
|
||||
"""Raised when database health checks fail"""
|
||||
pass
|
||||
422
shared/database/repository.py
Normal file
422
shared/database/repository.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Base Repository Pattern for Database Operations
|
||||
Provides generic CRUD operations, query building, and caching
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, TypeVar, Generic, Type, Union
|
||||
from abc import ABC, abstractmethod
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import declarative_base
|
||||
from sqlalchemy import select, update, delete, and_, or_, desc, asc, func, text
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from contextlib import asynccontextmanager
|
||||
import structlog
|
||||
|
||||
from .exceptions import (
|
||||
DatabaseError,
|
||||
RecordNotFoundError,
|
||||
DuplicateRecordError,
|
||||
ConstraintViolationError
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Type variables for generic repository
|
||||
Model = TypeVar('Model', bound=declarative_base())
|
||||
CreateSchema = TypeVar('CreateSchema')
|
||||
UpdateSchema = TypeVar('UpdateSchema')
|
||||
|
||||
|
||||
class BaseRepository(Generic[Model, CreateSchema, UpdateSchema], ABC):
|
||||
"""
|
||||
Base repository providing generic CRUD operations
|
||||
|
||||
Args:
|
||||
model: SQLAlchemy model class
|
||||
session: Database session
|
||||
cache_ttl: Cache time-to-live in seconds (optional)
|
||||
"""
|
||||
|
||||
def __init__(self, model: Type[Model], session: AsyncSession, cache_ttl: Optional[int] = None):
|
||||
self.model = model
|
||||
self.session = session
|
||||
self.cache_ttl = cache_ttl
|
||||
self._cache = {} if cache_ttl else None
|
||||
|
||||
# ===== CORE CRUD OPERATIONS =====
|
||||
|
||||
async def create(self, obj_in: CreateSchema, **kwargs) -> Model:
|
||||
"""Create a new record"""
|
||||
try:
|
||||
# Convert schema to dict if needed
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
obj_data = obj_in.model_dump()
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
obj_data = obj_in.dict()
|
||||
else:
|
||||
obj_data = obj_in
|
||||
|
||||
# Merge with additional kwargs
|
||||
obj_data.update(kwargs)
|
||||
|
||||
db_obj = self.model(**obj_data)
|
||||
self.session.add(db_obj)
|
||||
await self.session.flush() # Get ID without committing
|
||||
await self.session.refresh(db_obj)
|
||||
|
||||
logger.debug(f"Created {self.model.__name__}", record_id=getattr(db_obj, 'id', None))
|
||||
return db_obj
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error creating {self.model.__name__}", error=str(e))
|
||||
raise DuplicateRecordError(f"Record with provided data already exists")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error creating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to create record: {str(e)}")
|
||||
|
||||
async def get_by_id(self, record_id: Any) -> Optional[Model]:
|
||||
"""Get record by ID with optional caching"""
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
|
||||
# Check cache first
|
||||
if self._cache and cache_key in self._cache:
|
||||
logger.debug(f"Cache hit for {cache_key}")
|
||||
return self._cache[cache_key]
|
||||
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(self.model).where(self.model.id == record_id)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
# Cache the result
|
||||
if self._cache and record:
|
||||
self._cache[cache_key] = record
|
||||
|
||||
return record
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting {self.model.__name__} by ID",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to get record: {str(e)}")
|
||||
|
||||
async def get_by_field(self, field_name: str, value: Any) -> Optional[Model]:
|
||||
"""Get record by specific field"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(self.model).where(getattr(self.model, field_name) == value)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
except AttributeError:
|
||||
raise ValueError(f"Field '{field_name}' not found in {self.model.__name__}")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting {self.model.__name__} by {field_name}",
|
||||
value=value, error=str(e))
|
||||
raise DatabaseError(f"Failed to get record: {str(e)}")
|
||||
|
||||
async def get_multi(
|
||||
self,
|
||||
skip: int = 0,
|
||||
limit: int = 100,
|
||||
order_by: Optional[str] = None,
|
||||
order_desc: bool = False,
|
||||
filters: Optional[Dict[str, Any]] = None
|
||||
) -> List[Model]:
|
||||
"""Get multiple records with pagination, sorting, and filtering"""
|
||||
try:
|
||||
query = select(self.model)
|
||||
|
||||
# Apply filters
|
||||
if filters:
|
||||
conditions = []
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field):
|
||||
if isinstance(value, list):
|
||||
conditions.append(getattr(self.model, field).in_(value))
|
||||
else:
|
||||
conditions.append(getattr(self.model, field) == value)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
# Apply ordering
|
||||
if order_by and hasattr(self.model, order_by):
|
||||
order_field = getattr(self.model, order_by)
|
||||
if order_desc:
|
||||
query = query.order_by(desc(order_field))
|
||||
else:
|
||||
query = query.order_by(asc(order_field))
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error getting multiple {self.model.__name__} records",
|
||||
error=str(e))
|
||||
raise DatabaseError(f"Failed to get records: {str(e)}")
|
||||
|
||||
async def update(self, record_id: Any, obj_in: UpdateSchema, **kwargs) -> Optional[Model]:
|
||||
"""Update record by ID"""
|
||||
try:
|
||||
# Convert schema to dict if needed
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
update_data = obj_in.model_dump(exclude_unset=True)
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
update_data = obj_in.dict(exclude_unset=True)
|
||||
else:
|
||||
update_data = obj_in
|
||||
|
||||
# Merge with additional kwargs
|
||||
update_data.update(kwargs)
|
||||
|
||||
# Remove None values
|
||||
update_data = {k: v for k, v in update_data.items() if v is not None}
|
||||
|
||||
if not update_data:
|
||||
logger.warning(f"No data to update for {self.model.__name__}", record_id=record_id)
|
||||
return await self.get_by_id(record_id)
|
||||
|
||||
# Perform update
|
||||
result = await self.session.execute(
|
||||
update(self.model)
|
||||
.where(self.model.id == record_id)
|
||||
.values(**update_data)
|
||||
.returning(self.model)
|
||||
)
|
||||
|
||||
updated_record = result.scalar_one_or_none()
|
||||
|
||||
if not updated_record:
|
||||
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
|
||||
|
||||
# Clear cache
|
||||
if self._cache:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Updated {self.model.__name__}", record_id=record_id)
|
||||
return updated_record
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error updating {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise ConstraintViolationError(f"Update violates database constraints")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error updating {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to update record: {str(e)}")
|
||||
|
||||
async def delete(self, record_id: Any) -> bool:
|
||||
"""Delete record by ID"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
delete(self.model).where(self.model.id == record_id)
|
||||
)
|
||||
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if deleted_count == 0:
|
||||
raise RecordNotFoundError(f"{self.model.__name__} with id {record_id} not found")
|
||||
|
||||
# Clear cache
|
||||
if self._cache:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Deleted {self.model.__name__}", record_id=record_id)
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error deleting {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to delete record: {str(e)}")
|
||||
|
||||
# ===== ADVANCED QUERY OPERATIONS =====
|
||||
|
||||
async def count(self, filters: Optional[Dict[str, Any]] = None) -> int:
|
||||
"""Count records with optional filters"""
|
||||
try:
|
||||
query = select(func.count(self.model.id))
|
||||
|
||||
if filters:
|
||||
conditions = []
|
||||
for field, value in filters.items():
|
||||
if hasattr(self.model, field):
|
||||
if isinstance(value, list):
|
||||
conditions.append(getattr(self.model, field).in_(value))
|
||||
else:
|
||||
conditions.append(getattr(self.model, field) == value)
|
||||
|
||||
if conditions:
|
||||
query = query.where(and_(*conditions))
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalar() or 0
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error counting {self.model.__name__} records", error=str(e))
|
||||
raise DatabaseError(f"Failed to count records: {str(e)}")
|
||||
|
||||
async def exists(self, record_id: Any) -> bool:
|
||||
"""Check if record exists by ID"""
|
||||
try:
|
||||
result = await self.session.execute(
|
||||
select(func.count(self.model.id)).where(self.model.id == record_id)
|
||||
)
|
||||
count = result.scalar() or 0
|
||||
return count > 0
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error checking existence of {self.model.__name__}",
|
||||
record_id=record_id, error=str(e))
|
||||
raise DatabaseError(f"Failed to check record existence: {str(e)}")
|
||||
|
||||
async def bulk_create(self, objects: List[CreateSchema]) -> List[Model]:
|
||||
"""Create multiple records in bulk"""
|
||||
try:
|
||||
if not objects:
|
||||
return []
|
||||
|
||||
db_objects = []
|
||||
for obj_in in objects:
|
||||
if hasattr(obj_in, 'model_dump'):
|
||||
obj_data = obj_in.model_dump()
|
||||
elif hasattr(obj_in, 'dict'):
|
||||
obj_data = obj_in.dict()
|
||||
else:
|
||||
obj_data = obj_in
|
||||
|
||||
db_objects.append(self.model(**obj_data))
|
||||
|
||||
self.session.add_all(db_objects)
|
||||
await self.session.flush()
|
||||
|
||||
for db_obj in db_objects:
|
||||
await self.session.refresh(db_obj)
|
||||
|
||||
logger.debug(f"Bulk created {len(db_objects)} {self.model.__name__} records")
|
||||
return db_objects
|
||||
|
||||
except IntegrityError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Integrity error bulk creating {self.model.__name__}", error=str(e))
|
||||
raise DuplicateRecordError(f"One or more records already exist")
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error bulk creating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to create records: {str(e)}")
|
||||
|
||||
async def bulk_update(self, updates: List[Dict[str, Any]]) -> int:
|
||||
"""Update multiple records in bulk"""
|
||||
try:
|
||||
if not updates:
|
||||
return 0
|
||||
|
||||
# Group updates by fields being updated for efficiency
|
||||
for update_data in updates:
|
||||
if 'id' not in update_data:
|
||||
raise ValueError("Each update must include 'id' field")
|
||||
|
||||
record_id = update_data.pop('id')
|
||||
await self.session.execute(
|
||||
update(self.model)
|
||||
.where(self.model.id == record_id)
|
||||
.values(**update_data)
|
||||
)
|
||||
|
||||
# Clear relevant cache entries
|
||||
if self._cache:
|
||||
for update_data in updates:
|
||||
record_id = update_data.get('id')
|
||||
if record_id:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
|
||||
logger.debug(f"Bulk updated {len(updates)} {self.model.__name__} records")
|
||||
return len(updates)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Database error bulk updating {self.model.__name__}", error=str(e))
|
||||
raise DatabaseError(f"Failed to update records: {str(e)}")
|
||||
|
||||
# ===== SEARCH AND QUERY BUILDING =====
|
||||
|
||||
async def search(
|
||||
self,
|
||||
search_term: str,
|
||||
search_fields: List[str],
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[Model]:
|
||||
"""Search records across multiple fields"""
|
||||
try:
|
||||
conditions = []
|
||||
for field in search_fields:
|
||||
if hasattr(self.model, field):
|
||||
field_obj = getattr(self.model, field)
|
||||
# Case-insensitive partial match
|
||||
conditions.append(field_obj.ilike(f"%{search_term}%"))
|
||||
|
||||
if not conditions:
|
||||
logger.warning(f"No valid search fields provided for {self.model.__name__}")
|
||||
return []
|
||||
|
||||
query = select(self.model).where(or_(*conditions)).offset(skip).limit(limit)
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error searching {self.model.__name__}",
|
||||
search_term=search_term, error=str(e))
|
||||
raise DatabaseError(f"Failed to search records: {str(e)}")
|
||||
|
||||
async def execute_raw_query(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any:
|
||||
"""Execute raw SQL query (use with caution)"""
|
||||
try:
|
||||
result = await self.session.execute(text(query), params or {})
|
||||
return result
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error executing raw query", query=query, error=str(e))
|
||||
raise DatabaseError(f"Failed to execute query: {str(e)}")
|
||||
|
||||
# ===== CACHE MANAGEMENT =====
|
||||
|
||||
def clear_cache(self, record_id: Optional[Any] = None):
|
||||
"""Clear cache for specific record or all records"""
|
||||
if not self._cache:
|
||||
return
|
||||
|
||||
if record_id:
|
||||
cache_key = f"{self.model.__name__}:{record_id}"
|
||||
self._cache.pop(cache_key, None)
|
||||
else:
|
||||
# Clear all cache entries for this model
|
||||
keys_to_remove = [k for k in self._cache.keys() if k.startswith(f"{self.model.__name__}:")]
|
||||
for key in keys_to_remove:
|
||||
self._cache.pop(key, None)
|
||||
|
||||
logger.debug(f"Cleared cache for {self.model.__name__}", record_id=record_id)
|
||||
|
||||
# ===== CONTEXT MANAGERS =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction(self):
|
||||
"""Context manager for explicit transaction handling"""
|
||||
try:
|
||||
yield self.session
|
||||
await self.session.commit()
|
||||
except Exception as e:
|
||||
await self.session.rollback()
|
||||
logger.error(f"Transaction failed for {self.model.__name__}", error=str(e))
|
||||
raise
|
||||
306
shared/database/transactions.py
Normal file
306
shared/database/transactions.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Transaction Decorators and Context Managers
|
||||
Provides convenient transaction handling for service methods
|
||||
"""
|
||||
|
||||
from functools import wraps
|
||||
from typing import Callable, Any, Optional
|
||||
from contextlib import asynccontextmanager
|
||||
import structlog
|
||||
|
||||
from .base import DatabaseManager
|
||||
from .unit_of_work import UnitOfWork
|
||||
from .exceptions import TransactionError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
def transactional(database_manager: DatabaseManager, auto_commit: bool = True):
|
||||
"""
|
||||
Decorator that wraps a method in a database transaction
|
||||
|
||||
Args:
|
||||
database_manager: DatabaseManager instance
|
||||
auto_commit: Whether to auto-commit on success
|
||||
|
||||
Usage:
|
||||
@transactional(database_manager)
|
||||
async def create_user_with_profile(self, user_data, profile_data):
|
||||
# Your business logic here
|
||||
# Transaction is automatically managed
|
||||
pass
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with database_manager.get_background_session() as session:
|
||||
try:
|
||||
# Inject session into kwargs if not present
|
||||
if 'session' not in kwargs:
|
||||
kwargs['session'] = session
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
# Session is auto-committed by get_background_session
|
||||
logger.debug(f"Transaction completed successfully for {func.__name__}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Session is auto-rolled back by get_background_session
|
||||
logger.error(f"Transaction failed for {func.__name__}", error=str(e))
|
||||
raise TransactionError(f"Transaction failed: {str(e)}")
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
def unit_of_work_transactional(database_manager: DatabaseManager):
|
||||
"""
|
||||
Decorator that provides Unit of Work pattern for complex operations
|
||||
|
||||
Usage:
|
||||
@unit_of_work_transactional(database_manager)
|
||||
async def complex_business_operation(self, data, uow: UnitOfWork):
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
user = await user_repo.create(data.user)
|
||||
sale = await sales_repo.create(data.sale)
|
||||
|
||||
# UnitOfWork automatically commits
|
||||
return {"user": user, "sale": sale}
|
||||
"""
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with database_manager.get_background_session() as session:
|
||||
async with UnitOfWork(session, auto_commit=True) as uow:
|
||||
try:
|
||||
# Inject UnitOfWork into kwargs
|
||||
kwargs['uow'] = uow
|
||||
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
logger.debug(f"Unit of Work transaction completed for {func.__name__}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unit of Work transaction failed for {func.__name__}",
|
||||
error=str(e))
|
||||
raise TransactionError(f"Transaction failed: {str(e)}")
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def managed_transaction(database_manager: DatabaseManager):
|
||||
"""
|
||||
Context manager for explicit transaction control
|
||||
|
||||
Usage:
|
||||
async with managed_transaction(database_manager) as session:
|
||||
# Your database operations here
|
||||
user = User(name="John")
|
||||
session.add(user)
|
||||
# Auto-commits on exit, rolls back on exception
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
try:
|
||||
logger.debug("Starting managed transaction")
|
||||
yield session
|
||||
logger.debug("Managed transaction completed successfully")
|
||||
except Exception as e:
|
||||
logger.error("Managed transaction failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def managed_unit_of_work(database_manager: DatabaseManager, event_publisher=None):
|
||||
"""
|
||||
Context manager for explicit Unit of Work control
|
||||
|
||||
Usage:
|
||||
async with managed_unit_of_work(database_manager) as uow:
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
user = await user_repo.create(user_data)
|
||||
await uow.commit()
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
uow = UnitOfWork(session)
|
||||
try:
|
||||
logger.debug("Starting managed Unit of Work")
|
||||
yield uow
|
||||
|
||||
if not uow._committed:
|
||||
await uow.commit()
|
||||
|
||||
logger.debug("Managed Unit of Work completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
if not uow._rolled_back:
|
||||
await uow.rollback()
|
||||
logger.error("Managed Unit of Work failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
class TransactionManager:
|
||||
"""
|
||||
Advanced transaction manager for complex scenarios
|
||||
|
||||
Usage:
|
||||
tx_manager = TransactionManager(database_manager)
|
||||
|
||||
async with tx_manager.create_transaction() as tx:
|
||||
await tx.execute_in_transaction(my_business_logic, data)
|
||||
"""
|
||||
|
||||
def __init__(self, database_manager: DatabaseManager):
|
||||
self.database_manager = database_manager
|
||||
|
||||
@asynccontextmanager
|
||||
async def create_transaction(self, isolation_level: Optional[str] = None):
|
||||
"""Create a transaction with optional isolation level"""
|
||||
async with self.database_manager.get_background_session() as session:
|
||||
transaction_context = TransactionContext(session, isolation_level)
|
||||
try:
|
||||
yield transaction_context
|
||||
except Exception as e:
|
||||
logger.error("Transaction manager failed", error=str(e))
|
||||
raise
|
||||
|
||||
async def execute_with_retry(
|
||||
self,
|
||||
func: Callable,
|
||||
max_retries: int = 3,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""Execute function with transaction retry on failure"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with managed_transaction(self.database_manager) as session:
|
||||
kwargs['session'] = session
|
||||
result = await func(*args, **kwargs)
|
||||
logger.debug(f"Transaction succeeded on attempt {attempt + 1}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"Transaction attempt {attempt + 1} failed",
|
||||
error=str(e), remaining_attempts=max_retries - attempt - 1)
|
||||
|
||||
if attempt == max_retries - 1:
|
||||
break
|
||||
|
||||
logger.error(f"All transaction attempts failed after {max_retries} tries")
|
||||
raise TransactionError(f"Transaction failed after {max_retries} retries: {str(last_error)}")
|
||||
|
||||
|
||||
class TransactionContext:
|
||||
"""Context for managing individual transactions"""
|
||||
|
||||
def __init__(self, session, isolation_level: Optional[str] = None):
|
||||
self.session = session
|
||||
self.isolation_level = isolation_level
|
||||
|
||||
async def execute_in_transaction(self, func: Callable, *args, **kwargs):
|
||||
"""Execute function within the transaction context"""
|
||||
try:
|
||||
kwargs['session'] = self.session
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error("Function execution failed in transaction context", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# ===== UTILITY FUNCTIONS =====
|
||||
|
||||
async def run_in_transaction(database_manager: DatabaseManager, func: Callable, *args, **kwargs):
|
||||
"""
|
||||
Utility function to run any async function in a transaction
|
||||
|
||||
Usage:
|
||||
result = await run_in_transaction(
|
||||
database_manager,
|
||||
my_async_function,
|
||||
arg1, arg2,
|
||||
kwarg1="value"
|
||||
)
|
||||
"""
|
||||
async with managed_transaction(database_manager) as session:
|
||||
kwargs['session'] = session
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
async def run_with_unit_of_work(
|
||||
database_manager: DatabaseManager,
|
||||
func: Callable,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Utility function to run any async function with Unit of Work
|
||||
|
||||
Usage:
|
||||
result = await run_with_unit_of_work(
|
||||
database_manager,
|
||||
my_complex_function,
|
||||
arg1, arg2
|
||||
)
|
||||
"""
|
||||
async with managed_unit_of_work(database_manager) as uow:
|
||||
kwargs['uow'] = uow
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
# ===== BATCH OPERATIONS =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def batch_operation(database_manager: DatabaseManager, batch_size: int = 1000):
|
||||
"""
|
||||
Context manager for batch operations with automatic commit batching
|
||||
|
||||
Usage:
|
||||
async with batch_operation(database_manager, batch_size=500) as batch:
|
||||
for item in large_dataset:
|
||||
await batch.add_operation(create_record, item)
|
||||
"""
|
||||
async with database_manager.get_background_session() as session:
|
||||
batch_context = BatchOperationContext(session, batch_size)
|
||||
try:
|
||||
yield batch_context
|
||||
await batch_context.flush_remaining()
|
||||
except Exception as e:
|
||||
logger.error("Batch operation failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
class BatchOperationContext:
|
||||
"""Context for managing batch database operations"""
|
||||
|
||||
def __init__(self, session, batch_size: int):
|
||||
self.session = session
|
||||
self.batch_size = batch_size
|
||||
self.operation_count = 0
|
||||
|
||||
async def add_operation(self, func: Callable, *args, **kwargs):
|
||||
"""Add operation to batch"""
|
||||
kwargs['session'] = self.session
|
||||
await func(*args, **kwargs)
|
||||
|
||||
self.operation_count += 1
|
||||
|
||||
if self.operation_count >= self.batch_size:
|
||||
await self.session.commit()
|
||||
self.operation_count = 0
|
||||
logger.debug(f"Batch committed at {self.batch_size} operations")
|
||||
|
||||
async def flush_remaining(self):
|
||||
"""Commit any remaining operations"""
|
||||
if self.operation_count > 0:
|
||||
await self.session.commit()
|
||||
logger.debug(f"Final batch committed with {self.operation_count} operations")
|
||||
304
shared/database/unit_of_work.py
Normal file
304
shared/database/unit_of_work.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Unit of Work Pattern Implementation
|
||||
Manages transactions across multiple repositories with event publishing
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Type, TypeVar, Generic
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from abc import ABC, abstractmethod
|
||||
import structlog
|
||||
|
||||
from .repository import BaseRepository
|
||||
from .exceptions import TransactionError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
Model = TypeVar('Model')
|
||||
Repository = TypeVar('Repository', bound=BaseRepository)
|
||||
|
||||
|
||||
class BaseEvent(ABC):
|
||||
"""Base class for domain events"""
|
||||
|
||||
def __init__(self, event_type: str, data: Dict[str, Any]):
|
||||
self.event_type = event_type
|
||||
self.data = data
|
||||
|
||||
@abstractmethod
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert event to dictionary for publishing"""
|
||||
pass
|
||||
|
||||
|
||||
class DomainEvent(BaseEvent):
|
||||
"""Standard domain event implementation"""
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"event_type": self.event_type,
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
|
||||
class UnitOfWork:
|
||||
"""
|
||||
Unit of Work pattern for managing transactions and coordinating repositories
|
||||
|
||||
Usage:
|
||||
async with UnitOfWork(session) as uow:
|
||||
user_repo = uow.register_repository("users", UserRepository, User)
|
||||
sales_repo = uow.register_repository("sales", SalesRepository, SalesData)
|
||||
|
||||
user = await user_repo.create(user_data)
|
||||
sale = await sales_repo.create(sales_data)
|
||||
|
||||
await uow.commit()
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, auto_commit: bool = False):
|
||||
self.session = session
|
||||
self.auto_commit = auto_commit
|
||||
self._repositories: Dict[str, BaseRepository] = {}
|
||||
self._events: List[BaseEvent] = []
|
||||
self._committed = False
|
||||
self._rolled_back = False
|
||||
|
||||
def register_repository(
|
||||
self,
|
||||
name: str,
|
||||
repository_class: Type[Repository],
|
||||
model_class: Type[Model],
|
||||
**kwargs
|
||||
) -> Repository:
|
||||
"""
|
||||
Register a repository with the unit of work
|
||||
|
||||
Args:
|
||||
name: Unique name for the repository
|
||||
repository_class: Repository class to instantiate
|
||||
model_class: SQLAlchemy model class
|
||||
**kwargs: Additional arguments for repository
|
||||
|
||||
Returns:
|
||||
Instantiated repository
|
||||
"""
|
||||
if name in self._repositories:
|
||||
logger.warning(f"Repository '{name}' already registered, returning existing instance")
|
||||
return self._repositories[name]
|
||||
|
||||
repository = repository_class(model_class, self.session, **kwargs)
|
||||
self._repositories[name] = repository
|
||||
|
||||
logger.debug(f"Registered repository", name=name, model=model_class.__name__)
|
||||
return repository
|
||||
|
||||
def get_repository(self, name: str) -> Optional[Repository]:
|
||||
"""Get registered repository by name"""
|
||||
return self._repositories.get(name)
|
||||
|
||||
def add_event(self, event: BaseEvent):
|
||||
"""Add domain event to be published after commit"""
|
||||
self._events.append(event)
|
||||
logger.debug(f"Added event", event_type=event.event_type)
|
||||
|
||||
async def commit(self):
|
||||
"""Commit the transaction and publish events"""
|
||||
if self._committed:
|
||||
logger.warning("Unit of Work already committed")
|
||||
return
|
||||
|
||||
if self._rolled_back:
|
||||
raise TransactionError("Cannot commit after rollback")
|
||||
|
||||
try:
|
||||
await self.session.commit()
|
||||
self._committed = True
|
||||
|
||||
# Publish events after successful commit
|
||||
await self._publish_events()
|
||||
|
||||
logger.debug(f"Unit of Work committed successfully",
|
||||
repositories=list(self._repositories.keys()),
|
||||
events_published=len(self._events))
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
await self.rollback()
|
||||
logger.error("Failed to commit Unit of Work", error=str(e))
|
||||
raise TransactionError(f"Commit failed: {str(e)}")
|
||||
|
||||
async def rollback(self):
|
||||
"""Rollback the transaction"""
|
||||
if self._rolled_back:
|
||||
logger.warning("Unit of Work already rolled back")
|
||||
return
|
||||
|
||||
try:
|
||||
await self.session.rollback()
|
||||
self._rolled_back = True
|
||||
self._events.clear() # Clear events on rollback
|
||||
|
||||
logger.debug(f"Unit of Work rolled back",
|
||||
repositories=list(self._repositories.keys()))
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error("Failed to rollback Unit of Work", error=str(e))
|
||||
raise TransactionError(f"Rollback failed: {str(e)}")
|
||||
|
||||
async def _publish_events(self):
|
||||
"""Publish domain events (override in subclasses for actual publishing)"""
|
||||
if not self._events:
|
||||
return
|
||||
|
||||
# Default implementation just logs events
|
||||
# Override this method in service-specific implementations
|
||||
for event in self._events:
|
||||
logger.info(f"Publishing event",
|
||||
event_type=event.event_type,
|
||||
event_data=event.to_dict())
|
||||
|
||||
# Clear events after publishing
|
||||
self._events.clear()
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Async context manager entry"""
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Async context manager exit"""
|
||||
if exc_type is not None:
|
||||
# Exception occurred, rollback
|
||||
await self.rollback()
|
||||
return False
|
||||
|
||||
# No exception, auto-commit if enabled
|
||||
if self.auto_commit and not self._committed:
|
||||
await self.commit()
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class ServiceUnitOfWork(UnitOfWork):
|
||||
"""
|
||||
Service-specific Unit of Work with event publishing integration
|
||||
|
||||
Example usage with message publishing:
|
||||
|
||||
class AuthUnitOfWork(ServiceUnitOfWork):
|
||||
def __init__(self, session: AsyncSession, message_publisher=None):
|
||||
super().__init__(session)
|
||||
self.message_publisher = message_publisher
|
||||
|
||||
async def _publish_events(self):
|
||||
for event in self._events:
|
||||
if self.message_publisher:
|
||||
await self.message_publisher.publish(
|
||||
topic="auth.events",
|
||||
message=event.to_dict()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, event_publisher=None, auto_commit: bool = False):
|
||||
super().__init__(session, auto_commit)
|
||||
self.event_publisher = event_publisher
|
||||
|
||||
async def _publish_events(self):
|
||||
"""Publish events using the provided event publisher"""
|
||||
if not self._events or not self.event_publisher:
|
||||
return
|
||||
|
||||
try:
|
||||
for event in self._events:
|
||||
await self.event_publisher.publish(event)
|
||||
logger.debug(f"Published event via publisher",
|
||||
event_type=event.event_type)
|
||||
|
||||
self._events.clear()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to publish events", error=str(e))
|
||||
# Don't raise here to avoid breaking the transaction
|
||||
# Events will be retried or handled by the event publisher
|
||||
|
||||
|
||||
# ===== TRANSACTION CONTEXT MANAGER =====
|
||||
|
||||
@asynccontextmanager
|
||||
async def transaction_scope(session: AsyncSession, auto_commit: bool = True):
|
||||
"""
|
||||
Simple transaction context manager for single-repository operations
|
||||
|
||||
Usage:
|
||||
async with transaction_scope(session) as tx_session:
|
||||
user = User(name="John")
|
||||
tx_session.add(user)
|
||||
# Auto-commits on success, rolls back on exception
|
||||
"""
|
||||
try:
|
||||
yield session
|
||||
if auto_commit:
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Transaction scope failed", error=str(e))
|
||||
raise
|
||||
|
||||
|
||||
# ===== UTILITIES =====
|
||||
|
||||
class RepositoryRegistry:
|
||||
"""Registry for commonly used repository configurations"""
|
||||
|
||||
_registry: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
repository_class: Type[Repository],
|
||||
model_class: Type[Model],
|
||||
**kwargs
|
||||
):
|
||||
"""Register a repository configuration"""
|
||||
self._registry[name] = {
|
||||
"repository_class": repository_class,
|
||||
"model_class": model_class,
|
||||
"kwargs": kwargs
|
||||
}
|
||||
logger.debug(f"Registered repository configuration", name=name)
|
||||
|
||||
@classmethod
|
||||
def create_repository(self, name: str, session: AsyncSession) -> Optional[Repository]:
|
||||
"""Create repository instance from registry"""
|
||||
config = self._registry.get(name)
|
||||
if not config:
|
||||
logger.warning(f"Repository configuration '{name}' not found in registry")
|
||||
return None
|
||||
|
||||
return config["repository_class"](
|
||||
config["model_class"],
|
||||
session,
|
||||
**config["kwargs"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def list_registered(self) -> List[str]:
|
||||
"""List all registered repository names"""
|
||||
return list(self._registry.keys())
|
||||
|
||||
|
||||
# ===== FACTORY FUNCTIONS =====
|
||||
|
||||
def create_unit_of_work(session: AsyncSession, **kwargs) -> UnitOfWork:
|
||||
"""Factory function to create Unit of Work instances"""
|
||||
return UnitOfWork(session, **kwargs)
|
||||
|
||||
|
||||
def create_service_unit_of_work(
|
||||
session: AsyncSession,
|
||||
event_publisher=None,
|
||||
**kwargs
|
||||
) -> ServiceUnitOfWork:
|
||||
"""Factory function to create Service Unit of Work instances"""
|
||||
return ServiceUnitOfWork(session, event_publisher, **kwargs)
|
||||
402
shared/database/utils.py
Normal file
402
shared/database/utils.py
Normal file
@@ -0,0 +1,402 @@
|
||||
"""
|
||||
Database Utilities
|
||||
Helper functions for database operations and maintenance
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import text, inspect
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
import structlog
|
||||
|
||||
from .exceptions import DatabaseError, HealthCheckError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class DatabaseUtils:
|
||||
"""Utility functions for database operations"""
|
||||
|
||||
@staticmethod
|
||||
async def execute_health_check(session: AsyncSession, timeout: int = 5) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive database health check
|
||||
|
||||
Returns:
|
||||
Dict with health status, metrics, and diagnostics
|
||||
"""
|
||||
try:
|
||||
# Basic connectivity test
|
||||
start_time = __import__('time').time()
|
||||
await session.execute(text("SELECT 1"))
|
||||
response_time = __import__('time').time() - start_time
|
||||
|
||||
# Get database info
|
||||
db_info = await DatabaseUtils._get_database_info(session)
|
||||
|
||||
# Connection pool status (if available)
|
||||
pool_info = await DatabaseUtils._get_pool_info(session)
|
||||
|
||||
return {
|
||||
"status": "healthy",
|
||||
"response_time_seconds": round(response_time, 4),
|
||||
"database": db_info,
|
||||
"connection_pool": pool_info,
|
||||
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
raise HealthCheckError(f"Health check failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def _get_database_info(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get database server information"""
|
||||
try:
|
||||
# Try to get database version and basic stats
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
version_result = await session.execute(text("SELECT version()"))
|
||||
version = version_result.scalar()
|
||||
|
||||
stats_result = await session.execute(text("""
|
||||
SELECT
|
||||
count(*) as active_connections,
|
||||
(SELECT setting FROM pg_settings WHERE name = 'max_connections') as max_connections
|
||||
FROM pg_stat_activity
|
||||
WHERE state = 'active'
|
||||
"""))
|
||||
stats = stats_result.fetchone()
|
||||
|
||||
return {
|
||||
"type": "postgresql",
|
||||
"version": version,
|
||||
"active_connections": stats.active_connections if stats else 0,
|
||||
"max_connections": stats.max_connections if stats else "unknown"
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
version_result = await session.execute(text("SELECT sqlite_version()"))
|
||||
version = version_result.scalar()
|
||||
|
||||
return {
|
||||
"type": "sqlite",
|
||||
"version": version,
|
||||
"active_connections": 1,
|
||||
"max_connections": "unlimited"
|
||||
}
|
||||
|
||||
else:
|
||||
return {
|
||||
"type": session.bind.dialect.name,
|
||||
"version": "unknown",
|
||||
"active_connections": "unknown",
|
||||
"max_connections": "unknown"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve database info", error=str(e))
|
||||
return {
|
||||
"type": session.bind.dialect.name,
|
||||
"version": "unknown",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def _get_pool_info(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""Get connection pool information"""
|
||||
try:
|
||||
pool = session.bind.pool
|
||||
if pool:
|
||||
return {
|
||||
"size": pool.size(),
|
||||
"checked_in": pool.checkedin(),
|
||||
"checked_out": pool.checkedout(),
|
||||
"overflow": pool.overflow(),
|
||||
"invalid": pool.invalid()
|
||||
}
|
||||
else:
|
||||
return {"status": "no_pool"}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Could not retrieve pool info", error=str(e))
|
||||
return {"error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def validate_schema(session: AsyncSession, expected_tables: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Validate database schema against expected tables
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
expected_tables: List of table names that should exist
|
||||
|
||||
Returns:
|
||||
Validation results with missing/extra tables
|
||||
"""
|
||||
try:
|
||||
# Get existing tables
|
||||
inspector = inspect(session.bind)
|
||||
existing_tables = set(inspector.get_table_names())
|
||||
expected_tables_set = set(expected_tables)
|
||||
|
||||
missing_tables = expected_tables_set - existing_tables
|
||||
extra_tables = existing_tables - expected_tables_set
|
||||
|
||||
return {
|
||||
"valid": len(missing_tables) == 0,
|
||||
"existing_tables": list(existing_tables),
|
||||
"expected_tables": expected_tables,
|
||||
"missing_tables": list(missing_tables),
|
||||
"extra_tables": list(extra_tables),
|
||||
"total_tables": len(existing_tables)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Schema validation failed", error=str(e))
|
||||
raise DatabaseError(f"Schema validation failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def get_table_stats(session: AsyncSession, table_names: List[str]) -> Dict[str, Any]:
|
||||
"""
|
||||
Get statistics for specified tables
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
table_names: List of table names to analyze
|
||||
|
||||
Returns:
|
||||
Dictionary with table statistics
|
||||
"""
|
||||
try:
|
||||
stats = {}
|
||||
|
||||
for table_name in table_names:
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
# PostgreSQL specific queries
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
size_result = await session.execute(
|
||||
text(f"SELECT pg_total_relation_size('{table_name}')")
|
||||
)
|
||||
table_size = size_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": table_size,
|
||||
"size_mb": round(table_size / (1024 * 1024), 2) if table_size else 0
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
# SQLite specific queries
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": "unknown",
|
||||
"size_mb": "unknown"
|
||||
}
|
||||
|
||||
else:
|
||||
# Generic fallback
|
||||
count_result = await session.execute(
|
||||
text(f"SELECT COUNT(*) FROM {table_name}")
|
||||
)
|
||||
row_count = count_result.scalar()
|
||||
|
||||
stats[table_name] = {
|
||||
"row_count": row_count,
|
||||
"size_bytes": "unknown",
|
||||
"size_mb": "unknown"
|
||||
}
|
||||
|
||||
return stats
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get table statistics",
|
||||
tables=table_names, error=str(e))
|
||||
raise DatabaseError(f"Failed to get table stats: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_records(
|
||||
session: AsyncSession,
|
||||
table_name: str,
|
||||
date_column: str,
|
||||
days_old: int,
|
||||
batch_size: int = 1000
|
||||
) -> int:
|
||||
"""
|
||||
Clean up old records from a table
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
table_name: Name of table to clean
|
||||
date_column: Date column to filter by
|
||||
days_old: Records older than this many days will be deleted
|
||||
batch_size: Number of records to delete per batch
|
||||
|
||||
Returns:
|
||||
Total number of records deleted
|
||||
"""
|
||||
try:
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||||
AND ctid IN (
|
||||
SELECT ctid FROM {table_name}
|
||||
WHERE {date_column} < NOW() - INTERVAL :days_param
|
||||
LIMIT :batch_size
|
||||
)
|
||||
""")
|
||||
params = {
|
||||
"days_param": f"{days_old} days",
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < datetime('now', :days_param)
|
||||
AND rowid IN (
|
||||
SELECT rowid FROM {table_name}
|
||||
WHERE {date_column} < datetime('now', :days_param)
|
||||
LIMIT :batch_size
|
||||
)
|
||||
""")
|
||||
params = {
|
||||
"days_param": f"-{days_old} days",
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
else:
|
||||
# Generic fallback (may not work for all databases)
|
||||
delete_query = text(f"""
|
||||
DELETE FROM {table_name}
|
||||
WHERE {date_column} < DATE_SUB(NOW(), INTERVAL :days_old DAY)
|
||||
LIMIT :batch_size
|
||||
""")
|
||||
params = {
|
||||
"days_old": days_old,
|
||||
"batch_size": batch_size
|
||||
}
|
||||
|
||||
result = await session.execute(delete_query, params)
|
||||
deleted_count = result.rowcount
|
||||
|
||||
if deleted_count == 0:
|
||||
break
|
||||
|
||||
total_deleted += deleted_count
|
||||
await session.commit()
|
||||
|
||||
logger.debug(f"Deleted batch from {table_name}",
|
||||
batch_size=deleted_count,
|
||||
total_deleted=total_deleted)
|
||||
|
||||
logger.info(f"Cleanup completed for {table_name}",
|
||||
total_deleted=total_deleted,
|
||||
days_old=days_old)
|
||||
|
||||
return total_deleted
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Cleanup failed for {table_name}", error=str(e))
|
||||
raise DatabaseError(f"Cleanup failed: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
async def execute_maintenance(session: AsyncSession) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute database maintenance tasks
|
||||
|
||||
Returns:
|
||||
Dictionary with maintenance results
|
||||
"""
|
||||
try:
|
||||
results = {}
|
||||
|
||||
if session.bind.dialect.name == 'postgresql':
|
||||
# PostgreSQL maintenance
|
||||
await session.execute(text("VACUUM ANALYZE"))
|
||||
results["vacuum"] = "completed"
|
||||
|
||||
# Update statistics
|
||||
await session.execute(text("ANALYZE"))
|
||||
results["analyze"] = "completed"
|
||||
|
||||
elif session.bind.dialect.name == 'sqlite':
|
||||
# SQLite maintenance
|
||||
await session.execute(text("VACUUM"))
|
||||
results["vacuum"] = "completed"
|
||||
|
||||
await session.execute(text("ANALYZE"))
|
||||
results["analyze"] = "completed"
|
||||
|
||||
else:
|
||||
results["maintenance"] = "not_supported"
|
||||
|
||||
await session.commit()
|
||||
|
||||
logger.info("Database maintenance completed", results=results)
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error("Database maintenance failed", error=str(e))
|
||||
raise DatabaseError(f"Maintenance failed: {str(e)}")
|
||||
|
||||
|
||||
class QueryLogger:
|
||||
"""Utility for logging and analyzing database queries"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
self._query_log = []
|
||||
|
||||
async def log_query(self, query: str, params: Optional[Dict] = None, execution_time: Optional[float] = None):
|
||||
"""Log a database query with metadata"""
|
||||
log_entry = {
|
||||
"query": query,
|
||||
"params": params,
|
||||
"execution_time": execution_time,
|
||||
"timestamp": __import__('datetime').datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
self._query_log.append(log_entry)
|
||||
|
||||
# Log slow queries
|
||||
if execution_time and execution_time > 1.0: # 1 second threshold
|
||||
logger.warning("Slow query detected",
|
||||
query=query,
|
||||
execution_time=execution_time)
|
||||
|
||||
def get_query_stats(self) -> Dict[str, Any]:
|
||||
"""Get statistics about logged queries"""
|
||||
if not self._query_log:
|
||||
return {"total_queries": 0}
|
||||
|
||||
execution_times = [
|
||||
entry["execution_time"]
|
||||
for entry in self._query_log
|
||||
if entry["execution_time"] is not None
|
||||
]
|
||||
|
||||
return {
|
||||
"total_queries": len(self._query_log),
|
||||
"avg_execution_time": sum(execution_times) / len(execution_times) if execution_times else 0,
|
||||
"max_execution_time": max(execution_times) if execution_times else 0,
|
||||
"slow_queries_count": len([t for t in execution_times if t > 1.0])
|
||||
}
|
||||
|
||||
def clear_log(self):
|
||||
"""Clear the query log"""
|
||||
self._query_log.clear()
|
||||
@@ -100,6 +100,27 @@ class MetricsCollector:
|
||||
self._histograms[name] = histogram
|
||||
logger.info(f"Registered histogram: {name} for {self.service_name}")
|
||||
return histogram
|
||||
except ValueError as e:
|
||||
if "Duplicated timeseries" in str(e):
|
||||
# Metric already exists in global registry, try to find it
|
||||
from prometheus_client import REGISTRY
|
||||
metric_name = f"{self.service_name.replace('-', '_')}_{name}"
|
||||
for collector in REGISTRY._collector_to_names.keys():
|
||||
if hasattr(collector, '_name') and collector._name == metric_name:
|
||||
self._histograms[name] = collector
|
||||
logger.warning(f"Reusing existing histogram: {name} for {self.service_name}")
|
||||
return collector
|
||||
# If we can't find it, create a new name with suffix
|
||||
import time
|
||||
suffix = str(int(time.time() * 1000))[-6:] # Last 6 digits of timestamp
|
||||
histogram = Histogram(f"{self.service_name.replace('-', '_')}_{name}_{suffix}",
|
||||
documentation, labelnames=labels, buckets=buckets)
|
||||
self._histograms[name] = histogram
|
||||
logger.warning(f"Created histogram with suffix: {name}_{suffix} for {self.service_name}")
|
||||
return histogram
|
||||
else:
|
||||
logger.error(f"Failed to register histogram {name} for {self.service_name}: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to register histogram {name} for {self.service_name}: {e}")
|
||||
raise
|
||||
@@ -295,3 +316,14 @@ def setup_metrics_early(app, service_name: str = None) -> MetricsCollector:
|
||||
logger.info(f"Metrics setup completed for service: {service_name}")
|
||||
return metrics_collector
|
||||
|
||||
|
||||
# Additional helper function for endpoint tracking
|
||||
def track_endpoint_metrics(endpoint_name: str = None, service_name: str = None):
|
||||
"""Decorator for tracking endpoint metrics"""
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# For now, just pass through - metrics are handled by middleware
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
Reference in New Issue
Block a user