Files
bakery-ia/services/training/app/core/database.py

424 lines
16 KiB
Python
Raw Normal View History

2025-07-19 16:59:37 +02:00
# services/training/app/core/database.py
"""
Database configuration for training service
2025-07-19 16:59:37 +02:00
Uses shared database infrastructure
"""
2025-07-19 16:59:37 +02:00
import structlog
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
2025-08-01 16:26:36 +02:00
from contextlib import asynccontextmanager
2025-07-19 16:59:37 +02:00
from sqlalchemy import text
from shared.database.base import DatabaseManager, Base
from app.core.config import settings
2025-07-19 16:59:37 +02:00
logger = structlog.get_logger()
# Initialize database manager using shared infrastructure
database_manager = DatabaseManager(settings.DATABASE_URL)
2025-07-19 16:59:37 +02:00
# Alias for convenience - matches the existing interface
get_db = database_manager.get_db
2025-08-01 16:26:36 +02:00
@asynccontextmanager
async def get_background_db_session():
async with database_manager.async_session_local() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise
finally:
await session.close()
2025-07-19 16:59:37 +02:00
async def get_db_health() -> bool:
"""
Health check function for database connectivity
Enhanced version of the shared functionality
"""
try:
async with database_manager.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
2025-07-19 16:59:37 +02:00
except Exception as e:
logger.error("Database health check failed", error=str(e))
return False
async def get_comprehensive_db_health() -> dict:
"""
Comprehensive health check that verifies both connectivity and table existence
"""
health_status = {
"status": "healthy",
"connectivity": False,
"tables_exist": False,
"tables_verified": [],
"missing_tables": [],
"errors": []
}
try:
# Test basic connectivity
health_status["connectivity"] = await get_db_health()
if not health_status["connectivity"]:
health_status["status"] = "unhealthy"
health_status["errors"].append("Database connectivity failed")
return health_status
# Test table existence
tables_verified = await _verify_tables_exist()
health_status["tables_exist"] = tables_verified
if tables_verified:
health_status["tables_verified"] = [
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
]
else:
health_status["status"] = "unhealthy"
health_status["errors"].append("Required tables missing or inaccessible")
# Try to identify which specific tables are missing
try:
async with database_manager.get_session() as session:
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts']:
try:
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
health_status["tables_verified"].append(table_name)
except Exception:
health_status["missing_tables"].append(table_name)
except Exception as e:
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
logger.debug("Comprehensive database health check completed",
status=health_status["status"],
connectivity=health_status["connectivity"],
tables_exist=health_status["tables_exist"])
except Exception as e:
health_status["status"] = "unhealthy"
health_status["errors"].append(f"Health check failed: {str(e)}")
logger.error("Comprehensive database health check failed", error=str(e))
return health_status
2025-07-19 16:59:37 +02:00
# Training service specific database utilities
class TrainingDatabaseUtils:
"""Training service specific database utilities"""
@staticmethod
async def cleanup_old_training_logs(days_old: int = 90):
"""Clean up old training logs"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old training logs",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Training logs cleanup failed", error=str(e))
raise
@staticmethod
async def cleanup_old_models(days_old: int = 365):
"""Clean up old inactive models"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM trained_models "
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM trained_models "
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old models",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Model cleanup failed", error=str(e))
raise
@staticmethod
async def get_training_statistics(tenant_id: str = None) -> dict:
"""Get training statistics"""
try:
async with database_manager.async_session_local() as session:
# Base query for training logs
if tenant_id:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
)
params = {"tenant_id": tenant_id}
else:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE is_active = :is_active"
)
params = {}
# Get training job statistics
logs_result = await session.execute(logs_query, params)
job_stats = {row.status: row.count for row in logs_result.fetchall()}
# Get active models count
active_models_result = await session.execute(
models_query,
{**params, "is_active": True}
)
active_models = active_models_result.scalar() or 0
# Get inactive models count
inactive_models_result = await session.execute(
models_query,
{**params, "is_active": False}
)
inactive_models = inactive_models_result.scalar() or 0
return {
"training_jobs": job_stats,
"active_models": active_models,
"inactive_models": inactive_models,
"total_models": active_models + inactive_models
}
except Exception as e:
logger.error("Failed to get training statistics", error=str(e))
return {
"training_jobs": {},
"active_models": 0,
"inactive_models": 0,
"total_models": 0
}
@staticmethod
async def check_tenant_data_exists(tenant_id: str) -> bool:
"""Check if tenant has any training data"""
try:
async with database_manager.async_session_local() as session:
query = text(
"SELECT COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"LIMIT 1"
)
result = await session.execute(query, {"tenant_id": tenant_id})
count = result.scalar() or 0
return count > 0
except Exception as e:
logger.error("Failed to check tenant data existence",
tenant_id=tenant_id, error=str(e))
return False
# Enhanced database session dependency with better error handling
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Enhanced database session dependency with better logging and error handling
"""
async with database_manager.async_session_local() as session:
try:
logger.debug("Database session created")
yield session
except Exception as e:
logger.error("Database session error", error=str(e), exc_info=True)
await session.rollback()
raise
finally:
await session.close()
logger.debug("Database session closed")
# Database initialization for training service
async def initialize_training_database():
"""Initialize database tables for training service with retry logic and verification"""
import asyncio
from sqlalchemy import text
max_retries = 5
retry_delay = 2.0
for attempt in range(1, max_retries + 1):
try:
logger.info("Initializing training service database",
attempt=attempt, max_retries=max_retries)
# Step 1: Test database connectivity first
logger.info("Testing database connectivity...")
connection_ok = await database_manager.test_connection()
if not connection_ok:
raise Exception("Database connection test failed")
logger.info("Database connectivity verified")
# Step 2: Import models to ensure they're registered
logger.info("Importing and registering database models...")
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Verify models are registered in metadata
expected_tables = {
'model_training_logs', 'trained_models', 'model_performance_metrics',
'training_job_queue', 'model_artifacts'
}
registered_tables = set(Base.metadata.tables.keys())
missing_tables = expected_tables - registered_tables
if missing_tables:
raise Exception(f"Models not properly registered: {missing_tables}")
logger.info("Models registered successfully",
tables=list(registered_tables))
# Step 3: Create tables using shared infrastructure with verification
logger.info("Creating database tables...")
await database_manager.create_tables()
# Step 4: Verify tables were actually created
logger.info("Verifying table creation...")
verification_successful = await _verify_tables_exist()
if not verification_successful:
raise Exception("Table verification failed - tables were not created properly")
logger.info("Training service database initialized and verified successfully",
attempt=attempt)
return
except Exception as e:
logger.error("Database initialization failed",
attempt=attempt,
max_retries=max_retries,
error=str(e))
if attempt == max_retries:
logger.error("All database initialization attempts failed - giving up")
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
# Wait before retry with exponential backoff
wait_time = retry_delay * (2 ** (attempt - 1))
logger.info("Retrying database initialization",
retry_in_seconds=wait_time,
next_attempt=attempt + 1)
await asyncio.sleep(wait_time)
async def _verify_tables_exist() -> bool:
"""Verify that all required tables exist in the database"""
2025-07-19 16:59:37 +02:00
try:
async with database_manager.get_session() as session:
# Check each required table exists and is accessible
required_tables = [
'model_training_logs',
'trained_models',
'model_performance_metrics',
'training_job_queue',
'model_artifacts'
]
for table_name in required_tables:
try:
# Try to query the table structure
result = await session.execute(
text(f"SELECT 1 FROM {table_name} LIMIT 1")
)
logger.debug(f"Table {table_name} exists and is accessible")
except Exception as table_error:
# If it's a "relation does not exist" error, table creation failed
if "does not exist" in str(table_error).lower():
logger.error(f"Table {table_name} does not exist", error=str(table_error))
return False
# If it's an empty table, that's fine - table exists
elif "no data" in str(table_error).lower():
logger.debug(f"Table {table_name} exists but is empty (normal)")
else:
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
logger.info("All required tables verified successfully",
tables=required_tables)
return True
2025-07-19 16:59:37 +02:00
except Exception as e:
logger.error("Table verification failed", error=str(e))
return False
2025-07-19 16:59:37 +02:00
# Database cleanup for training service
async def cleanup_training_database():
"""Cleanup database connections for training service"""
try:
logger.info("Cleaning up training service database connections")
# Close engine connections
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
await database_manager.async_engine.dispose()
logger.info("Training service database cleanup completed")
except Exception as e:
logger.error("Failed to cleanup training service database", error=str(e))
# Export the commonly used items to maintain compatibility
__all__ = [
'Base',
'database_manager',
'get_db',
'get_db_session',
'get_db_health',
'TrainingDatabaseUtils',
'initialize_training_database',
'cleanup_training_database'
]