# services/training/app/models/training.py """ Database models for training service """ from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float from sqlalchemy.dialects.postgresql import UUID, ARRAY from sqlalchemy.ext.declarative import declarative_base from datetime import datetime import uuid Base = declarative_base() class ModelTrainingLog(Base): """ Table to track training job execution and status. Replaces the old Celery task tracking. """ __tablename__ = "model_training_logs" id = Column(Integer, primary_key=True, index=True) job_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled progress = Column(Integer, default=0) # 0-100 percentage current_step = Column(String(500), default="") # Timestamps start_time = Column(DateTime, default=datetime.now) end_time = Column(DateTime, nullable=True) # Configuration and results config = Column(JSON, nullable=True) # Training job configuration results = Column(JSON, nullable=True) # Training results error_message = Column(Text, nullable=True) # Metadata created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) class TrainedModel(Base): """ Table to store information about trained models. """ __tablename__ = "trained_models" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) product_name = Column(String(255), index=True, nullable=False) # Model information model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc. model_path = Column(String(1000), nullable=False) # Path to stored model file version = Column(Integer, nullable=False, default=1) # Training information training_samples = Column(Integer, nullable=False, default=0) features = Column(ARRAY(String), nullable=True) # List of features used hyperparameters = Column(JSON, nullable=True) # Model hyperparameters training_metrics = Column(JSON, nullable=True) # Training performance metrics # Data period information data_period_start = Column(DateTime, nullable=True) data_period_end = Column(DateTime, nullable=True) # Status and metadata is_active = Column(Boolean, default=True, index=True) created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) class ModelPerformanceMetric(Base): """ Table to track model performance over time. """ __tablename__ = "model_performance_metrics" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) product_name = Column(String(255), index=True, nullable=False) # Performance metrics mae = Column(Float, nullable=True) # Mean Absolute Error mse = Column(Float, nullable=True) # Mean Squared Error rmse = Column(Float, nullable=True) # Root Mean Squared Error mape = Column(Float, nullable=True) # Mean Absolute Percentage Error r2_score = Column(Float, nullable=True) # R-squared score # Additional metrics accuracy_percentage = Column(Float, nullable=True) prediction_confidence = Column(Float, nullable=True) # Evaluation information evaluation_period_start = Column(DateTime, nullable=True) evaluation_period_end = Column(DateTime, nullable=True) evaluation_samples = Column(Integer, nullable=True) # Metadata measured_at = Column(DateTime, default=datetime.now) created_at = Column(DateTime, default=datetime.now) class TrainingJobQueue(Base): """ Table to manage training job queue and scheduling. """ __tablename__ = "training_job_queue" id = Column(Integer, primary_key=True, index=True) job_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) # Job configuration job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation priority = Column(Integer, default=1) # Higher number = higher priority config = Column(JSON, nullable=True) # Scheduling information scheduled_at = Column(DateTime, nullable=True) started_at = Column(DateTime, nullable=True) estimated_duration_minutes = Column(Integer, nullable=True) # Status status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed retry_count = Column(Integer, default=0) max_retries = Column(Integer, default=3) # Metadata created_at = Column(DateTime, default=datetime.now) updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) class ModelArtifact(Base): """ Table to track model files and artifacts. """ __tablename__ = "model_artifacts" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), index=True, nullable=False) tenant_id = Column(String(255), index=True, nullable=False) # Artifact information artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc. file_path = Column(String(1000), nullable=False) file_size_bytes = Column(Integer, nullable=True) checksum = Column(String(255), nullable=True) # For file integrity # Storage information storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc. compression = Column(String(50), nullable=True) # gzip, lz4, etc. # Metadata created_at = Column(DateTime, default=datetime.now) expires_at = Column(DateTime, nullable=True) # For automatic cleanup