273 lines
10 KiB
Python
273 lines
10 KiB
Python
# services/training/app/core/database.py
|
|
"""
|
|
Database configuration for training service
|
|
Uses shared database infrastructure
|
|
"""
|
|
|
|
import structlog
|
|
from typing import AsyncGenerator
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from contextlib import asynccontextmanager
|
|
from sqlalchemy import text
|
|
|
|
from shared.database.base import DatabaseManager, Base
|
|
from app.core.config import settings
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
# Initialize database manager using shared infrastructure
|
|
database_manager = DatabaseManager(settings.DATABASE_URL)
|
|
|
|
# Alias for convenience - matches the existing interface
|
|
get_db = database_manager.get_db
|
|
|
|
@asynccontextmanager
|
|
async def get_background_db_session():
|
|
async with database_manager.async_session_local() as session:
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception as e:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
async def get_db_health() -> bool:
|
|
"""
|
|
Health check function for database connectivity
|
|
Enhanced version of the shared functionality
|
|
"""
|
|
try:
|
|
async with database_manager.async_engine.begin() as conn:
|
|
await conn.execute(text("SELECT 1"))
|
|
logger.debug("Database health check passed")
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error("Database health check failed", error=str(e))
|
|
return False
|
|
|
|
# Training service specific database utilities
|
|
class TrainingDatabaseUtils:
|
|
"""Training service specific database utilities"""
|
|
|
|
@staticmethod
|
|
async def cleanup_old_training_logs(days_old: int = 90):
|
|
"""Clean up old training logs"""
|
|
try:
|
|
async with database_manager.async_session_local() as session:
|
|
if settings.DATABASE_URL.startswith("sqlite"):
|
|
query = text(
|
|
"DELETE FROM model_training_logs "
|
|
"WHERE start_time < datetime('now', :days_param)"
|
|
)
|
|
params = {"days_param": f"-{days_old} days"}
|
|
else:
|
|
query = text(
|
|
"DELETE FROM model_training_logs "
|
|
"WHERE start_time < NOW() - INTERVAL :days_param"
|
|
)
|
|
params = {"days_param": f"{days_old} days"}
|
|
|
|
result = await session.execute(query, params)
|
|
await session.commit()
|
|
|
|
deleted_count = result.rowcount
|
|
logger.info("Cleaned up old training logs",
|
|
deleted_count=deleted_count,
|
|
days_old=days_old)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Training logs cleanup failed", error=str(e))
|
|
raise
|
|
|
|
@staticmethod
|
|
async def cleanup_old_models(days_old: int = 365):
|
|
"""Clean up old inactive models"""
|
|
try:
|
|
async with database_manager.async_session_local() as session:
|
|
if settings.DATABASE_URL.startswith("sqlite"):
|
|
query = text(
|
|
"DELETE FROM trained_models "
|
|
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
|
|
)
|
|
params = {"days_param": f"-{days_old} days"}
|
|
else:
|
|
query = text(
|
|
"DELETE FROM trained_models "
|
|
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
|
|
)
|
|
params = {"days_param": f"{days_old} days"}
|
|
|
|
result = await session.execute(query, params)
|
|
await session.commit()
|
|
|
|
deleted_count = result.rowcount
|
|
logger.info("Cleaned up old models",
|
|
deleted_count=deleted_count,
|
|
days_old=days_old)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Model cleanup failed", error=str(e))
|
|
raise
|
|
|
|
@staticmethod
|
|
async def get_training_statistics(tenant_id: str = None) -> dict:
|
|
"""Get training statistics"""
|
|
try:
|
|
async with database_manager.async_session_local() as session:
|
|
# Base query for training logs
|
|
if tenant_id:
|
|
logs_query = text(
|
|
"SELECT status, COUNT(*) as count "
|
|
"FROM model_training_logs "
|
|
"WHERE tenant_id = :tenant_id "
|
|
"GROUP BY status"
|
|
)
|
|
models_query = text(
|
|
"SELECT COUNT(*) as count "
|
|
"FROM trained_models "
|
|
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
|
|
)
|
|
params = {"tenant_id": tenant_id}
|
|
else:
|
|
logs_query = text(
|
|
"SELECT status, COUNT(*) as count "
|
|
"FROM model_training_logs "
|
|
"GROUP BY status"
|
|
)
|
|
models_query = text(
|
|
"SELECT COUNT(*) as count "
|
|
"FROM trained_models "
|
|
"WHERE is_active = :is_active"
|
|
)
|
|
params = {}
|
|
|
|
# Get training job statistics
|
|
logs_result = await session.execute(logs_query, params)
|
|
job_stats = {row.status: row.count for row in logs_result.fetchall()}
|
|
|
|
# Get active models count
|
|
active_models_result = await session.execute(
|
|
models_query,
|
|
{**params, "is_active": True}
|
|
)
|
|
active_models = active_models_result.scalar() or 0
|
|
|
|
# Get inactive models count
|
|
inactive_models_result = await session.execute(
|
|
models_query,
|
|
{**params, "is_active": False}
|
|
)
|
|
inactive_models = inactive_models_result.scalar() or 0
|
|
|
|
return {
|
|
"training_jobs": job_stats,
|
|
"active_models": active_models,
|
|
"inactive_models": inactive_models,
|
|
"total_models": active_models + inactive_models
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get training statistics", error=str(e))
|
|
return {
|
|
"training_jobs": {},
|
|
"active_models": 0,
|
|
"inactive_models": 0,
|
|
"total_models": 0
|
|
}
|
|
|
|
@staticmethod
|
|
async def check_tenant_data_exists(tenant_id: str) -> bool:
|
|
"""Check if tenant has any training data"""
|
|
try:
|
|
async with database_manager.async_session_local() as session:
|
|
query = text(
|
|
"SELECT COUNT(*) as count "
|
|
"FROM model_training_logs "
|
|
"WHERE tenant_id = :tenant_id "
|
|
"LIMIT 1"
|
|
)
|
|
|
|
result = await session.execute(query, {"tenant_id": tenant_id})
|
|
count = result.scalar() or 0
|
|
|
|
return count > 0
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to check tenant data existence",
|
|
tenant_id=tenant_id, error=str(e))
|
|
return False
|
|
|
|
# Enhanced database session dependency with better error handling
|
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""
|
|
Enhanced database session dependency with better logging and error handling
|
|
"""
|
|
async with database_manager.async_session_local() as session:
|
|
try:
|
|
logger.debug("Database session created")
|
|
yield session
|
|
except Exception as e:
|
|
logger.error("Database session error", error=str(e), exc_info=True)
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
logger.debug("Database session closed")
|
|
|
|
# Database initialization for training service
|
|
async def initialize_training_database():
|
|
"""Initialize database tables for training service"""
|
|
try:
|
|
logger.info("Initializing training service database")
|
|
|
|
# Import models to ensure they're registered
|
|
from app.models.training import (
|
|
ModelTrainingLog,
|
|
TrainedModel,
|
|
ModelPerformanceMetric,
|
|
TrainingJobQueue,
|
|
ModelArtifact
|
|
)
|
|
|
|
# Create tables using shared infrastructure
|
|
await database_manager.create_tables()
|
|
|
|
logger.info("Training service database initialized successfully")
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to initialize training service database", error=str(e))
|
|
raise
|
|
|
|
# Database cleanup for training service
|
|
async def cleanup_training_database():
|
|
"""Cleanup database connections for training service"""
|
|
try:
|
|
logger.info("Cleaning up training service database connections")
|
|
|
|
# Close engine connections
|
|
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
|
|
await database_manager.async_engine.dispose()
|
|
|
|
logger.info("Training service database cleanup completed")
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup training service database", error=str(e))
|
|
|
|
# Export the commonly used items to maintain compatibility
|
|
__all__ = [
|
|
'Base',
|
|
'database_manager',
|
|
'get_db',
|
|
'get_db_session',
|
|
'get_db_health',
|
|
'TrainingDatabaseUtils',
|
|
'initialize_training_database',
|
|
'cleanup_training_database'
|
|
] |