Files
bakery-ia/services/training/app/core/database.py
2025-09-29 07:54:25 +02:00

424 lines
16 KiB
Python

# services/training/app/core/database.py
"""
Database configuration for training service
Uses shared database infrastructure
"""
import structlog
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
from contextlib import asynccontextmanager
from sqlalchemy import text
from shared.database.base import DatabaseManager, Base
from app.core.config import settings
logger = structlog.get_logger()
# Initialize database manager using shared infrastructure
database_manager = DatabaseManager(settings.DATABASE_URL)
# Alias for convenience - matches the existing interface
get_db = database_manager.get_db
@asynccontextmanager
async def get_background_db_session():
async with database_manager.async_session_local() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise
finally:
await session.close()
async def get_db_health() -> bool:
"""
Health check function for database connectivity
Enhanced version of the shared functionality
"""
try:
async with database_manager.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
except Exception as e:
logger.error("Database health check failed", error=str(e))
return False
async def get_comprehensive_db_health() -> dict:
"""
Comprehensive health check that verifies both connectivity and table existence
"""
health_status = {
"status": "healthy",
"connectivity": False,
"tables_exist": False,
"tables_verified": [],
"missing_tables": [],
"errors": []
}
try:
# Test basic connectivity
health_status["connectivity"] = await get_db_health()
if not health_status["connectivity"]:
health_status["status"] = "unhealthy"
health_status["errors"].append("Database connectivity failed")
return health_status
# Test table existence
tables_verified = await _verify_tables_exist()
health_status["tables_exist"] = tables_verified
if tables_verified:
health_status["tables_verified"] = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
else:
health_status["status"] = "unhealthy"
health_status["errors"].append("Required tables missing or inaccessible")
# Try to identify which specific tables are missing
try:
async with database_manager.get_session() as session:
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts']:
try:
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
health_status["tables_verified"].append(table_name)
except Exception:
health_status["missing_tables"].append(table_name)
except Exception as e:
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
logger.debug("Comprehensive database health check completed",
status=health_status["status"],
connectivity=health_status["connectivity"],
tables_exist=health_status["tables_exist"])
except Exception as e:
health_status["status"] = "unhealthy"
health_status["errors"].append(f"Health check failed: {str(e)}")
logger.error("Comprehensive database health check failed", error=str(e))
return health_status
# Training service specific database utilities
class TrainingDatabaseUtils:
"""Training service specific database utilities"""
@staticmethod
async def cleanup_old_training_logs(days_old: int = 90):
"""Clean up old training logs"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old training logs",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Training logs cleanup failed", error=str(e))
raise
@staticmethod
async def cleanup_old_models(days_old: int = 365):
"""Clean up old inactive models"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM trained_models "
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM trained_models "
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old models",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Model cleanup failed", error=str(e))
raise
@staticmethod
async def get_training_statistics(tenant_id: str = None) -> dict:
"""Get training statistics"""
try:
async with database_manager.async_session_local() as session:
# Base query for training logs
if tenant_id:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
)
params = {"tenant_id": tenant_id}
else:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE is_active = :is_active"
)
params = {}
# Get training job statistics
logs_result = await session.execute(logs_query, params)
job_stats = {row.status: row.count for row in logs_result.fetchall()}
# Get active models count
active_models_result = await session.execute(
models_query,
{**params, "is_active": True}
)
active_models = active_models_result.scalar() or 0
# Get inactive models count
inactive_models_result = await session.execute(
models_query,
{**params, "is_active": False}
)
inactive_models = inactive_models_result.scalar() or 0
return {
"training_jobs": job_stats,
"active_models": active_models,
"inactive_models": inactive_models,
"total_models": active_models + inactive_models
}
except Exception as e:
logger.error("Failed to get training statistics", error=str(e))
return {
"training_jobs": {},
"active_models": 0,
"inactive_models": 0,
"total_models": 0
}
@staticmethod
async def check_tenant_data_exists(tenant_id: str) -> bool:
"""Check if tenant has any training data"""
try:
async with database_manager.async_session_local() as session:
query = text(
"SELECT COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"LIMIT 1"
)
result = await session.execute(query, {"tenant_id": tenant_id})
count = result.scalar() or 0
return count > 0
except Exception as e:
logger.error("Failed to check tenant data existence",
tenant_id=tenant_id, error=str(e))
return False
# Enhanced database session dependency with better error handling
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Enhanced database session dependency with better logging and error handling
"""
async with database_manager.async_session_local() as session:
try:
logger.debug("Database session created")
yield session
except Exception as e:
logger.error("Database session error", error=str(e), exc_info=True)
await session.rollback()
raise
finally:
await session.close()
logger.debug("Database session closed")
# Database initialization for training service
async def initialize_training_database():
"""Initialize database tables for training service with retry logic and verification"""
import asyncio
from sqlalchemy import text
max_retries = 5
retry_delay = 2.0
for attempt in range(1, max_retries + 1):
try:
logger.info("Initializing training service database",
attempt=attempt, max_retries=max_retries)
# Step 1: Test database connectivity first
logger.info("Testing database connectivity...")
connection_ok = await database_manager.test_connection()
if not connection_ok:
raise Exception("Database connection test failed")
logger.info("Database connectivity verified")
# Step 2: Import models to ensure they're registered
logger.info("Importing and registering database models...")
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Verify models are registered in metadata
expected_tables = {
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
}
registered_tables = set(Base.metadata.tables.keys())
missing_tables = expected_tables - registered_tables
if missing_tables:
raise Exception(f"Models not properly registered: {missing_tables}")
logger.info("Models registered successfully",
tables=list(registered_tables))
# Step 3: Create tables using shared infrastructure with verification
logger.info("Creating database tables...")
await database_manager.create_tables()
# Step 4: Verify tables were actually created
logger.info("Verifying table creation...")
verification_successful = await _verify_tables_exist()
if not verification_successful:
raise Exception("Table verification failed - tables were not created properly")
logger.info("Training service database initialized and verified successfully",
attempt=attempt)
return
except Exception as e:
logger.error("Database initialization failed",
attempt=attempt,
max_retries=max_retries,
error=str(e))
if attempt == max_retries:
logger.error("All database initialization attempts failed - giving up")
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
# Wait before retry with exponential backoff
wait_time = retry_delay * (2 ** (attempt - 1))
logger.info("Retrying database initialization",
retry_in_seconds=wait_time,
next_attempt=attempt + 1)
await asyncio.sleep(wait_time)
async def _verify_tables_exist() -> bool:
"""Verify that all required tables exist in the database"""
try:
async with database_manager.get_session() as session:
# Check each required table exists and is accessible
required_tables = [
'model_training_logs',
'trained_models',
'model_performance_metrics',
'training_job_queue',
'model_artifacts'
]
for table_name in required_tables:
try:
# Try to query the table structure
result = await session.execute(
text(f"SELECT 1 FROM {table_name} LIMIT 1")
)
logger.debug(f"Table {table_name} exists and is accessible")
except Exception as table_error:
# If it's a "relation does not exist" error, table creation failed
if "does not exist" in str(table_error).lower():
logger.error(f"Table {table_name} does not exist", error=str(table_error))
return False
# If it's an empty table, that's fine - table exists
elif "no data" in str(table_error).lower():
logger.debug(f"Table {table_name} exists but is empty (normal)")
else:
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
logger.info("All required tables verified successfully",
tables=required_tables)
return True
except Exception as e:
logger.error("Table verification failed", error=str(e))
return False
# Database cleanup for training service
async def cleanup_training_database():
"""Cleanup database connections for training service"""
try:
logger.info("Cleaning up training service database connections")
# Close engine connections
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
await database_manager.async_engine.dispose()
logger.info("Training service database cleanup completed")
except Exception as e:
logger.error("Failed to cleanup training service database", error=str(e))
# Export the commonly used items to maintain compatibility
__all__ = [
'Base',
'database_manager',
'get_db',
'get_db_session',
'get_db_health',
'TrainingDatabaseUtils',
'initialize_training_database',
'cleanup_training_database'
]