Improve AI logic
This commit is contained in:
@@ -14,6 +14,9 @@ import asyncio
|
||||
|
||||
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
||||
from app.ml.prophet_manager import BakeryProphetManager
|
||||
from app.ml.product_categorizer import ProductCategorizer, ProductCategory
|
||||
from app.ml.model_selector import ModelSelector
|
||||
from app.ml.hybrid_trainer import HybridProphetXGBoost
|
||||
from app.services.training_orchestrator import TrainingDataSet
|
||||
from app.core.config import settings
|
||||
|
||||
@@ -49,6 +52,9 @@ class EnhancedBakeryMLTrainer:
|
||||
self.database_manager = database_manager or create_database_manager(settings.DATABASE_URL, "training-service")
|
||||
self.enhanced_data_processor = EnhancedBakeryDataProcessor(self.database_manager)
|
||||
self.prophet_manager = BakeryProphetManager(database_manager=self.database_manager)
|
||||
self.hybrid_trainer = HybridProphetXGBoost(database_manager=self.database_manager)
|
||||
self.model_selector = ModelSelector()
|
||||
self.product_categorizer = ProductCategorizer()
|
||||
|
||||
async def _get_repositories(self, session):
|
||||
"""Initialize repositories with session"""
|
||||
@@ -169,6 +175,16 @@ class EnhancedBakeryMLTrainer:
|
||||
sales_df, weather_df, traffic_df, products, tenant_id, job_id
|
||||
)
|
||||
|
||||
# Categorize all products for category-specific forecasting
|
||||
logger.info("Categorizing products for optimized forecasting")
|
||||
product_categories = await self._categorize_all_products(
|
||||
sales_df, processed_data
|
||||
)
|
||||
logger.info("Product categorization complete",
|
||||
total_products=len(product_categories),
|
||||
categories_breakdown={cat.value: sum(1 for c in product_categories.values() if c == cat)
|
||||
for cat in set(product_categories.values())})
|
||||
|
||||
# Event 2: Data Analysis (20%)
|
||||
# Recalculate time remaining based on elapsed time
|
||||
elapsed_seconds = (datetime.now(timezone.utc) - repos['training_log']._get_start_time(job_id) if hasattr(repos['training_log'], '_get_start_time') else 0) or 0
|
||||
@@ -202,7 +218,7 @@ class EnhancedBakeryMLTrainer:
|
||||
)
|
||||
|
||||
training_results = await self._train_all_models_enhanced(
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker
|
||||
tenant_id, processed_data, job_id, repos, progress_tracker, product_categories
|
||||
)
|
||||
|
||||
# Calculate overall training summary with enhanced metrics
|
||||
@@ -269,6 +285,149 @@ class EnhancedBakeryMLTrainer:
|
||||
|
||||
raise
|
||||
|
||||
async def train_single_product_model(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
training_data: pd.DataFrame,
|
||||
job_id: str = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Train a model for a single product using repository pattern.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier
|
||||
inventory_product_id: Specific inventory product to train
|
||||
training_data: Prepared training DataFrame for the product
|
||||
job_id: Training job identifier (optional)
|
||||
|
||||
Returns:
|
||||
Dictionary with model training results
|
||||
"""
|
||||
if not job_id:
|
||||
job_id = f"single_product_{tenant_id}_{inventory_product_id}_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
logger.info("Starting single product model training",
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(training_data))
|
||||
|
||||
try:
|
||||
# Get database session and repositories
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
repos = await self._get_repositories(db_session)
|
||||
|
||||
# Validate input data
|
||||
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
|
||||
|
||||
# Validate required columns
|
||||
required_columns = ['ds', 'y']
|
||||
missing_cols = [col for col in required_columns if col not in training_data.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"Missing required columns in training data: {missing_cols}")
|
||||
|
||||
# Create a simple progress tracker for single product
|
||||
from app.services.progress_tracker import ParallelProductProgressTracker
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=1
|
||||
)
|
||||
|
||||
# Ensure training data has proper data types before training
|
||||
if 'ds' in training_data.columns:
|
||||
training_data['ds'] = pd.to_datetime(training_data['ds'])
|
||||
if 'y' in training_data.columns:
|
||||
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
|
||||
|
||||
# Remove any rows with NaN values
|
||||
training_data = training_data.dropna()
|
||||
|
||||
# Train the model using the existing _train_single_product method
|
||||
product_id, result = await self._train_single_product(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
product_data=training_data,
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker
|
||||
)
|
||||
|
||||
logger.info("Single product training completed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
result_status=result.get('status'))
|
||||
|
||||
# Get training metrics and filter out non-numeric values
|
||||
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
|
||||
# Filter metrics to only include numeric values (per Pydantic schema requirement)
|
||||
filtered_metrics = {}
|
||||
for key, value in raw_metrics.items():
|
||||
if key == 'product_category':
|
||||
# Skip product_category as it's a string value, not a numeric metric
|
||||
continue
|
||||
try:
|
||||
# Try to convert to float for validation
|
||||
filtered_metrics[key] = float(value) if value is not None else 0.0
|
||||
except (ValueError, TypeError):
|
||||
# Skip non-numeric values
|
||||
continue
|
||||
|
||||
# Return appropriate result format
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": result.get('status', 'success'),
|
||||
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
|
||||
"training_metrics": filtered_metrics,
|
||||
"training_time": result.get('training_time_seconds', 0),
|
||||
"data_points": result.get('data_points', 0),
|
||||
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Single product model training failed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
raise
|
||||
|
||||
def _serialize_scalers(self, scalers: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Serialize scaler objects to basic Python types that can be stored in database.
|
||||
This prevents issues with storing complex sklearn objects in JSON fields.
|
||||
"""
|
||||
if not scalers:
|
||||
return {}
|
||||
|
||||
serialized = {}
|
||||
for key, value in scalers.items():
|
||||
try:
|
||||
# Convert numpy scalars to Python native types
|
||||
if hasattr(value, 'item'): # numpy scalars
|
||||
serialized[key] = value.item()
|
||||
elif isinstance(value, (np.integer, np.floating)):
|
||||
serialized[key] = value.item() # Convert numpy types to Python types
|
||||
elif isinstance(value, (int, float, str, bool, type(None))):
|
||||
serialized[key] = value # Already basic type
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Convert list/tuple elements to basic types
|
||||
serialized[key] = [v.item() if hasattr(v, 'item') else v for v in value]
|
||||
else:
|
||||
# For complex objects, try to convert to string representation
|
||||
# or store as float if it's numeric
|
||||
try:
|
||||
serialized[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
# If all else fails, convert to string
|
||||
serialized[key] = str(value)
|
||||
except Exception:
|
||||
# If serialization fails, set to None to prevent database errors
|
||||
serialized[key] = None
|
||||
|
||||
return serialized
|
||||
|
||||
async def _process_all_products_enhanced(self,
|
||||
sales_df: pd.DataFrame,
|
||||
weather_df: pd.DataFrame,
|
||||
@@ -321,12 +480,15 @@ class EnhancedBakeryMLTrainer:
|
||||
product_data: pd.DataFrame,
|
||||
job_id: str,
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> tuple[str, Dict[str, Any]]:
|
||||
progress_tracker: ParallelProductProgressTracker,
|
||||
product_category: ProductCategory = ProductCategory.UNKNOWN) -> tuple[str, Dict[str, Any]]:
|
||||
"""Train a single product model - used for parallel execution with progress aggregation"""
|
||||
product_start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info("Training model", inventory_product_id=inventory_product_id)
|
||||
logger.info("Training model",
|
||||
inventory_product_id=inventory_product_id,
|
||||
category=product_category.value)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(product_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
@@ -343,14 +505,58 @@ class EnhancedBakeryMLTrainer:
|
||||
min_required=settings.MIN_TRAINING_DATA_DAYS)
|
||||
return inventory_product_id, result
|
||||
|
||||
# Train the model using Prophet manager
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
# Get category-specific hyperparameters
|
||||
category_characteristics = self.product_categorizer.get_category_characteristics(product_category)
|
||||
|
||||
# Determine which model type to use (Prophet vs Hybrid)
|
||||
model_type = self.model_selector.select_model_type(
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
product_category=product_category.value
|
||||
)
|
||||
|
||||
logger.info("Model type selected",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_type=model_type,
|
||||
category=product_category.value)
|
||||
|
||||
# Train the selected model
|
||||
if model_type == "hybrid":
|
||||
# Train hybrid Prophet + XGBoost model
|
||||
model_info = await self.hybrid_trainer.train_hybrid_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
model_info['model_type'] = 'hybrid_prophet_xgboost'
|
||||
else:
|
||||
# Train Prophet-only model with category-specific settings
|
||||
model_info = await self.prophet_manager.train_bakery_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
df=product_data,
|
||||
job_id=job_id,
|
||||
product_category=product_category,
|
||||
category_hyperparameters=category_characteristics.get('prophet_params', {})
|
||||
)
|
||||
model_info['model_type'] = 'prophet_optimized'
|
||||
|
||||
# Filter training metrics to exclude non-numeric values (e.g., product_category)
|
||||
if 'training_metrics' in model_info and model_info['training_metrics']:
|
||||
raw_metrics = model_info['training_metrics']
|
||||
filtered_metrics = {}
|
||||
for key, value in raw_metrics.items():
|
||||
if key == 'product_category':
|
||||
# Skip product_category as it's a string value, not a numeric metric
|
||||
continue
|
||||
try:
|
||||
# Try to convert to float for validation
|
||||
filtered_metrics[key] = float(value) if value is not None else 0.0
|
||||
except (ValueError, TypeError):
|
||||
# Skip non-numeric values
|
||||
continue
|
||||
model_info['training_metrics'] = filtered_metrics
|
||||
|
||||
# Store model record using repository
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
@@ -366,7 +572,7 @@ class EnhancedBakeryMLTrainer:
|
||||
result = {
|
||||
'status': 'success',
|
||||
'model_info': model_info,
|
||||
'model_record_id': model_record.id if model_record else None,
|
||||
'model_record_id': str(model_record.id) if model_record else None,
|
||||
'data_points': len(product_data),
|
||||
'training_time_seconds': time.time() - product_start_time,
|
||||
'trained_at': datetime.now().isoformat()
|
||||
@@ -403,7 +609,8 @@ class EnhancedBakeryMLTrainer:
|
||||
processed_data: Dict[str, pd.DataFrame],
|
||||
job_id: str,
|
||||
repos: Dict,
|
||||
progress_tracker: ParallelProductProgressTracker) -> Dict[str, Any]:
|
||||
progress_tracker: ParallelProductProgressTracker,
|
||||
product_categories: Dict[str, ProductCategory] = None) -> Dict[str, Any]:
|
||||
"""Train models with throttled parallel execution and progress tracking"""
|
||||
total_products = len(processed_data)
|
||||
logger.info(f"Starting throttled parallel training for {total_products} products")
|
||||
@@ -416,7 +623,8 @@ class EnhancedBakeryMLTrainer:
|
||||
product_data=product_data,
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker
|
||||
progress_tracker=progress_tracker,
|
||||
product_category=product_categories.get(inventory_product_id, ProductCategory.UNKNOWN) if product_categories else ProductCategory.UNKNOWN
|
||||
)
|
||||
for inventory_product_id, product_data in processed_data.items()
|
||||
]
|
||||
@@ -478,6 +686,29 @@ class EnhancedBakeryMLTrainer:
|
||||
processed_data: pd.DataFrame):
|
||||
"""Create model record using repository"""
|
||||
try:
|
||||
# Extract training period from the processed data
|
||||
training_start_date = None
|
||||
training_end_date = None
|
||||
if 'ds' in processed_data.columns and not processed_data.empty:
|
||||
# Ensure ds column is datetime64 before extracting dates (prevents object dtype issues)
|
||||
ds_datetime = pd.to_datetime(processed_data['ds'])
|
||||
|
||||
# Get min/max as pandas Timestamps (guaranteed to work correctly)
|
||||
min_ts = ds_datetime.min()
|
||||
max_ts = ds_datetime.max()
|
||||
|
||||
# Convert to python datetime with timezone removal
|
||||
if pd.notna(min_ts):
|
||||
training_start_date = pd.Timestamp(min_ts).to_pydatetime().replace(tzinfo=None)
|
||||
if pd.notna(max_ts):
|
||||
training_end_date = pd.Timestamp(max_ts).to_pydatetime().replace(tzinfo=None)
|
||||
|
||||
# Ensure features are clean string list
|
||||
try:
|
||||
features_used = [str(col) for col in processed_data.columns]
|
||||
except Exception:
|
||||
features_used = []
|
||||
|
||||
model_data = {
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
@@ -485,17 +716,20 @@ class EnhancedBakeryMLTrainer:
|
||||
"model_type": "enhanced_prophet",
|
||||
"model_path": model_info.get("model_path"),
|
||||
"metadata_path": model_info.get("metadata_path"),
|
||||
"mape": model_info.get("training_metrics", {}).get("mape"),
|
||||
"mae": model_info.get("training_metrics", {}).get("mae"),
|
||||
"rmse": model_info.get("training_metrics", {}).get("rmse"),
|
||||
"r2_score": model_info.get("training_metrics", {}).get("r2"),
|
||||
"training_samples": len(processed_data),
|
||||
"hyperparameters": model_info.get("hyperparameters"),
|
||||
"features_used": list(processed_data.columns),
|
||||
"normalization_params": self.enhanced_data_processor.get_scalers(), # Include scalers for prediction consistency
|
||||
"mape": float(model_info.get("training_metrics", {}).get("mape", 0)) if model_info.get("training_metrics", {}).get("mape") is not None else 0,
|
||||
"mae": float(model_info.get("training_metrics", {}).get("mae", 0)) if model_info.get("training_metrics", {}).get("mae") is not None else 0,
|
||||
"rmse": float(model_info.get("training_metrics", {}).get("rmse", 0)) if model_info.get("training_metrics", {}).get("rmse") is not None else 0,
|
||||
"r2_score": float(model_info.get("training_metrics", {}).get("r2", 0)) if model_info.get("training_metrics", {}).get("r2") is not None else 0,
|
||||
"training_samples": int(len(processed_data)),
|
||||
"hyperparameters": self._serialize_scalers(model_info.get("hyperparameters", {})),
|
||||
"features_used": [str(f) for f in features_used] if features_used else [],
|
||||
"normalization_params": self._serialize_scalers(self.enhanced_data_processor.get_scalers()) or {}, # Include scalers for prediction consistency
|
||||
"product_category": model_info.get("product_category", "unknown"), # Store product category
|
||||
"is_active": True,
|
||||
"is_production": True,
|
||||
"data_quality_score": model_info.get("data_quality_score", 100.0)
|
||||
"data_quality_score": float(model_info.get("data_quality_score", 100.0)) if model_info.get("data_quality_score") is not None else 100.0,
|
||||
"training_start_date": training_start_date,
|
||||
"training_end_date": training_end_date
|
||||
}
|
||||
|
||||
model_record = await repos['model'].create_model(model_data)
|
||||
@@ -533,13 +767,13 @@ class EnhancedBakeryMLTrainer:
|
||||
"model_id": str(model_id),
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"mae": metrics.get("mae"),
|
||||
"mse": metrics.get("mse"),
|
||||
"rmse": metrics.get("rmse"),
|
||||
"mape": metrics.get("mape"),
|
||||
"r2_score": metrics.get("r2"),
|
||||
"accuracy_percentage": 100 - metrics.get("mape", 0) if metrics.get("mape") else None,
|
||||
"evaluation_samples": metrics.get("data_points", 0)
|
||||
"mae": float(metrics.get("mae")) if metrics.get("mae") is not None else None,
|
||||
"mse": float(metrics.get("mse")) if metrics.get("mse") is not None else None,
|
||||
"rmse": float(metrics.get("rmse")) if metrics.get("rmse") is not None else None,
|
||||
"mape": float(metrics.get("mape")) if metrics.get("mape") is not None else None,
|
||||
"r2_score": float(metrics.get("r2")) if metrics.get("r2") is not None else None,
|
||||
"accuracy_percentage": float(100 - metrics.get("mape", 0)) if metrics.get("mape") is not None else None,
|
||||
"evaluation_samples": int(metrics.get("data_points", 0)) if metrics.get("data_points") is not None else 0
|
||||
}
|
||||
|
||||
await repos['performance'].create_performance_metric(metric_data)
|
||||
@@ -672,7 +906,59 @@ class EnhancedBakeryMLTrainer:
|
||||
sales_df['quantity'] = pd.to_numeric(sales_df['quantity'], errors='coerce')
|
||||
except Exception:
|
||||
raise ValueError("Quantity column must be numeric")
|
||||
|
||||
|
||||
async def _categorize_all_products(
|
||||
self,
|
||||
sales_df: pd.DataFrame,
|
||||
processed_data: Dict[str, pd.DataFrame]
|
||||
) -> Dict[str, ProductCategory]:
|
||||
"""
|
||||
Categorize all products for category-specific forecasting.
|
||||
|
||||
Args:
|
||||
sales_df: Raw sales data with product names
|
||||
processed_data: Processed data by product ID
|
||||
|
||||
Returns:
|
||||
Dict mapping inventory_product_id to ProductCategory
|
||||
"""
|
||||
product_categories = {}
|
||||
|
||||
for inventory_product_id in processed_data.keys():
|
||||
try:
|
||||
# Get product name from sales data (if available)
|
||||
product_sales = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
||||
|
||||
# Extract product name (try multiple possible column names)
|
||||
product_name = "unknown"
|
||||
for name_col in ['product_name', 'name', 'item_name']:
|
||||
if name_col in product_sales.columns and not product_sales[name_col].empty:
|
||||
product_name = product_sales[name_col].iloc[0]
|
||||
break
|
||||
|
||||
# Prepare sales data for pattern analysis
|
||||
sales_for_analysis = product_sales[['date', 'quantity']].copy() if 'date' in product_sales.columns else None
|
||||
|
||||
# Categorize product
|
||||
category = self.product_categorizer.categorize_product(
|
||||
product_name=str(product_name),
|
||||
product_id=inventory_product_id,
|
||||
sales_data=sales_for_analysis
|
||||
)
|
||||
|
||||
product_categories[inventory_product_id] = category
|
||||
|
||||
logger.debug("Product categorized",
|
||||
inventory_product_id=inventory_product_id,
|
||||
product_name=product_name,
|
||||
category=category.value)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to categorize product {inventory_product_id}: {e}")
|
||||
product_categories[inventory_product_id] = ProductCategory.UNKNOWN
|
||||
|
||||
return product_categories
|
||||
|
||||
async def evaluate_model_performance_enhanced(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
|
||||
Reference in New Issue
Block a user