# services/training/app/models/training.py """ Database models for training service """ from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float from sqlalchemy.dialects.postgresql import UUID, ARRAY from shared.database.base import Base from datetime import datetime import uuid class ModelTrainingLog(Base): """ Table to track training job execution and status. Replaces the old Celery task tracking. """ __tablename__ = "model_training_logs" id = Column(Integer, primary_key=True, index=True) job_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled progress = Column(Integer, default=0) # 0-100 percentage current_step = Column(String(500), default="") # Timestamps start_time = Column(DateTime, default=datetime.now) end_time = Column(DateTime, nullable=True) # Configuration and results config = Column(JSON, nullable=True) # Training job configuration results = Column(JSON, nullable=True) # Training results error_message = Column(Text, nullable=True) # Metadata created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) class TrainedModel(Base): """ Table to store information about trained models. """ __tablename__ = "trained_models" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) product_name = Column(String(255), index=True, nullable=False) # Model information model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc. model_path = Column(String(1000), nullable=False) # Path to stored model file version = Column(Integer, nullable=False, default=1) # Training information training_samples = Column(Integer, nullable=False, default=0) features = Column(ARRAY(String), nullable=True) # List of features used hyperparameters = Column(JSON, nullable=True) # Model hyperparameters training_metrics = Column(JSON, nullable=True) # Training performance metrics # Data period information data_period_start = Column(DateTime, nullable=True) data_period_end = Column(DateTime, nullable=True) # Status and metadata is_active = Column(Boolean, default=True, index=True) created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) class ModelPerformanceMetric(Base): """ Table to track model performance over time. """ __tablename__ = "model_performance_metrics" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) product_name = Column(String(255), index=True, nullable=False) # Performance metrics mae = Column(Float, nullable=True) # Mean Absolute Error mse = Column(Float, nullable=True) # Mean Squared Error rmse = Column(Float, nullable=True) # Root Mean Squared Error mape = Column(Float, nullable=True) # Mean Absolute Percentage Error r2_score = Column(Float, nullable=True) # R-squared score # Additional metrics accuracy_percentage = Column(Float, nullable=True) prediction_confidence = Column(Float, nullable=True) # Evaluation information evaluation_period_start = Column(DateTime, nullable=True) evaluation_period_end = Column(DateTime, nullable=True) evaluation_samples = Column(Integer, nullable=True) # Metadata measured_at = Column(DateTime, default=datetime.now) created_at = Column(DateTime, default=datetime.now) class TrainingJobQueue(Base): """ Table to manage training job queue and scheduling. """ __tablename__ = "training_job_queue" id = Column(Integer, primary_key=True, index=True) job_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) # Job configuration job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation priority = Column(Integer, default=1) # Higher number = higher priority config = Column(JSON, nullable=True) # Scheduling information scheduled_at = Column(DateTime, nullable=True) started_at = Column(DateTime, nullable=True) estimated_duration_minutes = Column(Integer, nullable=True) # Status status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed retry_count = Column(Integer, default=0) max_retries = Column(Integer, default=3) # Metadata created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) class ModelArtifact(Base): """ Table to track model files and artifacts. """ __tablename__ = "model_artifacts" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) # Artifact information artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc. file_path = Column(String(1000), nullable=False) file_size_bytes = Column(Integer, nullable=True) checksum = Column(String(255), nullable=True) # For file integrity # Storage information storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc. compression = Column(String(50), nullable=True) # gzip, lz4, etc. # Metadata created_at = Column(DateTime, default=datetime.now) expires_at = Column(DateTime, nullable=True) # For automatic cleanup