REFACTOR data service
This commit is contained in:
@@ -31,6 +31,10 @@ class EnhancedBakeryDataProcessor:
|
||||
self.scalers = {} # Store scalers for each feature
|
||||
self.imputers = {} # Store imputers for missing value handling
|
||||
self.date_alignment_service = DateAlignmentService()
|
||||
|
||||
def get_scalers(self) -> Dict[str, Any]:
|
||||
"""Return the scalers/normalization parameters for use during prediction"""
|
||||
return self.scalers.copy()
|
||||
|
||||
async def _get_repositories(self, session):
|
||||
"""Initialize repositories with session"""
|
||||
@@ -558,9 +562,19 @@ class EnhancedBakeryDataProcessor:
|
||||
|
||||
if traffic_std > 0 and not pd.isna(traffic_std) and not pd.isna(traffic_mean):
|
||||
df['traffic_normalized'] = (df['traffic_volume'] - traffic_mean) / traffic_std
|
||||
|
||||
# Store normalization parameters for later use in predictions
|
||||
self.scalers['traffic_mean'] = float(traffic_mean)
|
||||
self.scalers['traffic_std'] = float(traffic_std)
|
||||
|
||||
logger.info(f"Traffic normalization parameters: mean={traffic_mean:.2f}, std={traffic_std:.2f}")
|
||||
else:
|
||||
logger.warning("Traffic volume has zero standard deviation, using zeros for normalized values")
|
||||
df['traffic_normalized'] = 0.0
|
||||
|
||||
# Store default parameters for consistency
|
||||
self.scalers['traffic_mean'] = 100.0 # Default traffic level used during training
|
||||
self.scalers['traffic_std'] = 50.0 # Reasonable std for traffic normalization
|
||||
|
||||
# Fill any remaining NaN values
|
||||
df['traffic_normalized'] = df['traffic_normalized'].fillna(0.0)
|
||||
|
||||
@@ -349,6 +349,7 @@ class EnhancedBakeryMLTrainer:
|
||||
"training_samples": len(processed_data),
|
||||
"hyperparameters": model_info.get("hyperparameters"),
|
||||
"features_used": list(processed_data.columns),
|
||||
"normalization_params": self.enhanced_data_processor.get_scalers(), # Include scalers for prediction consistency
|
||||
"is_active": True,
|
||||
"is_production": True,
|
||||
"data_quality_score": model_info.get("data_quality_score", 100.0)
|
||||
|
||||
Reference in New Issue
Block a user