# services/training/app/models/training.py """ Database models for training service """ from sqlalchemy import Column, Integer, String, DateTime, Text, Boolean, JSON, Float from sqlalchemy.dialects.postgresql import UUID, ARRAY from shared.database.base import Base from datetime import datetime, timezone import uuid class ModelTrainingLog(Base): """ Table to track training job execution and status. Replaces the old Celery task tracking. """ __tablename__ = "model_training_logs" id = Column(Integer, primary_key=True, index=True) job_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True) status = Column(String(50), nullable=False, default="pending") # pending, running, completed, failed, cancelled progress = Column(Integer, default=0) # 0-100 percentage current_step = Column(String(500), default="") # Timestamps start_time = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) end_time = Column(DateTime(timezone=True), nullable=True) # Configuration and results config = Column(JSON, nullable=True) # Training job configuration results = Column(JSON, nullable=True) # Training results error_message = Column(Text, nullable=True) # Metadata created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) class ModelPerformanceMetric(Base): """ Table to track model performance over time. """ __tablename__ = "model_performance_metrics" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), index=True, nullable=False) tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True) inventory_product_id = Column(UUID(as_uuid=True), index=True, nullable=False) # Performance metrics mae = Column(Float, nullable=True) # Mean Absolute Error mse = Column(Float, nullable=True) # Mean Squared Error rmse = Column(Float, nullable=True) # Root Mean Squared Error mape = Column(Float, nullable=True) # Mean Absolute Percentage Error r2_score = Column(Float, nullable=True) # R-squared score # Additional metrics accuracy_percentage = Column(Float, nullable=True) prediction_confidence = Column(Float, nullable=True) # Evaluation information evaluation_period_start = Column(DateTime, nullable=True) evaluation_period_end = Column(DateTime, nullable=True) evaluation_samples = Column(Integer, nullable=True) # Metadata measured_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) class TrainingJobQueue(Base): """ Table to manage training job queue and scheduling. """ __tablename__ = "training_job_queue" id = Column(Integer, primary_key=True, index=True) job_id = Column(String(255), unique=True, index=True, nullable=False) tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Job configuration job_type = Column(String(50), nullable=False) # full_training, single_product, evaluation priority = Column(Integer, default=1) # Higher number = higher priority config = Column(JSON, nullable=True) # Scheduling information scheduled_at = Column(DateTime, nullable=True) started_at = Column(DateTime, nullable=True) estimated_duration_minutes = Column(Integer, nullable=True) # Status status = Column(String(50), nullable=False, default="queued") # queued, running, completed, failed retry_count = Column(Integer, default=0) max_retries = Column(Integer, default=3) # Metadata created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) cancelled_by = Column(String, nullable=True) class ModelArtifact(Base): """ Table to track model files and artifacts. """ __tablename__ = "model_artifacts" id = Column(Integer, primary_key=True, index=True) model_id = Column(String(255), index=True, nullable=False) tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Artifact information artifact_type = Column(String(50), nullable=False) # model_file, metadata, training_data, etc. file_path = Column(String(1000), nullable=False) file_size_bytes = Column(Integer, nullable=True) checksum = Column(String(255), nullable=True) # For file integrity # Storage information storage_location = Column(String(100), nullable=False, default="local") # local, s3, gcs, etc. compression = Column(String(50), nullable=True) # gzip, lz4, etc. # Metadata created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) expires_at = Column(DateTime(timezone=True), nullable=True) # For automatic cleanup class TrainedModel(Base): __tablename__ = "trained_models" # Primary identification - Updated to use UUID properly id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True) inventory_product_id = Column(UUID(as_uuid=True), nullable=False, index=True) # Model information model_type = Column(String, default="prophet_optimized") model_version = Column(String, default="1.0") job_id = Column(String, nullable=False) # File storage model_path = Column(String, nullable=False) # Path to the .pkl file metadata_path = Column(String) # Path to metadata JSON # Training metrics mape = Column(Float) mae = Column(Float) rmse = Column(Float) r2_score = Column(Float) training_samples = Column(Integer) # Hyperparameters and features hyperparameters = Column(JSON) # Store optimized parameters features_used = Column(JSON) # List of regressor columns normalization_params = Column(JSON) # Store feature normalization parameters for consistent predictions product_category = Column(String, nullable=True) # Product category for category-specific forecasting # Model status is_active = Column(Boolean, default=True) is_production = Column(Boolean, default=False) # Timestamps - Updated to be timezone-aware with proper defaults created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc)) last_used_at = Column(DateTime(timezone=True)) # Training data info training_start_date = Column(DateTime(timezone=True)) training_end_date = Column(DateTime(timezone=True)) data_quality_score = Column(Float) # Additional metadata notes = Column(Text) created_by = Column(String) # User who triggered training def to_dict(self): return { "id": str(self.id), "model_id": str(self.id), "tenant_id": str(self.tenant_id), "inventory_product_id": str(self.inventory_product_id), "model_type": self.model_type, "model_version": self.model_version, "model_path": self.model_path, "mape": self.mape, "mae": self.mae, "rmse": self.rmse, "r2_score": self.r2_score, "training_samples": self.training_samples, "hyperparameters": self.hyperparameters, "features_used": self.features_used, "product_category": self.product_category, "is_active": self.is_active, "is_production": self.is_production, "created_at": self.created_at.isoformat() if self.created_at else None, "updated_at": self.updated_at.isoformat() if self.updated_at else None, "last_used_at": self.last_used_at.isoformat() if self.last_used_at else None, "training_start_date": self.training_start_date.isoformat() if self.training_start_date else None, "training_end_date": self.training_end_date.isoformat() if self.training_end_date else None, "data_quality_score": self.data_quality_score } class TrainingPerformanceMetrics(Base): """ Table to track historical training performance for time estimation. Stores aggregated metrics from completed training jobs. """ __tablename__ = "training_performance_metrics" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True) job_id = Column(String(255), nullable=False, index=True) # Training job statistics total_products = Column(Integer, nullable=False) successful_products = Column(Integer, nullable=False) failed_products = Column(Integer, nullable=False) # Time metrics total_duration_seconds = Column(Float, nullable=False) avg_time_per_product = Column(Float, nullable=False) # Key metric for estimation data_analysis_time_seconds = Column(Float, nullable=True) training_time_seconds = Column(Float, nullable=True) finalization_time_seconds = Column(Float, nullable=True) # Job metadata completed_at = Column(DateTime(timezone=True), nullable=False, default=lambda: datetime.now(timezone.utc)) created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)) def __repr__(self): return ( f"" ) def to_dict(self): return { "id": str(self.id), "tenant_id": str(self.tenant_id), "job_id": self.job_id, "total_products": self.total_products, "successful_products": self.successful_products, "failed_products": self.failed_products, "total_duration_seconds": self.total_duration_seconds, "avg_time_per_product": self.avg_time_per_product, "data_analysis_time_seconds": self.data_analysis_time_seconds, "training_time_seconds": self.training_time_seconds, "finalization_time_seconds": self.finalization_time_seconds, "completed_at": self.completed_at.isoformat() if self.completed_at else None, "created_at": self.created_at.isoformat() if self.created_at else None }