Improve training code 4
This commit is contained in:
80
services/training/app/models/training_models.py
Normal file
80
services/training/app/models/training_models.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# services/training/app/models/training_models.py
|
||||
"""
|
||||
Database models for trained ML models
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Boolean, JSON
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class TrainedModel(Base):
|
||||
__tablename__ = "trained_models"
|
||||
|
||||
# Primary identification
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
tenant_id = Column(String, nullable=False, index=True)
|
||||
product_name = Column(String, nullable=False, index=True)
|
||||
|
||||
# Model information
|
||||
model_type = Column(String, default="prophet_optimized")
|
||||
model_version = Column(String, default="1.0")
|
||||
job_id = Column(String, nullable=False)
|
||||
|
||||
# File storage
|
||||
model_path = Column(String, nullable=False) # Path to the .pkl file
|
||||
metadata_path = Column(String) # Path to metadata JSON
|
||||
|
||||
# Training metrics
|
||||
mape = Column(Float)
|
||||
mae = Column(Float)
|
||||
rmse = Column(Float)
|
||||
r2_score = Column(Float)
|
||||
training_samples = Column(Integer)
|
||||
|
||||
# Hyperparameters and features
|
||||
hyperparameters = Column(JSON) # Store optimized parameters
|
||||
features_used = Column(JSON) # List of regressor columns
|
||||
|
||||
# Model status
|
||||
is_active = Column(Boolean, default=True)
|
||||
is_production = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
last_used_at = Column(DateTime)
|
||||
|
||||
# Training data info
|
||||
training_start_date = Column(DateTime)
|
||||
training_end_date = Column(DateTime)
|
||||
data_quality_score = Column(Float)
|
||||
|
||||
# Additional metadata
|
||||
notes = Column(Text)
|
||||
created_by = Column(String) # User who triggered training
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"tenant_id": self.tenant_id,
|
||||
"product_name": self.product_name,
|
||||
"model_type": self.model_type,
|
||||
"model_version": self.model_version,
|
||||
"model_path": self.model_path,
|
||||
"mape": self.mape,
|
||||
"mae": self.mae,
|
||||
"rmse": self.rmse,
|
||||
"r2_score": self.r2_score,
|
||||
"training_samples": self.training_samples,
|
||||
"hyperparameters": self.hyperparameters,
|
||||
"features_used": self.features_used,
|
||||
"is_active": self.is_active,
|
||||
"is_production": self.is_production,
|
||||
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||
"last_used_at": self.last_used_at.isoformat() if self.last_used_at else None,
|
||||
"training_start_date": self.training_start_date.isoformat() if self.training_start_date else None,
|
||||
"training_end_date": self.training_end_date.isoformat() if self.training_end_date else None,
|
||||
"data_quality_score": self.data_quality_score
|
||||
}
|
||||
@@ -25,23 +25,6 @@ class TrainingJobRequest(BaseModel):
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
# Advanced configuration
|
||||
force_retrain: bool = Field(False, description="Force retraining even if recent model exists")
|
||||
parallel_training: bool = Field(True, description="Train products in parallel")
|
||||
max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10)
|
||||
|
||||
@validator('seasonality_mode')
|
||||
def validate_seasonality_mode(cls, v):
|
||||
if v not in ['additive', 'multiplicative']:
|
||||
raise ValueError('seasonality_mode must be either "additive" or "multiplicative"')
|
||||
return v
|
||||
|
||||
|
||||
class SingleProductTrainingRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user