Files
bakery-ia/services/training/app/core/database.py

273 lines
10 KiB
Python
Raw Normal View History

2025-07-19 16:59:37 +02:00
# services/training/app/core/database.py
"""
Database configuration for training service
2025-07-19 16:59:37 +02:00
Uses shared database infrastructure
"""
2025-07-19 16:59:37 +02:00
import structlog
from typing import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession
2025-08-01 16:26:36 +02:00
from contextlib import asynccontextmanager
2025-07-19 16:59:37 +02:00
from sqlalchemy import text
from shared.database.base import DatabaseManager, Base
from app.core.config import settings
2025-07-19 16:59:37 +02:00
logger = structlog.get_logger()
# Initialize database manager using shared infrastructure
database_manager = DatabaseManager(settings.DATABASE_URL)
2025-07-19 16:59:37 +02:00
# Alias for convenience - matches the existing interface
get_db = database_manager.get_db
2025-08-01 16:26:36 +02:00
@asynccontextmanager
async def get_background_db_session():
async with database_manager.async_session_local() as session:
try:
yield session
await session.commit()
except Exception as e:
await session.rollback()
raise
finally:
await session.close()
2025-07-19 16:59:37 +02:00
async def get_db_health() -> bool:
"""
Health check function for database connectivity
Enhanced version of the shared functionality
"""
try:
async with database_manager.async_engine.begin() as conn:
await conn.execute(text("SELECT 1"))
logger.debug("Database health check passed")
return True
except Exception as e:
logger.error("Database health check failed", error=str(e))
return False
# Training service specific database utilities
class TrainingDatabaseUtils:
"""Training service specific database utilities"""
@staticmethod
async def cleanup_old_training_logs(days_old: int = 90):
"""Clean up old training logs"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM model_training_logs "
"WHERE start_time < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old training logs",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Training logs cleanup failed", error=str(e))
raise
@staticmethod
async def cleanup_old_models(days_old: int = 365):
"""Clean up old inactive models"""
try:
async with database_manager.async_session_local() as session:
if settings.DATABASE_URL.startswith("sqlite"):
query = text(
"DELETE FROM trained_models "
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
)
params = {"days_param": f"-{days_old} days"}
else:
query = text(
"DELETE FROM trained_models "
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
)
params = {"days_param": f"{days_old} days"}
result = await session.execute(query, params)
await session.commit()
deleted_count = result.rowcount
logger.info("Cleaned up old models",
deleted_count=deleted_count,
days_old=days_old)
return deleted_count
except Exception as e:
logger.error("Model cleanup failed", error=str(e))
raise
@staticmethod
async def get_training_statistics(tenant_id: str = None) -> dict:
"""Get training statistics"""
try:
async with database_manager.async_session_local() as session:
# Base query for training logs
if tenant_id:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
)
params = {"tenant_id": tenant_id}
else:
logs_query = text(
"SELECT status, COUNT(*) as count "
"FROM model_training_logs "
"GROUP BY status"
)
models_query = text(
"SELECT COUNT(*) as count "
"FROM trained_models "
"WHERE is_active = :is_active"
)
params = {}
# Get training job statistics
logs_result = await session.execute(logs_query, params)
job_stats = {row.status: row.count for row in logs_result.fetchall()}
# Get active models count
active_models_result = await session.execute(
models_query,
{**params, "is_active": True}
)
active_models = active_models_result.scalar() or 0
# Get inactive models count
inactive_models_result = await session.execute(
models_query,
{**params, "is_active": False}
)
inactive_models = inactive_models_result.scalar() or 0
return {
"training_jobs": job_stats,
"active_models": active_models,
"inactive_models": inactive_models,
"total_models": active_models + inactive_models
}
except Exception as e:
logger.error("Failed to get training statistics", error=str(e))
return {
"training_jobs": {},
"active_models": 0,
"inactive_models": 0,
"total_models": 0
}
@staticmethod
async def check_tenant_data_exists(tenant_id: str) -> bool:
"""Check if tenant has any training data"""
try:
async with database_manager.async_session_local() as session:
query = text(
"SELECT COUNT(*) as count "
"FROM model_training_logs "
"WHERE tenant_id = :tenant_id "
"LIMIT 1"
)
result = await session.execute(query, {"tenant_id": tenant_id})
count = result.scalar() or 0
return count > 0
except Exception as e:
logger.error("Failed to check tenant data existence",
tenant_id=tenant_id, error=str(e))
return False
# Enhanced database session dependency with better error handling
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""
Enhanced database session dependency with better logging and error handling
"""
async with database_manager.async_session_local() as session:
try:
logger.debug("Database session created")
yield session
except Exception as e:
logger.error("Database session error", error=str(e), exc_info=True)
await session.rollback()
raise
finally:
await session.close()
logger.debug("Database session closed")
# Database initialization for training service
async def initialize_training_database():
"""Initialize database tables for training service"""
try:
logger.info("Initializing training service database")
# Import models to ensure they're registered
from app.models.training import (
ModelTrainingLog,
TrainedModel,
ModelPerformanceMetric,
TrainingJobQueue,
ModelArtifact
)
# Create tables using shared infrastructure
await database_manager.create_tables()
logger.info("Training service database initialized successfully")
except Exception as e:
logger.error("Failed to initialize training service database", error=str(e))
raise
# Database cleanup for training service
async def cleanup_training_database():
"""Cleanup database connections for training service"""
try:
logger.info("Cleaning up training service database connections")
# Close engine connections
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
await database_manager.async_engine.dispose()
logger.info("Training service database cleanup completed")
except Exception as e:
logger.error("Failed to cleanup training service database", error=str(e))
# Export the commonly used items to maintain compatibility
__all__ = [
'Base',
'database_manager',
'get_db',
'get_db_session',
'get_db_health',
'TrainingDatabaseUtils',
'initialize_training_database',
'cleanup_training_database'
]