Improve AI logic
This commit is contained in:
@@ -170,6 +170,7 @@ class TrainingDataOrchestrator:
|
||||
logger.error(f"Training data preparation failed: {str(e)}")
|
||||
raise ValueError(f"Failed to prepare training data: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def extract_sales_date_range_utc_localize(sales_data_df: pd.DataFrame):
|
||||
"""
|
||||
Extracts the UTC-aware date range from a sales DataFrame using tz_localize.
|
||||
@@ -246,12 +247,14 @@ class TrainingDataOrchestrator:
|
||||
if 'date' in record:
|
||||
record_date = record['date']
|
||||
|
||||
# ✅ FIX: Proper timezone handling for date parsing
|
||||
# ✅ FIX: Proper timezone handling for date parsing - FIXED THE TRUNCATION ISSUE
|
||||
if isinstance(record_date, str):
|
||||
# Parse complete ISO datetime string with timezone info intact
|
||||
# DO NOT truncate to date part only - this was causing the filtering issue
|
||||
if 'T' in record_date:
|
||||
record_date = record_date.replace('Z', '+00:00')
|
||||
# Parse with timezone info intact
|
||||
parsed_date = datetime.fromisoformat(record_date.split('T')[0])
|
||||
# Parse with FULL datetime info, not just date part
|
||||
parsed_date = datetime.fromisoformat(record_date)
|
||||
# Ensure timezone-aware
|
||||
if parsed_date.tzinfo is None:
|
||||
parsed_date = parsed_date.replace(tzinfo=timezone.utc)
|
||||
@@ -260,8 +263,8 @@ class TrainingDataOrchestrator:
|
||||
# Ensure timezone-aware
|
||||
if record_date.tzinfo is None:
|
||||
record_date = record_date.replace(tzinfo=timezone.utc)
|
||||
# Normalize to start of day
|
||||
record_date = record_date.replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
# DO NOT normalize to start of day - keep actual datetime for proper filtering
|
||||
# Only normalize if needed for daily aggregation, but preserve original for filtering
|
||||
|
||||
# ✅ FIX: Ensure aligned_range dates are also timezone-aware for comparison
|
||||
aligned_start = aligned_range.start
|
||||
@@ -885,4 +888,4 @@ class TrainingDataOrchestrator:
|
||||
1 if len(dataset.traffic_data) > 0 else 0
|
||||
])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -468,6 +468,7 @@ class EnhancedTrainingService:
|
||||
"""
|
||||
try:
|
||||
from app.models.training import TrainingPerformanceMetrics
|
||||
from shared.database.repository import BaseRepository
|
||||
|
||||
# Extract timing and success data
|
||||
models_trained = training_results.get("models_trained", {})
|
||||
@@ -508,10 +509,13 @@ class EnhancedTrainingService:
|
||||
"completed_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
# Create a temporary repository for the TrainingPerformanceMetrics model
|
||||
# Use the session from one of the initialized repositories to ensure it's available
|
||||
session = self.model_repo.session # This should be the same session used by all repositories
|
||||
metrics_repo = BaseRepository(TrainingPerformanceMetrics, session)
|
||||
|
||||
# Use repository to create record
|
||||
performance_metrics = TrainingPerformanceMetrics(**metric_data)
|
||||
self.session.add(performance_metrics)
|
||||
await self.session.commit()
|
||||
await metrics_repo.create(metric_data)
|
||||
|
||||
logger.info("Saved training performance metrics for future estimations",
|
||||
tenant_id=tenant_id,
|
||||
@@ -777,17 +781,154 @@ class EnhancedTrainingService:
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id)
|
||||
|
||||
# This would use the data client to fetch data for the specific product
|
||||
# and then use the enhanced training pipeline
|
||||
# For now, return a success response
|
||||
# Create initial training log
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Fetching training data",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Prepare training data for all products to get weather/traffic data
|
||||
# then filter down to the specific product
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
bakery_location=bakery_location,
|
||||
job_id=job_id + "_temp"
|
||||
)
|
||||
|
||||
# Filter sales data to the specific product
|
||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
||||
|
||||
if product_sales_df.empty:
|
||||
raise ValueError(f"No sales data available for product {inventory_product_id}")
|
||||
|
||||
# Prepare the data in Prophet format (ds and y columns)
|
||||
# Ensure proper column names and types for Prophet
|
||||
product_data = product_sales_df.copy()
|
||||
product_data = product_data.rename(columns={
|
||||
'sale_date': 'ds', # Common sales date column
|
||||
'sale_datetime': 'ds', # Alternative date column
|
||||
'date': 'ds', # Alternative date column
|
||||
'quantity': 'y', # Quantity sold
|
||||
'total_amount': 'y', # Alternative for sales data
|
||||
'sales_amount': 'y', # Alternative for sales data
|
||||
'sale_amount': 'y' # Alternative for sales data
|
||||
})
|
||||
|
||||
# If 'ds' and 'y' columns are not renamed properly, try to infer them
|
||||
if 'ds' not in product_data.columns:
|
||||
# Try to find date-like columns
|
||||
date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()]
|
||||
if date_cols:
|
||||
product_data = product_data.rename(columns={date_cols[0]: 'ds'})
|
||||
|
||||
if 'y' not in product_data.columns:
|
||||
# Try to find sales/quantity-like columns
|
||||
sales_cols = [col for col in product_data.columns if
|
||||
any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])]
|
||||
if sales_cols:
|
||||
product_data = product_data.rename(columns={sales_cols[0]: 'y'})
|
||||
|
||||
# Ensure required columns exist
|
||||
if 'ds' not in product_data.columns or 'y' not in product_data.columns:
|
||||
raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}")
|
||||
|
||||
# Convert the date column to datetime if it's not already
|
||||
product_data['ds'] = pd.to_datetime(product_data['ds'])
|
||||
|
||||
# Convert to numeric ensuring no pandas/numpy objects remain
|
||||
product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce')
|
||||
|
||||
# Sort by date to ensure proper chronological order
|
||||
product_data = product_data.sort_values('ds').reset_index(drop=True)
|
||||
|
||||
# Drop any rows with NaN values
|
||||
product_data = product_data.dropna(subset=['ds', 'y'])
|
||||
|
||||
# Ensure the data is in the right format for Prophet
|
||||
product_data = product_data[['ds', 'y']].copy()
|
||||
|
||||
# Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations)
|
||||
product_data['ds'] = pd.to_datetime(product_data['ds'])
|
||||
product_data['y'] = product_data['y'].astype(float)
|
||||
|
||||
# DEBUG: Log data types to diagnose dict comparison error
|
||||
logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}")
|
||||
logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}")
|
||||
logger.info(f"DEBUG: Attempting to get min/max...")
|
||||
try:
|
||||
min_val = product_data['ds'].min()
|
||||
max_val = product_data['ds'].max()
|
||||
logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}")
|
||||
logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}")
|
||||
except Exception as debug_e:
|
||||
logger.error(f"DEBUG: Failed to get min/max: {debug_e}")
|
||||
import traceback
|
||||
logger.error(f"DEBUG: Traceback: {traceback.format_exc()}")
|
||||
|
||||
logger.info("Prepared training data for single product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
|
||||
|
||||
# Update progress
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=30,
|
||||
current_step="Training model",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Train the model using the trainer
|
||||
# Extract datetime values with proper pandas Timestamp wrapper for type safety
|
||||
try:
|
||||
training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime()
|
||||
training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Failed to extract training dates: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}")
|
||||
logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}")
|
||||
raise
|
||||
|
||||
# Run the actual training
|
||||
try:
|
||||
model_info = await self.trainer.train_single_product_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
training_data=product_data,
|
||||
job_id=job_id
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Training failed with error: {e}")
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
# Update progress
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=80,
|
||||
current_step="Saving model",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# The model should already be saved by train_single_product_model
|
||||
# Return appropriate response
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": "completed",
|
||||
"message": "Enhanced single product training completed successfully",
|
||||
"created_at": datetime.now(),
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 15, # Default estimate for single product
|
||||
"training_results": {
|
||||
"total_products": 1,
|
||||
"successful_trainings": 1,
|
||||
@@ -795,21 +936,37 @@ class EnhancedTrainingService:
|
||||
"products": [{
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": "completed",
|
||||
"model_id": f"model_{inventory_product_id}_{job_id[:8]}",
|
||||
"data_points": 100,
|
||||
"metrics": {"mape": 15.5, "mae": 2.3, "rmse": 3.1, "r2_score": 0.85}
|
||||
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
|
||||
"data_points": len(product_data) if product_data is not None else 0,
|
||||
# Filter metrics to ensure only numeric values are included
|
||||
"metrics": {
|
||||
k: float(v) if not isinstance(v, (int, float)) else v
|
||||
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
|
||||
if k != 'product_category' and v is not None
|
||||
}
|
||||
}],
|
||||
"overall_training_time_seconds": 45.2
|
||||
"overall_training_time_seconds": model_info.get('training_time', 45.2)
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True,
|
||||
"completed_at": datetime.now().isoformat()
|
||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Enhanced single product training failed",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
|
||||
# Update status to failed
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Training failed",
|
||||
error_message=str(e),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -842,6 +999,7 @@ class EnhancedTrainingService:
|
||||
"status": final_result["status"],
|
||||
"message": f"Training {final_result['status']} successfully",
|
||||
"created_at": datetime.now(),
|
||||
"estimated_duration_minutes": final_result.get("estimated_duration_minutes", 15),
|
||||
"training_results": {
|
||||
"total_products": len(products),
|
||||
"successful_trainings": len([p for p in products if p["status"] == "completed"]),
|
||||
|
||||
Reference in New Issue
Block a user