imporve features
This commit is contained in:
@@ -844,6 +844,9 @@ class EnhancedBakeryMLTrainer:
|
||||
# Extract training period from the processed data
|
||||
training_start_date = None
|
||||
training_end_date = None
|
||||
data_freshness_days = None
|
||||
data_coverage_days = None
|
||||
|
||||
if 'ds' in processed_data.columns and not processed_data.empty:
|
||||
# Ensure ds column is datetime64 before extracting dates (prevents object dtype issues)
|
||||
ds_datetime = pd.to_datetime(processed_data['ds'])
|
||||
@@ -857,6 +860,15 @@ class EnhancedBakeryMLTrainer:
|
||||
training_start_date = pd.Timestamp(min_ts).to_pydatetime().replace(tzinfo=None)
|
||||
if pd.notna(max_ts):
|
||||
training_end_date = pd.Timestamp(max_ts).to_pydatetime().replace(tzinfo=None)
|
||||
|
||||
# Calculate data freshness metrics
|
||||
if training_end_date:
|
||||
from datetime import datetime
|
||||
data_freshness_days = (datetime.now() - training_end_date).days
|
||||
|
||||
# Calculate data coverage period
|
||||
if training_start_date and training_end_date:
|
||||
data_coverage_days = (training_end_date - training_start_date).days
|
||||
|
||||
# Ensure features are clean string list
|
||||
try:
|
||||
@@ -864,6 +876,13 @@ class EnhancedBakeryMLTrainer:
|
||||
except Exception:
|
||||
features_used = []
|
||||
|
||||
# Prepare hyperparameters with data freshness metrics
|
||||
hyperparameters = model_info.get("hyperparameters", {})
|
||||
if data_freshness_days is not None:
|
||||
hyperparameters["data_freshness_days"] = data_freshness_days
|
||||
if data_coverage_days is not None:
|
||||
hyperparameters["data_coverage_days"] = data_coverage_days
|
||||
|
||||
model_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
@@ -876,7 +895,7 @@ class EnhancedBakeryMLTrainer:
|
||||
"rmse": float(model_info.get("training_metrics", {}).get("rmse", 0)) if model_info.get("training_metrics", {}).get("rmse") is not None else 0,
|
||||
"r2_score": float(model_info.get("training_metrics", {}).get("r2", 0)) if model_info.get("training_metrics", {}).get("r2") is not None else 0,
|
||||
"training_samples": int(len(processed_data)),
|
||||
"hyperparameters": self._serialize_scalers(model_info.get("hyperparameters", {})),
|
||||
"hyperparameters": self._serialize_scalers(hyperparameters),
|
||||
"features_used": [str(f) for f in features_used] if features_used else [],
|
||||
"normalization_params": self._serialize_scalers(self.enhanced_data_processor.get_scalers()) or {}, # Include scalers for prediction consistency
|
||||
"product_category": model_info.get("product_category", "unknown"), # Store product category
|
||||
@@ -890,7 +909,9 @@ class EnhancedBakeryMLTrainer:
|
||||
model_record = await repos['model'].create_model(model_data)
|
||||
logger.info("Created enhanced model record",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_id=model_record.id)
|
||||
model_id=model_record.id,
|
||||
data_freshness_days=data_freshness_days,
|
||||
data_coverage_days=data_coverage_days)
|
||||
|
||||
# Create artifacts for model files
|
||||
if model_info.get("model_path"):
|
||||
|
||||
Reference in New Issue
Block a user