Improve training code
This commit is contained in:
@@ -37,37 +37,6 @@ class ModelTrainingLog(Base):
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class TrainedModel(Base):
|
||||
"""
|
||||
Table to store information about trained models.
|
||||
"""
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
model_id = Column(String(255), unique=True, index=True, nullable=False)
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
|
||||
product_name = Column(String(255), index=True, nullable=False)
|
||||
|
||||
# Model information
|
||||
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
|
||||
model_path = Column(String(1000), nullable=False) # Path to stored model file
|
||||
version = Column(Integer, nullable=False, default=1)
|
||||
|
||||
# Training information
|
||||
training_samples = Column(Integer, nullable=False, default=0)
|
||||
features = Column(ARRAY(String), nullable=True) # List of features used
|
||||
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
|
||||
training_metrics = Column(JSON, nullable=True) # Training performance metrics
|
||||
|
||||
# Data period information
|
||||
data_period_start = Column(DateTime, nullable=True)
|
||||
data_period_end = Column(DateTime, nullable=True)
|
||||
|
||||
# Status and metadata
|
||||
is_active = Column(Boolean, default=True, index=True)
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
|
||||
|
||||
class ModelPerformanceMetric(Base):
|
||||
"""
|
||||
Table to track model performance over time.
|
||||
@@ -150,4 +119,73 @@ class ModelArtifact(Base):
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.now)
|
||||
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
|
||||
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
|
||||
|
||||
class TrainedModel(Base):
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
# Primary identification
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
tenant_id = Column(String, nullable=False, index=True)
|
||||
product_name = Column(String, nullable=False, index=True)
|
||||
|
||||
# Model information
|
||||
model_type = Column(String, default="prophet_optimized")
|
||||
model_version = Column(String, default="1.0")
|
||||
job_id = Column(String, nullable=False)
|
||||
|
||||
# File storage
|
||||
model_path = Column(String, nullable=False) # Path to the .pkl file
|
||||
metadata_path = Column(String) # Path to metadata JSON
|
||||
|
||||
# Training metrics
|
||||
mape = Column(Float)
|
||||
mae = Column(Float)
|
||||
rmse = Column(Float)
|
||||
r2_score = Column(Float)
|
||||
training_samples = Column(Integer)
|
||||
|
||||
# Hyperparameters and features
|
||||
hyperparameters = Column(JSON) # Store optimized parameters
|
||||
features_used = Column(JSON) # List of regressor columns
|
||||
|
||||
# Model status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_production = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_used_at = Column(DateTime)
|
||||
|
||||
# Training data info
|
||||
training_start_date = Column(DateTime)
|
||||
training_end_date = Column(DateTime)
|
||||
data_quality_score = Column(Float)
|
||||
|
||||
# Additional metadata
|
||||
notes = Column(Text)
|
||||
created_by = Column(String) # User who triggered training
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"product_name": self.product_name,
|
||||
"model_type": self.model_type,
|
||||
"model_version": self.model_version,
|
||||
"model_path": self.model_path,
|
||||
"mape": self.mape,
|
||||
"mae": self.mae,
|
||||
"rmse": self.rmse,
|
||||
"r2_score": self.r2_score,
|
||||
"training_samples": self.training_samples,
|
||||
"hyperparameters": self.hyperparameters,
|
||||
"features_used": self.features_used,
|
||||
"is_active": self.is_active,
|
||||
"is_production": self.is_production,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
|
||||
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
|
||||
"data_quality_score": self.data_quality_score
|
||||
}
|
||||
Reference in New Issue
Block a user