# services/training/app/core/database.py """ Database configuration for training service Uses shared database infrastructure """ import structlog from typing import AsyncGenerator from sqlalchemy.ext.asyncio import AsyncSession from contextlib import asynccontextmanager from sqlalchemy import text from shared.database.base import DatabaseManager, Base from app.core.config import settings logger = structlog.get_logger() # Initialize database manager with connection pooling configuration database_manager = DatabaseManager( settings.DATABASE_URL, pool_size=settings.DB_POOL_SIZE, max_overflow=settings.DB_MAX_OVERFLOW, pool_timeout=settings.DB_POOL_TIMEOUT, pool_recycle=settings.DB_POOL_RECYCLE, pool_pre_ping=settings.DB_POOL_PRE_PING, echo=settings.DB_ECHO ) # Alias for convenience - matches the existing interface get_db = database_manager.get_db @asynccontextmanager async def get_background_db_session(): async with database_manager.async_session_local() as session: try: yield session await session.commit() except Exception as e: await session.rollback() raise finally: await session.close() async def get_db_health() -> bool: """ Health check function for database connectivity Enhanced version of the shared functionality """ try: async with database_manager.async_engine.begin() as conn: await conn.execute(text("SELECT 1")) logger.debug("Database health check passed") return True except Exception as e: logger.error("Database health check failed", error=str(e)) return False async def get_comprehensive_db_health() -> dict: """ Comprehensive health check that verifies both connectivity and table existence """ health_status = { "status": "healthy", "connectivity": False, "tables_exist": False, "tables_verified": [], "missing_tables": [], "errors": [] } try: # Test basic connectivity health_status["connectivity"] = await get_db_health() if not health_status["connectivity"]: health_status["status"] = "unhealthy" health_status["errors"].append("Database connectivity failed") return health_status # Test table existence tables_verified = await _verify_tables_exist() health_status["tables_exist"] = tables_verified if tables_verified: health_status["tables_verified"] = [ 'model_training_logs', 'trained_models', 'model_performance_metrics', 'training_job_queue', 'model_artifacts' ] else: health_status["status"] = "unhealthy" health_status["errors"].append("Required tables missing or inaccessible") # Try to identify which specific tables are missing try: async with database_manager.get_session() as session: for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics', 'training_job_queue', 'model_artifacts']: try: await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1")) health_status["tables_verified"].append(table_name) except Exception: health_status["missing_tables"].append(table_name) except Exception as e: health_status["errors"].append(f"Error checking individual tables: {str(e)}") logger.debug("Comprehensive database health check completed", status=health_status["status"], connectivity=health_status["connectivity"], tables_exist=health_status["tables_exist"]) except Exception as e: health_status["status"] = "unhealthy" health_status["errors"].append(f"Health check failed: {str(e)}") logger.error("Comprehensive database health check failed", error=str(e)) return health_status # Training service specific database utilities class TrainingDatabaseUtils: """Training service specific database utilities""" @staticmethod async def cleanup_old_training_logs(days_old: int = 90): """Clean up old training logs""" try: async with database_manager.async_session_local() as session: if settings.DATABASE_URL.startswith("sqlite"): query = text( "DELETE FROM model_training_logs " "WHERE start_time < datetime('now', :days_param)" ) params = {"days_param": f"-{days_old} days"} else: query = text( "DELETE FROM model_training_logs " "WHERE start_time < NOW() - INTERVAL :days_param" ) params = {"days_param": f"{days_old} days"} result = await session.execute(query, params) await session.commit() deleted_count = result.rowcount logger.info("Cleaned up old training logs", deleted_count=deleted_count, days_old=days_old) return deleted_count except Exception as e: logger.error("Training logs cleanup failed", error=str(e)) raise @staticmethod async def cleanup_old_models(days_old: int = 365): """Clean up old inactive models""" try: async with database_manager.async_session_local() as session: if settings.DATABASE_URL.startswith("sqlite"): query = text( "DELETE FROM trained_models " "WHERE is_active = 0 AND created_at < datetime('now', :days_param)" ) params = {"days_param": f"-{days_old} days"} else: query = text( "DELETE FROM trained_models " "WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param" ) params = {"days_param": f"{days_old} days"} result = await session.execute(query, params) await session.commit() deleted_count = result.rowcount logger.info("Cleaned up old models", deleted_count=deleted_count, days_old=days_old) return deleted_count except Exception as e: logger.error("Model cleanup failed", error=str(e)) raise @staticmethod async def get_training_statistics(tenant_id: str = None) -> dict: """Get training statistics""" try: async with database_manager.async_session_local() as session: # Base query for training logs if tenant_id: logs_query = text( "SELECT status, COUNT(*) as count " "FROM model_training_logs " "WHERE tenant_id = :tenant_id " "GROUP BY status" ) models_query = text( "SELECT COUNT(*) as count " "FROM trained_models " "WHERE tenant_id = :tenant_id AND is_active = :is_active" ) params = {"tenant_id": tenant_id} else: logs_query = text( "SELECT status, COUNT(*) as count " "FROM model_training_logs " "GROUP BY status" ) models_query = text( "SELECT COUNT(*) as count " "FROM trained_models " "WHERE is_active = :is_active" ) params = {} # Get training job statistics logs_result = await session.execute(logs_query, params) job_stats = {row.status: row.count for row in logs_result.fetchall()} # Get active models count active_models_result = await session.execute( models_query, {**params, "is_active": True} ) active_models = active_models_result.scalar() or 0 # Get inactive models count inactive_models_result = await session.execute( models_query, {**params, "is_active": False} ) inactive_models = inactive_models_result.scalar() or 0 return { "training_jobs": job_stats, "active_models": active_models, "inactive_models": inactive_models, "total_models": active_models + inactive_models } except Exception as e: logger.error("Failed to get training statistics", error=str(e)) return { "training_jobs": {}, "active_models": 0, "inactive_models": 0, "total_models": 0 } @staticmethod async def check_tenant_data_exists(tenant_id: str) -> bool: """Check if tenant has any training data""" try: async with database_manager.async_session_local() as session: query = text( "SELECT COUNT(*) as count " "FROM model_training_logs " "WHERE tenant_id = :tenant_id " "LIMIT 1" ) result = await session.execute(query, {"tenant_id": tenant_id}) count = result.scalar() or 0 return count > 0 except Exception as e: logger.error("Failed to check tenant data existence", tenant_id=tenant_id, error=str(e)) return False # Enhanced database session dependency with better error handling async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """ Enhanced database session dependency with better logging and error handling """ async with database_manager.async_session_local() as session: try: logger.debug("Database session created") yield session except Exception as e: logger.error("Database session error", error=str(e), exc_info=True) await session.rollback() raise finally: await session.close() logger.debug("Database session closed") # Database initialization for training service async def initialize_training_database(): """Initialize database tables for training service with retry logic and verification""" import asyncio from sqlalchemy import text max_retries = 5 retry_delay = 2.0 for attempt in range(1, max_retries + 1): try: logger.info("Initializing training service database", attempt=attempt, max_retries=max_retries) # Step 1: Test database connectivity first logger.info("Testing database connectivity...") connection_ok = await database_manager.test_connection() if not connection_ok: raise Exception("Database connection test failed") logger.info("Database connectivity verified") # Step 2: Import models to ensure they're registered logger.info("Importing and registering database models...") from app.models.training import ( ModelTrainingLog, TrainedModel, ModelPerformanceMetric, TrainingJobQueue, ModelArtifact ) # Verify models are registered in metadata expected_tables = { 'model_training_logs', 'trained_models', 'model_performance_metrics', 'training_job_queue', 'model_artifacts' } registered_tables = set(Base.metadata.tables.keys()) missing_tables = expected_tables - registered_tables if missing_tables: raise Exception(f"Models not properly registered: {missing_tables}") logger.info("Models registered successfully", tables=list(registered_tables)) # Step 3: Create tables using shared infrastructure with verification logger.info("Creating database tables...") await database_manager.create_tables() # Step 4: Verify tables were actually created logger.info("Verifying table creation...") verification_successful = await _verify_tables_exist() if not verification_successful: raise Exception("Table verification failed - tables were not created properly") logger.info("Training service database initialized and verified successfully", attempt=attempt) return except Exception as e: logger.error("Database initialization failed", attempt=attempt, max_retries=max_retries, error=str(e)) if attempt == max_retries: logger.error("All database initialization attempts failed - giving up") raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}") # Wait before retry with exponential backoff wait_time = retry_delay * (2 ** (attempt - 1)) logger.info("Retrying database initialization", retry_in_seconds=wait_time, next_attempt=attempt + 1) await asyncio.sleep(wait_time) async def _verify_tables_exist() -> bool: """Verify that all required tables exist in the database""" try: async with database_manager.get_session() as session: # Check each required table exists and is accessible required_tables = [ 'model_training_logs', 'trained_models', 'model_performance_metrics', 'training_job_queue', 'model_artifacts' ] for table_name in required_tables: try: # Try to query the table structure result = await session.execute( text(f"SELECT 1 FROM {table_name} LIMIT 1") ) logger.debug(f"Table {table_name} exists and is accessible") except Exception as table_error: # If it's a "relation does not exist" error, table creation failed if "does not exist" in str(table_error).lower(): logger.error(f"Table {table_name} does not exist", error=str(table_error)) return False # If it's an empty table, that's fine - table exists elif "no data" in str(table_error).lower(): logger.debug(f"Table {table_name} exists but is empty (normal)") else: logger.warning(f"Unexpected error querying {table_name}", error=str(table_error)) logger.info("All required tables verified successfully", tables=required_tables) return True except Exception as e: logger.error("Table verification failed", error=str(e)) return False # Database cleanup for training service async def cleanup_training_database(): """Cleanup database connections for training service""" try: logger.info("Cleaning up training service database connections") # Close engine connections if hasattr(database_manager, 'async_engine') and database_manager.async_engine: await database_manager.async_engine.dispose() logger.info("Training service database cleanup completed") except Exception as e: logger.error("Failed to cleanup training service database", error=str(e)) # Export the commonly used items to maintain compatibility __all__ = [ 'Base', 'database_manager', 'get_db', 'get_db_session', 'get_db_health', 'TrainingDatabaseUtils', 'initialize_training_database', 'cleanup_training_database' ]