Initial commit - production deployment
This commit is contained in:
0
services/training/app/core/__init__.py
Normal file
0
services/training/app/core/__init__.py
Normal file
89
services/training/app/core/config.py
Normal file
89
services/training/app/core/config.py
Normal file
@@ -0,0 +1,89 @@
|
||||
# ================================================================
|
||||
# TRAINING SERVICE CONFIGURATION
|
||||
# services/training/app/core/config.py
|
||||
# ================================================================
|
||||
|
||||
"""
|
||||
Training service configuration
|
||||
ML model training and management
|
||||
"""
|
||||
|
||||
from shared.config.base import BaseServiceSettings
|
||||
import os
|
||||
|
||||
class TrainingSettings(BaseServiceSettings):
|
||||
"""Training service specific settings"""
|
||||
|
||||
# Service Identity
|
||||
APP_NAME: str = "Training Service"
|
||||
SERVICE_NAME: str = "training-service"
|
||||
DESCRIPTION: str = "Machine learning model training service"
|
||||
|
||||
# Database configuration (secure approach - build from components)
|
||||
@property
|
||||
def DATABASE_URL(self) -> str:
|
||||
"""Build database URL from secure components"""
|
||||
# Try complete URL first (for backward compatibility)
|
||||
complete_url = os.getenv("TRAINING_DATABASE_URL")
|
||||
if complete_url:
|
||||
return complete_url
|
||||
|
||||
# Build from components (secure approach)
|
||||
user = os.getenv("TRAINING_DB_USER", "training_user")
|
||||
password = os.getenv("TRAINING_DB_PASSWORD", "training_pass123")
|
||||
host = os.getenv("TRAINING_DB_HOST", "localhost")
|
||||
port = os.getenv("TRAINING_DB_PORT", "5432")
|
||||
name = os.getenv("TRAINING_DB_NAME", "training_db")
|
||||
|
||||
return f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{name}"
|
||||
|
||||
# Redis Database (dedicated for training cache)
|
||||
REDIS_DB: int = 1
|
||||
|
||||
# ML Model Storage
|
||||
MODEL_BACKUP_ENABLED: bool = os.getenv("MODEL_BACKUP_ENABLED", "true").lower() == "true"
|
||||
MODEL_VERSIONING_ENABLED: bool = os.getenv("MODEL_VERSIONING_ENABLED", "true").lower() == "true"
|
||||
|
||||
# MinIO Configuration
|
||||
MINIO_ENDPOINT: str = os.getenv("MINIO_ENDPOINT", "minio.bakery-ia.svc.cluster.local:9000")
|
||||
MINIO_ACCESS_KEY: str = os.getenv("MINIO_ACCESS_KEY", "training-service")
|
||||
MINIO_SECRET_KEY: str = os.getenv("MINIO_SECRET_KEY", "training-secret-key")
|
||||
MINIO_USE_SSL: bool = os.getenv("MINIO_USE_SSL", "true").lower() == "true"
|
||||
MINIO_MODEL_BUCKET: str = os.getenv("MINIO_MODEL_BUCKET", "training-models")
|
||||
MINIO_CONSOLE_PORT: str = os.getenv("MINIO_CONSOLE_PORT", "9001")
|
||||
MINIO_API_PORT: str = os.getenv("MINIO_API_PORT", "9000")
|
||||
MINIO_REGION: str = os.getenv("MINIO_REGION", "us-east-1")
|
||||
MINIO_MODEL_LIFECYCLE_DAYS: int = int(os.getenv("MINIO_MODEL_LIFECYCLE_DAYS", "90"))
|
||||
MINIO_CACHE_TTL_SECONDS: int = int(os.getenv("MINIO_CACHE_TTL_SECONDS", "3600"))
|
||||
|
||||
# Training Configuration
|
||||
MAX_CONCURRENT_TRAINING_JOBS: int = int(os.getenv("MAX_CONCURRENT_TRAINING_JOBS", "3"))
|
||||
|
||||
# Prophet Specific Configuration
|
||||
PROPHET_HOLIDAYS_PRIOR_SCALE: float = float(os.getenv("PROPHET_HOLIDAYS_PRIOR_SCALE", "10.0"))
|
||||
|
||||
# Spanish Holiday Integration
|
||||
ENABLE_CUSTOM_HOLIDAYS: bool = os.getenv("ENABLE_CUSTOM_HOLIDAYS", "true").lower() == "true"
|
||||
|
||||
# Data Processing
|
||||
DATA_PREPROCESSING_ENABLED: bool = True
|
||||
OUTLIER_DETECTION_ENABLED: bool = os.getenv("OUTLIER_DETECTION_ENABLED", "true").lower() == "true"
|
||||
SEASONAL_DECOMPOSITION_ENABLED: bool = os.getenv("SEASONAL_DECOMPOSITION_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Model Validation
|
||||
CROSS_VALIDATION_ENABLED: bool = os.getenv("CROSS_VALIDATION_ENABLED", "true").lower() == "true"
|
||||
VALIDATION_SPLIT_RATIO: float = float(os.getenv("VALIDATION_SPLIT_RATIO", "0.2"))
|
||||
MIN_MODEL_ACCURACY: float = float(os.getenv("MIN_MODEL_ACCURACY", "0.7"))
|
||||
|
||||
# Distributed Training (for future scaling)
|
||||
DISTRIBUTED_TRAINING_ENABLED: bool = os.getenv("DISTRIBUTED_TRAINING_ENABLED", "false").lower() == "true"
|
||||
TRAINING_WORKER_COUNT: int = int(os.getenv("TRAINING_WORKER_COUNT", "1"))
|
||||
|
||||
PROPHET_DAILY_SEASONALITY: bool = True
|
||||
PROPHET_WEEKLY_SEASONALITY: bool = True
|
||||
PROPHET_YEARLY_SEASONALITY: bool = True
|
||||
|
||||
# Throttling settings for parallel training to prevent heartbeat blocking
|
||||
MAX_CONCURRENT_TRAININGS: int = int(os.getenv("MAX_CONCURRENT_TRAININGS", "3"))
|
||||
|
||||
settings = TrainingSettings()
|
||||
97
services/training/app/core/constants.py
Normal file
97
services/training/app/core/constants.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Training Service Constants
|
||||
Centralized constants to avoid magic numbers throughout the codebase
|
||||
"""
|
||||
|
||||
# Data Validation Thresholds
|
||||
MIN_DATA_POINTS_REQUIRED = 30
|
||||
RECOMMENDED_DATA_POINTS = 90
|
||||
MAX_ZERO_RATIO_ERROR = 0.9 # 90% zeros = error
|
||||
HIGH_ZERO_RATIO_WARNING = 0.7 # 70% zeros = warning
|
||||
MAX_ZERO_RATIO_INTERMITTENT = 0.8 # Products with >80% zeros are intermittent
|
||||
MODERATE_SPARSITY_THRESHOLD = 0.6 # 60% zeros = moderate sparsity
|
||||
|
||||
# Training Time Periods (in days)
|
||||
MIN_NON_ZERO_DAYS = 30 # Minimum days with non-zero sales
|
||||
DATA_QUALITY_DAY_THRESHOLD_LOW = 90
|
||||
DATA_QUALITY_DAY_THRESHOLD_HIGH = 365
|
||||
MAX_TRAINING_RANGE_DAYS = 730 # 2 years
|
||||
MIN_TRAINING_RANGE_DAYS = 30
|
||||
|
||||
# Product Classification Thresholds
|
||||
HIGH_VOLUME_MEAN_SALES = 10.0
|
||||
HIGH_VOLUME_ZERO_RATIO = 0.3
|
||||
MEDIUM_VOLUME_MEAN_SALES = 5.0
|
||||
MEDIUM_VOLUME_ZERO_RATIO = 0.5
|
||||
LOW_VOLUME_MEAN_SALES = 2.0
|
||||
LOW_VOLUME_ZERO_RATIO = 0.7
|
||||
|
||||
# Hyperparameter Optimization
|
||||
OPTUNA_TRIALS_HIGH_VOLUME = 30
|
||||
OPTUNA_TRIALS_MEDIUM_VOLUME = 25
|
||||
OPTUNA_TRIALS_LOW_VOLUME = 20
|
||||
OPTUNA_TRIALS_INTERMITTENT = 15
|
||||
OPTUNA_TIMEOUT_SECONDS = 600
|
||||
|
||||
# Prophet Uncertainty Sampling
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MIN = 100
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MAX = 200
|
||||
UNCERTAINTY_SAMPLES_LOW_MIN = 150
|
||||
UNCERTAINTY_SAMPLES_LOW_MAX = 300
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MIN = 200
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MAX = 500
|
||||
UNCERTAINTY_SAMPLES_HIGH_MIN = 300
|
||||
UNCERTAINTY_SAMPLES_HIGH_MAX = 800
|
||||
|
||||
# MAPE Calculation
|
||||
MAPE_LOW_VOLUME_THRESHOLD = 2.0
|
||||
MAPE_MEDIUM_VOLUME_THRESHOLD = 5.0
|
||||
MAPE_CALCULATION_MIN_THRESHOLD = 0.5
|
||||
MAPE_CALCULATION_MID_THRESHOLD = 1.0
|
||||
MAPE_MAX_CAP = 200.0 # Cap MAPE at 200%
|
||||
MAPE_MEDIUM_CAP = 150.0
|
||||
|
||||
# Baseline MAPE estimates for improvement calculation
|
||||
BASELINE_MAPE_VERY_SPARSE = 80.0
|
||||
BASELINE_MAPE_SPARSE = 60.0
|
||||
BASELINE_MAPE_HIGH_VOLUME = 25.0
|
||||
BASELINE_MAPE_MEDIUM_VOLUME = 35.0
|
||||
BASELINE_MAPE_LOW_VOLUME = 45.0
|
||||
IMPROVEMENT_SIGNIFICANCE_THRESHOLD = 0.8 # Only claim improvement if MAPE < 80% of baseline
|
||||
|
||||
# Cross-validation
|
||||
CV_N_SPLITS = 2
|
||||
CV_MIN_VALIDATION_DAYS = 7
|
||||
|
||||
# Progress tracking
|
||||
PROGRESS_DATA_PREPARATION_START = 0
|
||||
PROGRESS_DATA_PREPARATION_END = 45
|
||||
PROGRESS_MODEL_TRAINING_START = 45
|
||||
PROGRESS_MODEL_TRAINING_END = 85
|
||||
PROGRESS_FINALIZATION_START = 85
|
||||
PROGRESS_FINALIZATION_END = 100
|
||||
|
||||
# HTTP Client Configuration
|
||||
HTTP_TIMEOUT_DEFAULT = 30.0 # seconds
|
||||
HTTP_TIMEOUT_LONG_RUNNING = 60.0 # for training data fetches
|
||||
HTTP_MAX_RETRIES = 3
|
||||
HTTP_RETRY_BACKOFF_FACTOR = 2.0
|
||||
|
||||
# WebSocket Configuration
|
||||
WEBSOCKET_PING_TIMEOUT = 60.0 # seconds
|
||||
WEBSOCKET_ACTIVITY_WARNING_THRESHOLD = 90.0 # seconds
|
||||
WEBSOCKET_CONSUMER_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
|
||||
# Synthetic Data Generation
|
||||
SYNTHETIC_TEMP_DEFAULT = 50.0
|
||||
SYNTHETIC_TEMP_VARIATION = 100.0
|
||||
SYNTHETIC_TRAFFIC_DEFAULT = 50.0
|
||||
SYNTHETIC_TRAFFIC_VARIATION = 100.0
|
||||
|
||||
# Model Storage
|
||||
MODEL_FILE_EXTENSION = ".pkl"
|
||||
METADATA_FILE_EXTENSION = ".json"
|
||||
|
||||
# Data Quality Scoring
|
||||
MIN_QUALITY_SCORE = 0.1
|
||||
MAX_QUALITY_SCORE = 1.0
|
||||
432
services/training/app/core/database.py
Normal file
432
services/training/app/core/database.py
Normal file
@@ -0,0 +1,432 @@
|
||||
# services/training/app/core/database.py
|
||||
"""
|
||||
Database configuration for training service
|
||||
Uses shared database infrastructure
|
||||
"""
|
||||
|
||||
import structlog
|
||||
from typing import AsyncGenerator
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from contextlib import asynccontextmanager
|
||||
from sqlalchemy import text
|
||||
|
||||
from shared.database.base import DatabaseManager, Base
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize database manager with connection pooling configuration
|
||||
database_manager = DatabaseManager(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
pool_pre_ping=settings.DB_POOL_PRE_PING,
|
||||
echo=settings.DB_ECHO
|
||||
)
|
||||
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_background_db_session():
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
|
||||
async def get_db_health() -> bool:
|
||||
"""
|
||||
Health check function for database connectivity
|
||||
Enhanced version of the shared functionality
|
||||
"""
|
||||
try:
|
||||
async with database_manager.async_engine.begin() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
logger.debug("Database health check passed")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database health check failed", error=str(e))
|
||||
return False
|
||||
|
||||
async def get_comprehensive_db_health() -> dict:
|
||||
"""
|
||||
Comprehensive health check that verifies both connectivity and table existence
|
||||
"""
|
||||
health_status = {
|
||||
"status": "healthy",
|
||||
"connectivity": False,
|
||||
"tables_exist": False,
|
||||
"tables_verified": [],
|
||||
"missing_tables": [],
|
||||
"errors": []
|
||||
}
|
||||
|
||||
try:
|
||||
# Test basic connectivity
|
||||
health_status["connectivity"] = await get_db_health()
|
||||
|
||||
if not health_status["connectivity"]:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Database connectivity failed")
|
||||
return health_status
|
||||
|
||||
# Test table existence
|
||||
tables_verified = await _verify_tables_exist()
|
||||
health_status["tables_exist"] = tables_verified
|
||||
|
||||
if tables_verified:
|
||||
health_status["tables_verified"] = [
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
]
|
||||
else:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append("Required tables missing or inaccessible")
|
||||
|
||||
# Try to identify which specific tables are missing
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
for table_name in ['model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts']:
|
||||
try:
|
||||
await session.execute(text(f"SELECT 1 FROM {table_name} LIMIT 1"))
|
||||
health_status["tables_verified"].append(table_name)
|
||||
except Exception:
|
||||
health_status["missing_tables"].append(table_name)
|
||||
except Exception as e:
|
||||
health_status["errors"].append(f"Error checking individual tables: {str(e)}")
|
||||
|
||||
logger.debug("Comprehensive database health check completed",
|
||||
status=health_status["status"],
|
||||
connectivity=health_status["connectivity"],
|
||||
tables_exist=health_status["tables_exist"])
|
||||
|
||||
except Exception as e:
|
||||
health_status["status"] = "unhealthy"
|
||||
health_status["errors"].append(f"Health check failed: {str(e)}")
|
||||
logger.error("Comprehensive database health check failed", error=str(e))
|
||||
|
||||
return health_status
|
||||
|
||||
# Training service specific database utilities
|
||||
class TrainingDatabaseUtils:
|
||||
"""Training service specific database utilities"""
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_training_logs(days_old: int = 90):
|
||||
"""Clean up old training logs"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM model_training_logs "
|
||||
"WHERE start_time < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old training logs",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Training logs cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def cleanup_old_models(days_old: int = 365):
|
||||
"""Clean up old inactive models"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
if settings.DATABASE_URL.startswith("sqlite"):
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = 0 AND created_at < datetime('now', :days_param)"
|
||||
)
|
||||
params = {"days_param": f"-{days_old} days"}
|
||||
else:
|
||||
query = text(
|
||||
"DELETE FROM trained_models "
|
||||
"WHERE is_active = false AND created_at < NOW() - INTERVAL :days_param"
|
||||
)
|
||||
params = {"days_param": f"{days_old} days"}
|
||||
|
||||
result = await session.execute(query, params)
|
||||
await session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info("Cleaned up old models",
|
||||
deleted_count=deleted_count,
|
||||
days_old=days_old)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Model cleanup failed", error=str(e))
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
async def get_training_statistics(tenant_id: str = None) -> dict:
|
||||
"""Get training statistics"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
# Base query for training logs
|
||||
if tenant_id:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE tenant_id = :tenant_id AND is_active = :is_active"
|
||||
)
|
||||
params = {"tenant_id": tenant_id}
|
||||
else:
|
||||
logs_query = text(
|
||||
"SELECT status, COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"GROUP BY status"
|
||||
)
|
||||
models_query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM trained_models "
|
||||
"WHERE is_active = :is_active"
|
||||
)
|
||||
params = {}
|
||||
|
||||
# Get training job statistics
|
||||
logs_result = await session.execute(logs_query, params)
|
||||
job_stats = {row.status: row.count for row in logs_result.fetchall()}
|
||||
|
||||
# Get active models count
|
||||
active_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": True}
|
||||
)
|
||||
active_models = active_models_result.scalar() or 0
|
||||
|
||||
# Get inactive models count
|
||||
inactive_models_result = await session.execute(
|
||||
models_query,
|
||||
{**params, "is_active": False}
|
||||
)
|
||||
inactive_models = inactive_models_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"training_jobs": job_stats,
|
||||
"active_models": active_models,
|
||||
"inactive_models": inactive_models,
|
||||
"total_models": active_models + inactive_models
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training statistics", error=str(e))
|
||||
return {
|
||||
"training_jobs": {},
|
||||
"active_models": 0,
|
||||
"inactive_models": 0,
|
||||
"total_models": 0
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
async def check_tenant_data_exists(tenant_id: str) -> bool:
|
||||
"""Check if tenant has any training data"""
|
||||
try:
|
||||
async with database_manager.async_session_local() as session:
|
||||
query = text(
|
||||
"SELECT COUNT(*) as count "
|
||||
"FROM model_training_logs "
|
||||
"WHERE tenant_id = :tenant_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
result = await session.execute(query, {"tenant_id": tenant_id})
|
||||
count = result.scalar() or 0
|
||||
|
||||
return count > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to check tenant data existence",
|
||||
tenant_id=tenant_id, error=str(e))
|
||||
return False
|
||||
|
||||
# Enhanced database session dependency with better error handling
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""
|
||||
Enhanced database session dependency with better logging and error handling
|
||||
"""
|
||||
async with database_manager.async_session_local() as session:
|
||||
try:
|
||||
logger.debug("Database session created")
|
||||
yield session
|
||||
except Exception as e:
|
||||
logger.error("Database session error", error=str(e), exc_info=True)
|
||||
await session.rollback()
|
||||
raise
|
||||
finally:
|
||||
await session.close()
|
||||
logger.debug("Database session closed")
|
||||
|
||||
# Database initialization for training service
|
||||
async def initialize_training_database():
|
||||
"""Initialize database tables for training service with retry logic and verification"""
|
||||
import asyncio
|
||||
from sqlalchemy import text
|
||||
|
||||
max_retries = 5
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
logger.info("Initializing training service database",
|
||||
attempt=attempt, max_retries=max_retries)
|
||||
|
||||
# Step 1: Test database connectivity first
|
||||
logger.info("Testing database connectivity...")
|
||||
connection_ok = await database_manager.test_connection()
|
||||
if not connection_ok:
|
||||
raise Exception("Database connection test failed")
|
||||
logger.info("Database connectivity verified")
|
||||
|
||||
# Step 2: Import models to ensure they're registered
|
||||
logger.info("Importing and registering database models...")
|
||||
from app.models.training import (
|
||||
ModelTrainingLog,
|
||||
TrainedModel,
|
||||
ModelPerformanceMetric,
|
||||
TrainingJobQueue,
|
||||
ModelArtifact
|
||||
)
|
||||
|
||||
# Verify models are registered in metadata
|
||||
expected_tables = {
|
||||
'model_training_logs', 'trained_models', 'model_performance_metrics',
|
||||
'training_job_queue', 'model_artifacts'
|
||||
}
|
||||
registered_tables = set(Base.metadata.tables.keys())
|
||||
missing_tables = expected_tables - registered_tables
|
||||
if missing_tables:
|
||||
raise Exception(f"Models not properly registered: {missing_tables}")
|
||||
|
||||
logger.info("Models registered successfully",
|
||||
tables=list(registered_tables))
|
||||
|
||||
# Step 3: Create tables using shared infrastructure with verification
|
||||
logger.info("Creating database tables...")
|
||||
await database_manager.create_tables()
|
||||
|
||||
# Step 4: Verify tables were actually created
|
||||
logger.info("Verifying table creation...")
|
||||
verification_successful = await _verify_tables_exist()
|
||||
|
||||
if not verification_successful:
|
||||
raise Exception("Table verification failed - tables were not created properly")
|
||||
|
||||
logger.info("Training service database initialized and verified successfully",
|
||||
attempt=attempt)
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Database initialization failed",
|
||||
attempt=attempt,
|
||||
max_retries=max_retries,
|
||||
error=str(e))
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error("All database initialization attempts failed - giving up")
|
||||
raise Exception(f"Failed to initialize training database after {max_retries} attempts: {str(e)}")
|
||||
|
||||
# Wait before retry with exponential backoff
|
||||
wait_time = retry_delay * (2 ** (attempt - 1))
|
||||
logger.info("Retrying database initialization",
|
||||
retry_in_seconds=wait_time,
|
||||
next_attempt=attempt + 1)
|
||||
await asyncio.sleep(wait_time)
|
||||
|
||||
async def _verify_tables_exist() -> bool:
|
||||
"""Verify that all required tables exist in the database"""
|
||||
try:
|
||||
async with database_manager.get_session() as session:
|
||||
# Check each required table exists and is accessible
|
||||
required_tables = [
|
||||
'model_training_logs',
|
||||
'trained_models',
|
||||
'model_performance_metrics',
|
||||
'training_job_queue',
|
||||
'model_artifacts'
|
||||
]
|
||||
|
||||
for table_name in required_tables:
|
||||
try:
|
||||
# Try to query the table structure
|
||||
result = await session.execute(
|
||||
text(f"SELECT 1 FROM {table_name} LIMIT 1")
|
||||
)
|
||||
logger.debug(f"Table {table_name} exists and is accessible")
|
||||
except Exception as table_error:
|
||||
# If it's a "relation does not exist" error, table creation failed
|
||||
if "does not exist" in str(table_error).lower():
|
||||
logger.error(f"Table {table_name} does not exist", error=str(table_error))
|
||||
return False
|
||||
# If it's an empty table, that's fine - table exists
|
||||
elif "no data" in str(table_error).lower():
|
||||
logger.debug(f"Table {table_name} exists but is empty (normal)")
|
||||
else:
|
||||
logger.warning(f"Unexpected error querying {table_name}", error=str(table_error))
|
||||
|
||||
logger.info("All required tables verified successfully",
|
||||
tables=required_tables)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Table verification failed", error=str(e))
|
||||
return False
|
||||
|
||||
# Database cleanup for training service
|
||||
async def cleanup_training_database():
|
||||
"""Cleanup database connections for training service"""
|
||||
try:
|
||||
logger.info("Cleaning up training service database connections")
|
||||
|
||||
# Close engine connections
|
||||
if hasattr(database_manager, 'async_engine') and database_manager.async_engine:
|
||||
await database_manager.async_engine.dispose()
|
||||
|
||||
logger.info("Training service database cleanup completed")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup training service database", error=str(e))
|
||||
|
||||
# Export the commonly used items to maintain compatibility
|
||||
__all__ = [
|
||||
'Base',
|
||||
'database_manager',
|
||||
'get_db',
|
||||
'get_db_session',
|
||||
'get_db_health',
|
||||
'TrainingDatabaseUtils',
|
||||
'initialize_training_database',
|
||||
'cleanup_training_database'
|
||||
]
|
||||
35
services/training/app/core/training_constants.py
Normal file
35
services/training/app/core/training_constants.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Training Progress Constants
|
||||
Centralized constants for training progress tracking and timing
|
||||
"""
|
||||
|
||||
# Progress Milestones (percentage)
|
||||
PROGRESS_STARTED = 0
|
||||
PROGRESS_DATA_VALIDATION = 10
|
||||
PROGRESS_DATA_ANALYSIS = 20
|
||||
PROGRESS_DATA_PREPARATION_COMPLETE = 30
|
||||
PROGRESS_ML_TRAINING_START = 40
|
||||
PROGRESS_TRAINING_COMPLETE = 85
|
||||
PROGRESS_STORING_MODELS = 92
|
||||
PROGRESS_STORING_METRICS = 94
|
||||
PROGRESS_COMPLETED = 100
|
||||
|
||||
# Progress Ranges
|
||||
PROGRESS_TRAINING_RANGE_START = 20 # After data analysis
|
||||
PROGRESS_TRAINING_RANGE_END = 80 # Before finalization
|
||||
PROGRESS_TRAINING_RANGE_WIDTH = PROGRESS_TRAINING_RANGE_END - PROGRESS_TRAINING_RANGE_START # 60%
|
||||
|
||||
# Time Limits and Intervals (seconds)
|
||||
MAX_ESTIMATED_TIME_REMAINING_SECONDS = 1800 # 30 minutes
|
||||
WEBSOCKET_HEARTBEAT_INTERVAL_SECONDS = 30
|
||||
WEBSOCKET_RECONNECT_MAX_ATTEMPTS = 3
|
||||
WEBSOCKET_RECONNECT_INITIAL_DELAY_SECONDS = 1
|
||||
WEBSOCKET_RECONNECT_MAX_DELAY_SECONDS = 10
|
||||
|
||||
# Training Timeouts (seconds)
|
||||
TRAINING_SKIP_OPTION_DELAY_SECONDS = 120 # 2 minutes
|
||||
HTTP_POLLING_INTERVAL_MS = 5000 # 5 seconds
|
||||
HTTP_POLLING_DEBOUNCE_MS = 5000 # 5 seconds before enabling after WebSocket disconnect
|
||||
|
||||
# Frontend Display
|
||||
TRAINING_COMPLETION_DELAY_MS = 2000 # Delay before navigating after completion
|
||||
Reference in New Issue
Block a user