Improve training code 4
This commit is contained in:
@@ -25,23 +25,6 @@ class TrainingJobRequest(BaseModel):
|
||||
products: Optional[List[str]] = Field(None, description="Specific products to train (if None, trains all)")
|
||||
start_date: Optional[datetime] = Field(None, description="Start date for training data")
|
||||
end_date: Optional[datetime] = Field(None, description="End date for training data")
|
||||
|
||||
# Prophet-specific parameters
|
||||
seasonality_mode: str = Field("additive", description="Prophet seasonality mode", pattern="^(additive|multiplicative)$")
|
||||
daily_seasonality: bool = Field(True, description="Enable daily seasonality")
|
||||
weekly_seasonality: bool = Field(True, description="Enable weekly seasonality")
|
||||
yearly_seasonality: bool = Field(True, description="Enable yearly seasonality")
|
||||
|
||||
# Advanced configuration
|
||||
force_retrain: bool = Field(False, description="Force retraining even if recent model exists")
|
||||
parallel_training: bool = Field(True, description="Train products in parallel")
|
||||
max_workers: int = Field(4, description="Maximum parallel workers", ge=1, le=10)
|
||||
|
||||
@validator('seasonality_mode')
|
||||
def validate_seasonality_mode(cls, v):
|
||||
if v not in ['additive', 'multiplicative']:
|
||||
raise ValueError('seasonality_mode must be either "additive" or "multiplicative"')
|
||||
return v
|
||||
|
||||
|
||||
class SingleProductTrainingRequest(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user