REFACTOR data service

This commit is contained in:
Urtzi Alfaro
2025-08-12 18:17:30 +02:00
parent 7c237c0acc
commit fbe7470ad9
149 changed files with 8528 additions and 7393 deletions

View File

@@ -31,6 +31,10 @@ class EnhancedBakeryDataProcessor:
self.scalers = {} # Store scalers for each feature
self.imputers = {} # Store imputers for missing value handling
self.date_alignment_service = DateAlignmentService()
def get_scalers(self) -> Dict[str, Any]:
"""Return the scalers/normalization parameters for use during prediction"""
return self.scalers.copy()
async def _get_repositories(self, session):
"""Initialize repositories with session"""
@@ -558,9 +562,19 @@ class EnhancedBakeryDataProcessor:
if traffic_std > 0 and not pd.isna(traffic_std) and not pd.isna(traffic_mean):
df['traffic_normalized'] = (df['traffic_volume'] - traffic_mean) / traffic_std
# Store normalization parameters for later use in predictions
self.scalers['traffic_mean'] = float(traffic_mean)
self.scalers['traffic_std'] = float(traffic_std)
logger.info(f"Traffic normalization parameters: mean={traffic_mean:.2f}, std={traffic_std:.2f}")
else:
logger.warning("Traffic volume has zero standard deviation, using zeros for normalized values")
df['traffic_normalized'] = 0.0
# Store default parameters for consistency
self.scalers['traffic_mean'] = 100.0 # Default traffic level used during training
self.scalers['traffic_std'] = 50.0 # Reasonable std for traffic normalization
# Fill any remaining NaN values
df['traffic_normalized'] = df['traffic_normalized'].fillna(0.0)

View File

@@ -349,6 +349,7 @@ class EnhancedBakeryMLTrainer:
"training_samples": len(processed_data),
"hyperparameters": model_info.get("hyperparameters"),
"features_used": list(processed_data.columns),
"normalization_params": self.enhanced_data_processor.get_scalers(), # Include scalers for prediction consistency
"is_active": True,
"is_production": True,
"data_quality_score": model_info.get("data_quality_score", 100.0)

View File

@@ -149,6 +149,7 @@ class TrainedModel(Base):
# Hyperparameters and features
hyperparameters = Column(JSON) # Store optimized parameters
features_used = Column(JSON) # List of regressor columns
normalization_params = Column(JSON) # Store feature normalization parameters for consistent predictions
# Model status
is_active = Column(Boolean, default=True)

View File

@@ -9,7 +9,7 @@ from typing import Dict, Any, List, Optional
from datetime import datetime
# Import the shared clients
from shared.clients import get_data_client, get_service_clients
from shared.clients import get_sales_client, get_external_client, get_service_clients
from app.core.config import settings
logger = structlog.get_logger()
@@ -21,19 +21,20 @@ class DataClient:
"""
def __init__(self):
# Get the shared data client configured for this service
self.data_client = get_data_client(settings, "training")
# Get the new specialized clients
self.sales_client = get_sales_client(settings, "training")
self.external_client = get_external_client(settings, "training")
# Check if the new method is available for stored traffic data
if hasattr(self.data_client, 'get_stored_traffic_data_for_training'):
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
self.supports_stored_traffic_data = True
else:
self.supports_stored_traffic_data = False
logger.warning("Stored traffic data method not available in data client")
logger.warning("Stored traffic data method not available in external client")
# Or alternatively, get all clients at once:
# self.clients = get_service_clients(settings, "training")
# Then use: self.clients.data.get_sales_data(...)
# Then use: self.clients.sales.get_sales_data(...) and self.clients.external.get_weather_forecast(...)
async def fetch_sales_data(
self,
@@ -57,18 +58,18 @@ class DataClient:
try:
if fetch_all:
# Use paginated method to get ALL records (original behavior)
sales_data = await self.data_client.get_all_sales_data(
sales_data = await self.sales_client.get_all_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
product_id=product_id,
aggregation="daily",
page_size=5000, # Match original page size
page_size=1000, # Comply with API limit
max_pages=100 # Safety limit (500k records max)
)
else:
# Use standard method for limited results
sales_data = await self.data_client.get_sales_data(
sales_data = await self.sales_client.get_sales_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
@@ -102,7 +103,7 @@ class DataClient:
All the error handling and retry logic is now in the base client!
"""
try:
weather_data = await self.data_client.get_weather_historical(
weather_data = await self.external_client.get_weather_historical(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
@@ -134,7 +135,7 @@ class DataClient:
Fetch traffic data for training
"""
try:
traffic_data = await self.data_client.get_traffic_data(
traffic_data = await self.external_client.get_traffic_data(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
@@ -169,7 +170,7 @@ class DataClient:
try:
if self.supports_stored_traffic_data:
# Use the dedicated stored traffic data method
stored_traffic_data = await self.data_client.get_stored_traffic_data_for_training(
stored_traffic_data = await self.external_client.get_stored_traffic_data_for_training(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date,
@@ -209,11 +210,15 @@ class DataClient:
Validate data quality before training
"""
try:
validation_result = await self.data_client.validate_data_quality(
tenant_id=tenant_id,
start_date=start_date,
end_date=end_date
)
# Note: validation_data_quality may need to be implemented in one of the new services
# validation_result = await self.sales_client.validate_data_quality(
# tenant_id=tenant_id,
# start_date=start_date,
# end_date=end_date
# )
# Temporary implementation - assume data is valid for now
validation_result = {"is_valid": True, "message": "Validation temporarily disabled"}
if validation_result:
logger.info("Data validation completed",