REFACTOR external service and improve websocket training
This commit is contained in:
@@ -41,25 +41,16 @@ class TrainingSettings(BaseServiceSettings):
|
||||
REDIS_DB: int = 1
|
||||
|
||||
# ML Model Storage
|
||||
MODEL_STORAGE_PATH: str = os.getenv("MODEL_STORAGE_PATH", "/app/models")
|
||||
MODEL_BACKUP_ENABLED: bool = os.getenv("MODEL_BACKUP_ENABLED", "true").lower() == "true"
|
||||
MODEL_VERSIONING_ENABLED: bool = os.getenv("MODEL_VERSIONING_ENABLED", "true").lower() == "true"
|
||||
|
||||
# Training Configuration
|
||||
MAX_TRAINING_TIME_MINUTES: int = int(os.getenv("MAX_TRAINING_TIME_MINUTES", "30"))
|
||||
MAX_CONCURRENT_TRAINING_JOBS: int = int(os.getenv("MAX_CONCURRENT_TRAINING_JOBS", "3"))
|
||||
MIN_TRAINING_DATA_DAYS: int = int(os.getenv("MIN_TRAINING_DATA_DAYS", "30"))
|
||||
TRAINING_BATCH_SIZE: int = int(os.getenv("TRAINING_BATCH_SIZE", "1000"))
|
||||
|
||||
# Prophet Specific Configuration
|
||||
PROPHET_SEASONALITY_MODE: str = os.getenv("PROPHET_SEASONALITY_MODE", "additive")
|
||||
PROPHET_CHANGEPOINT_PRIOR_SCALE: float = float(os.getenv("PROPHET_CHANGEPOINT_PRIOR_SCALE", "0.05"))
|
||||
PROPHET_SEASONALITY_PRIOR_SCALE: float = float(os.getenv("PROPHET_SEASONALITY_PRIOR_SCALE", "10.0"))
|
||||
PROPHET_HOLIDAYS_PRIOR_SCALE: float = float(os.getenv("PROPHET_HOLIDAYS_PRIOR_SCALE", "10.0"))
|
||||
|
||||
# Spanish Holiday Integration
|
||||
ENABLE_SPANISH_HOLIDAYS: bool = True
|
||||
ENABLE_MADRID_HOLIDAYS: bool = True
|
||||
ENABLE_CUSTOM_HOLIDAYS: bool = os.getenv("ENABLE_CUSTOM_HOLIDAYS", "true").lower() == "true"
|
||||
|
||||
# Data Processing
|
||||
@@ -79,6 +70,8 @@ class TrainingSettings(BaseServiceSettings):
|
||||
PROPHET_DAILY_SEASONALITY: bool = True
|
||||
PROPHET_WEEKLY_SEASONALITY: bool = True
|
||||
PROPHET_YEARLY_SEASONALITY: bool = True
|
||||
PROPHET_SEASONALITY_MODE: str = "additive"
|
||||
|
||||
settings = TrainingSettings()
|
||||
# Throttling settings for parallel training to prevent heartbeat blocking
|
||||
MAX_CONCURRENT_TRAININGS: int = int(os.getenv("MAX_CONCURRENT_TRAININGS", "3"))
|
||||
|
||||
settings = TrainingSettings()
|
||||
|
||||
97
services/training/app/core/constants.py
Normal file
97
services/training/app/core/constants.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Training Service Constants
|
||||
Centralized constants to avoid magic numbers throughout the codebase
|
||||
"""
|
||||
|
||||
# Data Validation Thresholds
|
||||
MIN_DATA_POINTS_REQUIRED = 30
|
||||
RECOMMENDED_DATA_POINTS = 90
|
||||
MAX_ZERO_RATIO_ERROR = 0.9 # 90% zeros = error
|
||||
HIGH_ZERO_RATIO_WARNING = 0.7 # 70% zeros = warning
|
||||
MAX_ZERO_RATIO_INTERMITTENT = 0.8 # Products with >80% zeros are intermittent
|
||||
MODERATE_SPARSITY_THRESHOLD = 0.6 # 60% zeros = moderate sparsity
|
||||
|
||||
# Training Time Periods (in days)
|
||||
MIN_NON_ZERO_DAYS = 30 # Minimum days with non-zero sales
|
||||
DATA_QUALITY_DAY_THRESHOLD_LOW = 90
|
||||
DATA_QUALITY_DAY_THRESHOLD_HIGH = 365
|
||||
MAX_TRAINING_RANGE_DAYS = 730 # 2 years
|
||||
MIN_TRAINING_RANGE_DAYS = 30
|
||||
|
||||
# Product Classification Thresholds
|
||||
HIGH_VOLUME_MEAN_SALES = 10.0
|
||||
HIGH_VOLUME_ZERO_RATIO = 0.3
|
||||
MEDIUM_VOLUME_MEAN_SALES = 5.0
|
||||
MEDIUM_VOLUME_ZERO_RATIO = 0.5
|
||||
LOW_VOLUME_MEAN_SALES = 2.0
|
||||
LOW_VOLUME_ZERO_RATIO = 0.7
|
||||
|
||||
# Hyperparameter Optimization
|
||||
OPTUNA_TRIALS_HIGH_VOLUME = 30
|
||||
OPTUNA_TRIALS_MEDIUM_VOLUME = 25
|
||||
OPTUNA_TRIALS_LOW_VOLUME = 20
|
||||
OPTUNA_TRIALS_INTERMITTENT = 15
|
||||
OPTUNA_TIMEOUT_SECONDS = 600
|
||||
|
||||
# Prophet Uncertainty Sampling
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MIN = 100
|
||||
UNCERTAINTY_SAMPLES_SPARSE_MAX = 200
|
||||
UNCERTAINTY_SAMPLES_LOW_MIN = 150
|
||||
UNCERTAINTY_SAMPLES_LOW_MAX = 300
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MIN = 200
|
||||
UNCERTAINTY_SAMPLES_MEDIUM_MAX = 500
|
||||
UNCERTAINTY_SAMPLES_HIGH_MIN = 300
|
||||
UNCERTAINTY_SAMPLES_HIGH_MAX = 800
|
||||
|
||||
# MAPE Calculation
|
||||
MAPE_LOW_VOLUME_THRESHOLD = 2.0
|
||||
MAPE_MEDIUM_VOLUME_THRESHOLD = 5.0
|
||||
MAPE_CALCULATION_MIN_THRESHOLD = 0.5
|
||||
MAPE_CALCULATION_MID_THRESHOLD = 1.0
|
||||
MAPE_MAX_CAP = 200.0 # Cap MAPE at 200%
|
||||
MAPE_MEDIUM_CAP = 150.0
|
||||
|
||||
# Baseline MAPE estimates for improvement calculation
|
||||
BASELINE_MAPE_VERY_SPARSE = 80.0
|
||||
BASELINE_MAPE_SPARSE = 60.0
|
||||
BASELINE_MAPE_HIGH_VOLUME = 25.0
|
||||
BASELINE_MAPE_MEDIUM_VOLUME = 35.0
|
||||
BASELINE_MAPE_LOW_VOLUME = 45.0
|
||||
IMPROVEMENT_SIGNIFICANCE_THRESHOLD = 0.8 # Only claim improvement if MAPE < 80% of baseline
|
||||
|
||||
# Cross-validation
|
||||
CV_N_SPLITS = 2
|
||||
CV_MIN_VALIDATION_DAYS = 7
|
||||
|
||||
# Progress tracking
|
||||
PROGRESS_DATA_PREPARATION_START = 0
|
||||
PROGRESS_DATA_PREPARATION_END = 45
|
||||
PROGRESS_MODEL_TRAINING_START = 45
|
||||
PROGRESS_MODEL_TRAINING_END = 85
|
||||
PROGRESS_FINALIZATION_START = 85
|
||||
PROGRESS_FINALIZATION_END = 100
|
||||
|
||||
# HTTP Client Configuration
|
||||
HTTP_TIMEOUT_DEFAULT = 30.0 # seconds
|
||||
HTTP_TIMEOUT_LONG_RUNNING = 60.0 # for training data fetches
|
||||
HTTP_MAX_RETRIES = 3
|
||||
HTTP_RETRY_BACKOFF_FACTOR = 2.0
|
||||
|
||||
# WebSocket Configuration
|
||||
WEBSOCKET_PING_TIMEOUT = 60.0 # seconds
|
||||
WEBSOCKET_ACTIVITY_WARNING_THRESHOLD = 90.0 # seconds
|
||||
WEBSOCKET_CONSUMER_HEARTBEAT_INTERVAL = 10.0 # seconds
|
||||
|
||||
# Synthetic Data Generation
|
||||
SYNTHETIC_TEMP_DEFAULT = 50.0
|
||||
SYNTHETIC_TEMP_VARIATION = 100.0
|
||||
SYNTHETIC_TRAFFIC_DEFAULT = 50.0
|
||||
SYNTHETIC_TRAFFIC_VARIATION = 100.0
|
||||
|
||||
# Model Storage
|
||||
MODEL_FILE_EXTENSION = ".pkl"
|
||||
METADATA_FILE_EXTENSION = ".json"
|
||||
|
||||
# Data Quality Scoring
|
||||
MIN_QUALITY_SCORE = 0.1
|
||||
MAX_QUALITY_SCORE = 1.0
|
||||
@@ -15,8 +15,16 @@ from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Initialize database manager using shared infrastructure
|
||||
database_manager = DatabaseManager(settings.DATABASE_URL)
|
||||
# Initialize database manager with connection pooling configuration
|
||||
database_manager = DatabaseManager(
|
||||
settings.DATABASE_URL,
|
||||
pool_size=settings.DB_POOL_SIZE,
|
||||
max_overflow=settings.DB_MAX_OVERFLOW,
|
||||
pool_timeout=settings.DB_POOL_TIMEOUT,
|
||||
pool_recycle=settings.DB_POOL_RECYCLE,
|
||||
pool_pre_ping=settings.DB_POOL_PRE_PING,
|
||||
echo=settings.DB_ECHO
|
||||
)
|
||||
|
||||
# Alias for convenience - matches the existing interface
|
||||
get_db = database_manager.get_db
|
||||
|
||||
Reference in New Issue
Block a user