Improve training code

This commit is contained in:
Urtzi Alfaro
2025-07-28 19:28:39 +02:00
parent 946015b80c
commit 98f546af12
15 changed files with 2534 additions and 2812 deletions

View File

@@ -37,37 +37,6 @@ class ModelTrainingLog(Base):
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class TrainedModel(Base):
"""
Table to store information about trained models.
"""
__tablename__ = "trained_models"
id = Column(Integer, primary_key=True, index=True)
model_id = Column(String(255), unique=True, index=True, nullable=False)
tenant_id = Column(UUID(as_uuid=True), nullable=False, index=True)
product_name = Column(String(255), index=True, nullable=False)
# Model information
model_type = Column(String(50), nullable=False, default="prophet") # prophet, arima, etc.
model_path = Column(String(1000), nullable=False) # Path to stored model file
version = Column(Integer, nullable=False, default=1)
# Training information
training_samples = Column(Integer, nullable=False, default=0)
features = Column(ARRAY(String), nullable=True) # List of features used
hyperparameters = Column(JSON, nullable=True) # Model hyperparameters
training_metrics = Column(JSON, nullable=True) # Training performance metrics
# Data period information
data_period_start = Column(DateTime, nullable=True)
data_period_end = Column(DateTime, nullable=True)
# Status and metadata
is_active = Column(Boolean, default=True, index=True)
created_at = Column(DateTime, default=datetime.now)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
class ModelPerformanceMetric(Base):
"""
Table to track model performance over time.
@@ -150,4 +119,73 @@ class ModelArtifact(Base):
# Metadata
created_at = Column(DateTime, default=datetime.now)
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
expires_at = Column(DateTime, nullable=True) # For automatic cleanup
class TrainedModel(Base):
__tablename__ = "trained_models"
# Primary identification
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
tenant_id = Column(String, nullable=False, index=True)
product_name = Column(String, nullable=False, index=True)
# Model information
model_type = Column(String, default="prophet_optimized")
model_version = Column(String, default="1.0")
job_id = Column(String, nullable=False)
# File storage
model_path = Column(String, nullable=False) # Path to the .pkl file
metadata_path = Column(String) # Path to metadata JSON
# Training metrics
mape = Column(Float)
mae = Column(Float)
rmse = Column(Float)
r2_score = Column(Float)
training_samples = Column(Integer)
# Hyperparameters and features
hyperparameters = Column(JSON) # Store optimized parameters
features_used = Column(JSON) # List of regressor columns
# Model status
is_active = Column(Boolean, default=True)
is_production = Column(Boolean, default=False)
# Timestamps
created_at = Column(DateTime, default=datetime.utcnow)
last_used_at = Column(DateTime)
# Training data info
training_start_date = Column(DateTime)
training_end_date = Column(DateTime)
data_quality_score = Column(Float)
# Additional metadata
notes = Column(Text)
created_by = Column(String) # User who triggered training
def to_dict(self):
return {
"id": self.id,
"tenant_id": self.tenant_id,
"product_name": self.product_name,
"model_type": self.model_type,
"model_version": self.model_version,
"model_path": self.model_path,
"mape": self.mape,
"mae": self.mae,
"rmse": self.rmse,
"r2_score": self.r2_score,
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"is_active": self.is_active,
"is_production": self.is_production,
"created_at": self.created_at.isoformat() if self.created_at else None,
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
"data_quality_score": self.data_quality_score
}