80 lines
2.9 KiB
Python
80 lines
2.9 KiB
Python
|
|
# services/training/app/models/training_models.py
|
||
|
|
"""
|
||
|
|
Database models for trained ML models
|
||
|
|
"""
|
||
|
|
|
||
|
|
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Boolean, JSON
|
||
|
|
from sqlalchemy.ext.declarative import declarative_base
|
||
|
|
from datetime import datetime
|
||
|
|
import uuid
|
||
|
|
|
||
|
|
Base = declarative_base()
|
||
|
|
|
||
|
|
class TrainedModel(Base):
|
||
|
|
__tablename__ = "trained_models"
|
||
|
|
|
||
|
|
# Primary identification
|
||
|
|
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||
|
|
tenant_id = Column(String, nullable=False, index=True)
|
||
|
|
product_name = Column(String, nullable=False, index=True)
|
||
|
|
|
||
|
|
# Model information
|
||
|
|
model_type = Column(String, default="prophet_optimized")
|
||
|
|
model_version = Column(String, default="1.0")
|
||
|
|
job_id = Column(String, nullable=False)
|
||
|
|
|
||
|
|
# File storage
|
||
|
|
model_path = Column(String, nullable=False) # Path to the .pkl file
|
||
|
|
metadata_path = Column(String) # Path to metadata JSON
|
||
|
|
|
||
|
|
# Training metrics
|
||
|
|
mape = Column(Float)
|
||
|
|
mae = Column(Float)
|
||
|
|
rmse = Column(Float)
|
||
|
|
r2_score = Column(Float)
|
||
|
|
training_samples = Column(Integer)
|
||
|
|
|
||
|
|
# Hyperparameters and features
|
||
|
|
hyperparameters = Column(JSON) # Store optimized parameters
|
||
|
|
features_used = Column(JSON) # List of regressor columns
|
||
|
|
|
||
|
|
# Model status
|
||
|
|
is_active = Column(Boolean, default=True)
|
||
|
|
is_production = Column(Boolean, default=False)
|
||
|
|
|
||
|
|
# Timestamps
|
||
|
|
created_at = Column(DateTime, default=datetime.utcnow)
|
||
|
|
last_used_at = Column(DateTime)
|
||
|
|
|
||
|
|
# Training data info
|
||
|
|
training_start_date = Column(DateTime)
|
||
|
|
training_end_date = Column(DateTime)
|
||
|
|
data_quality_score = Column(Float)
|
||
|
|
|
||
|
|
# Additional metadata
|
||
|
|
notes = Column(Text)
|
||
|
|
created_by = Column(String) # User who triggered training
|
||
|
|
|
||
|
|
def to_dict(self):
|
||
|
|
return {
|
||
|
|
"id": self.id,
|
||
|
|
"tenant_id": self.tenant_id,
|
||
|
|
"product_name": self.product_name,
|
||
|
|
"model_type": self.model_type,
|
||
|
|
"model_version": self.model_version,
|
||
|
|
"model_path": self.model_path,
|
||
|
|
"mape": self.mape,
|
||
|
|
"mae": self.mae,
|
||
|
|
"rmse": self.rmse,
|
||
|
|
"r2_score": self.r2_score,
|
||
|
|
"training_samples": self.training_samples,
|
||
|
|
"hyperparameters": self.hyperparameters,
|
||
|
|
"features_used": self.features_used,
|
||
|
|
"is_active": self.is_active,
|
||
|
|
"is_production": self.is_production,
|
||
|
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||
|
|
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||
|
|
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
|
||
|
|
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
|
||
|
|
"data_quality_score": self.data_quality_score
|
||
|
|
}
|