Bug fixes of training
This commit is contained in:
@@ -732,297 +732,293 @@ class EnhancedTrainingService:
|
||||
current_step: str = None,
|
||||
error_message: str = None,
|
||||
results: Dict = None,
|
||||
tenant_id: str = None):
|
||||
"""Update job status using repository pattern"""
|
||||
tenant_id: str = None,
|
||||
session = None):
|
||||
"""Update job status using repository pattern
|
||||
|
||||
Args:
|
||||
session: Optional database session to reuse. If None, creates a new session.
|
||||
"""
|
||||
try:
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
# Use provided session or create new one
|
||||
should_create_session = session is None
|
||||
|
||||
# Check if log exists, create if not
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
|
||||
if not existing_log:
|
||||
# Create initial log entry
|
||||
if not tenant_id:
|
||||
# Extract tenant_id from job_id if not provided
|
||||
# Format: enhanced_training_{tenant_id}_{job_suffix}
|
||||
try:
|
||||
parts = job_id.split('_')
|
||||
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
|
||||
tenant_id = parts[2]
|
||||
except Exception:
|
||||
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
||||
|
||||
if tenant_id:
|
||||
log_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": status or "pending",
|
||||
"progress": progress or 0,
|
||||
"current_step": current_step or "initializing",
|
||||
"start_time": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
if error_message:
|
||||
log_data["error_message"] = error_message
|
||||
if results:
|
||||
# Ensure results are JSON-serializable before storing
|
||||
log_data["results"] = make_json_serializable(results)
|
||||
|
||||
try:
|
||||
await self.training_log_repo.create_training_log(log_data)
|
||||
await session.commit() # Explicit commit so other sessions can see it
|
||||
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
||||
except Exception as create_error:
|
||||
# Handle race condition: another session may have created the log
|
||||
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
||||
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
|
||||
await session.rollback()
|
||||
# Query again to get the existing log
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
if existing_log:
|
||||
# Update the existing log instead
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step,
|
||||
status=status
|
||||
)
|
||||
await session.commit()
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
||||
return
|
||||
else:
|
||||
# Update existing log
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step,
|
||||
status=status
|
||||
if should_create_session:
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
await self._update_job_status_impl(
|
||||
session, job_id, status, progress, current_step,
|
||||
error_message, results, tenant_id
|
||||
)
|
||||
|
||||
# Update additional fields if provided
|
||||
if error_message or results:
|
||||
update_data = {}
|
||||
if error_message:
|
||||
update_data["error_message"] = error_message
|
||||
if results:
|
||||
# Ensure results are JSON-serializable before storing
|
||||
update_data["results"] = make_json_serializable(results)
|
||||
if status in ["completed", "failed"]:
|
||||
update_data["end_time"] = datetime.now(timezone.utc)
|
||||
|
||||
if update_data:
|
||||
await self.training_log_repo.update(existing_log.id, update_data)
|
||||
|
||||
await session.commit() # Explicit commit after updates
|
||||
|
||||
else:
|
||||
# Reuse provided session (don't commit - let caller control transaction)
|
||||
await self._init_repositories(session)
|
||||
await self._update_job_status_impl(
|
||||
session, job_id, status, progress, current_step,
|
||||
error_message, results, tenant_id, auto_commit=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Failed to update job status using repository",
|
||||
job_id=job_id,
|
||||
error=str(e))
|
||||
|
||||
async def _update_job_status_impl(self,
|
||||
session,
|
||||
job_id: str,
|
||||
status: str,
|
||||
progress: int = None,
|
||||
current_step: str = None,
|
||||
error_message: str = None,
|
||||
results: Dict = None,
|
||||
tenant_id: str = None,
|
||||
auto_commit: bool = True):
|
||||
"""Implementation of job status update"""
|
||||
# Check if log exists, create if not
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
|
||||
if not existing_log:
|
||||
# Create initial log entry
|
||||
if not tenant_id:
|
||||
# Extract tenant_id from job_id if not provided
|
||||
# Format: enhanced_training_{tenant_id}_{job_suffix}
|
||||
try:
|
||||
parts = job_id.split('_')
|
||||
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
|
||||
tenant_id = parts[2]
|
||||
except Exception:
|
||||
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
|
||||
|
||||
if tenant_id:
|
||||
log_data = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"status": status or "pending",
|
||||
"progress": progress or 0,
|
||||
"current_step": current_step or "initializing",
|
||||
"start_time": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
if error_message:
|
||||
log_data["error_message"] = error_message
|
||||
if results:
|
||||
# Ensure results are JSON-serializable before storing
|
||||
log_data["results"] = make_json_serializable(results)
|
||||
|
||||
try:
|
||||
await self.training_log_repo.create_training_log(log_data)
|
||||
if auto_commit:
|
||||
await session.commit() # Explicit commit so other sessions can see it
|
||||
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
|
||||
except Exception as create_error:
|
||||
# Handle race condition: another session may have created the log
|
||||
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
|
||||
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
|
||||
await session.rollback()
|
||||
# Query again to get the existing log
|
||||
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
|
||||
if existing_log:
|
||||
# Update the existing log instead
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step,
|
||||
status=status
|
||||
)
|
||||
if auto_commit:
|
||||
await session.commit()
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
logger.error("Cannot create training log without tenant_id", job_id=job_id)
|
||||
return
|
||||
else:
|
||||
# Update existing log
|
||||
await self.training_log_repo.update_log_progress(
|
||||
job_id=job_id,
|
||||
progress=progress,
|
||||
current_step=current_step,
|
||||
status=status
|
||||
)
|
||||
|
||||
# Update additional fields if provided
|
||||
if error_message or results:
|
||||
update_data = {}
|
||||
if error_message:
|
||||
update_data["error_message"] = error_message
|
||||
if results:
|
||||
# Ensure results are JSON-serializable before storing
|
||||
update_data["results"] = make_json_serializable(results)
|
||||
if status in ["completed", "failed"]:
|
||||
update_data["end_time"] = datetime.now(timezone.utc)
|
||||
|
||||
if update_data:
|
||||
await self.training_log_repo.update(existing_log.id, update_data)
|
||||
|
||||
if auto_commit:
|
||||
await session.commit() # Explicit commit after updates
|
||||
|
||||
async def start_single_product_training(self,
|
||||
tenant_id: str,
|
||||
inventory_product_id: str,
|
||||
job_id: str,
|
||||
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
|
||||
"""Start enhanced single product training using repository pattern"""
|
||||
try:
|
||||
logger.info("Starting enhanced single product training",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id)
|
||||
|
||||
# Create initial training log
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Fetching training data",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Prepare training data for all products to get weather/traffic data
|
||||
# then filter down to the specific product
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
bakery_location=bakery_location,
|
||||
job_id=job_id + "_temp"
|
||||
)
|
||||
|
||||
# Filter sales data to the specific product
|
||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
||||
|
||||
if product_sales_df.empty:
|
||||
raise ValueError(f"No sales data available for product {inventory_product_id}")
|
||||
|
||||
# Prepare the data in Prophet format (ds and y columns)
|
||||
# Ensure proper column names and types for Prophet
|
||||
product_data = product_sales_df.copy()
|
||||
product_data = product_data.rename(columns={
|
||||
'sale_date': 'ds', # Common sales date column
|
||||
'sale_datetime': 'ds', # Alternative date column
|
||||
'date': 'ds', # Alternative date column
|
||||
'quantity': 'y', # Quantity sold
|
||||
'total_amount': 'y', # Alternative for sales data
|
||||
'sales_amount': 'y', # Alternative for sales data
|
||||
'sale_amount': 'y' # Alternative for sales data
|
||||
})
|
||||
|
||||
# If 'ds' and 'y' columns are not renamed properly, try to infer them
|
||||
if 'ds' not in product_data.columns:
|
||||
# Try to find date-like columns
|
||||
date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()]
|
||||
if date_cols:
|
||||
product_data = product_data.rename(columns={date_cols[0]: 'ds'})
|
||||
|
||||
if 'y' not in product_data.columns:
|
||||
# Try to find sales/quantity-like columns
|
||||
sales_cols = [col for col in product_data.columns if
|
||||
any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])]
|
||||
if sales_cols:
|
||||
product_data = product_data.rename(columns={sales_cols[0]: 'y'})
|
||||
|
||||
# Ensure required columns exist
|
||||
if 'ds' not in product_data.columns or 'y' not in product_data.columns:
|
||||
raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}")
|
||||
|
||||
# Convert the date column to datetime if it's not already
|
||||
product_data['ds'] = pd.to_datetime(product_data['ds'])
|
||||
|
||||
# Convert to numeric ensuring no pandas/numpy objects remain
|
||||
product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce')
|
||||
|
||||
# Sort by date to ensure proper chronological order
|
||||
product_data = product_data.sort_values('ds').reset_index(drop=True)
|
||||
|
||||
# Drop any rows with NaN values
|
||||
product_data = product_data.dropna(subset=['ds', 'y'])
|
||||
|
||||
# Ensure the data is in the right format for Prophet
|
||||
product_data = product_data[['ds', 'y']].copy()
|
||||
"""Start enhanced single product training using repository pattern with single session"""
|
||||
# Create a single database session for all operations to avoid connection pool exhaustion
|
||||
async with self.database_manager.get_session() as session:
|
||||
await self._init_repositories(session)
|
||||
|
||||
# Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations)
|
||||
product_data['ds'] = pd.to_datetime(product_data['ds'])
|
||||
product_data['y'] = product_data['y'].astype(float)
|
||||
|
||||
# DEBUG: Log data types to diagnose dict comparison error
|
||||
logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}")
|
||||
logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}")
|
||||
logger.info(f"DEBUG: Attempting to get min/max...")
|
||||
try:
|
||||
min_val = product_data['ds'].min()
|
||||
max_val = product_data['ds'].max()
|
||||
logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}")
|
||||
logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}")
|
||||
except Exception as debug_e:
|
||||
logger.error(f"DEBUG: Failed to get min/max: {debug_e}")
|
||||
import traceback
|
||||
logger.error(f"DEBUG: Traceback: {traceback.format_exc()}")
|
||||
logger.info("Starting enhanced single product training",
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
job_id=job_id)
|
||||
|
||||
logger.info("Prepared training data for single product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
|
||||
|
||||
# Update progress
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=30,
|
||||
current_step="Training model",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Train the model using the trainer
|
||||
# Extract datetime values with proper pandas Timestamp wrapper for type safety
|
||||
try:
|
||||
training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime()
|
||||
training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime()
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Failed to extract training dates: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}")
|
||||
logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}")
|
||||
raise
|
||||
|
||||
# Run the actual training
|
||||
try:
|
||||
model_info = await self.trainer.train_single_product_model(
|
||||
# Create initial training log (using shared session)
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=0,
|
||||
current_step="Fetching training data",
|
||||
tenant_id=tenant_id,
|
||||
session=session
|
||||
)
|
||||
await session.commit() # Commit after initial log creation
|
||||
|
||||
# Prepare training data for all products to get weather/traffic data
|
||||
# then filter down to the specific product
|
||||
training_dataset = await self.orchestrator.prepare_training_data(
|
||||
tenant_id=tenant_id,
|
||||
bakery_location=bakery_location,
|
||||
job_id=job_id + "_temp"
|
||||
)
|
||||
|
||||
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
|
||||
# Filter sales data to the specific product first
|
||||
sales_df = pd.DataFrame(training_dataset.sales_data)
|
||||
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
|
||||
|
||||
if product_sales_df.empty:
|
||||
raise ValueError(f"No sales data available for product {inventory_product_id}")
|
||||
|
||||
# Get weather and traffic data as DataFrames
|
||||
weather_df = pd.DataFrame(training_dataset.weather_data)
|
||||
traffic_df = pd.DataFrame(training_dataset.traffic_data)
|
||||
|
||||
# Get POI features from the training dataset (already collected by orchestrator)
|
||||
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
|
||||
|
||||
# Use the enhanced data processor to merge all features properly
|
||||
# This will include POI, weather, traffic features along with ds and y
|
||||
from app.ml.data_processor import EnhancedBakeryDataProcessor
|
||||
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
|
||||
|
||||
product_data = await data_processor.prepare_training_data(
|
||||
sales_data=product_sales_df,
|
||||
weather_data=weather_df,
|
||||
traffic_data=traffic_df,
|
||||
inventory_product_id=inventory_product_id,
|
||||
training_data=product_data,
|
||||
poi_features=poi_features,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Training failed with error: {e}")
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
# Update progress
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=80,
|
||||
current_step="Saving model",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# The model should already be saved by train_single_product_model
|
||||
# Return appropriate response
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": "completed",
|
||||
"message": "Enhanced single product training completed successfully",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 15, # Default estimate for single product
|
||||
"training_results": {
|
||||
"total_products": 1,
|
||||
"successful_trainings": 1,
|
||||
"failed_trainings": 0,
|
||||
"products": [{
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": "completed",
|
||||
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
|
||||
"data_points": len(product_data) if product_data is not None else 0,
|
||||
# Filter metrics to ensure only numeric values are included
|
||||
"metrics": {
|
||||
k: float(v) if not isinstance(v, (int, float)) else v
|
||||
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
|
||||
if k != 'product_category' and v is not None
|
||||
}
|
||||
}],
|
||||
"overall_training_time_seconds": model_info.get('training_time', 45.2)
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Enhanced single product training failed",
|
||||
|
||||
if product_data.empty:
|
||||
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
|
||||
|
||||
logger.info("Prepared training data for single product",
|
||||
inventory_product_id=inventory_product_id,
|
||||
data_points=len(product_data),
|
||||
features=list(product_data.columns),
|
||||
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
|
||||
|
||||
# Update progress (using shared session)
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=30,
|
||||
current_step="Training model",
|
||||
tenant_id=tenant_id,
|
||||
session=session
|
||||
)
|
||||
await session.commit() # Commit progress update
|
||||
|
||||
# Run the actual training (passing the session to avoid nested session creation)
|
||||
try:
|
||||
model_info = await self.trainer.train_single_product_model(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
|
||||
# Update status to failed
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Training failed",
|
||||
error_message=str(e),
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
raise
|
||||
training_data=product_data,
|
||||
job_id=job_id,
|
||||
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"Training failed with error: {e}")
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
# Update progress (using shared session)
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="running",
|
||||
progress=80,
|
||||
current_step="Saving model",
|
||||
tenant_id=tenant_id,
|
||||
session=session
|
||||
)
|
||||
await session.commit() # Commit progress update
|
||||
|
||||
# The model should already be saved by train_single_product_model
|
||||
# Return appropriate response
|
||||
return {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": "completed",
|
||||
"message": "Enhanced single product training completed successfully",
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"estimated_duration_minutes": 15, # Default estimate for single product
|
||||
"training_results": {
|
||||
"total_products": 1,
|
||||
"successful_trainings": 1,
|
||||
"failed_trainings": 0,
|
||||
"products": [{
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": "completed",
|
||||
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
|
||||
"data_points": len(product_data) if product_data is not None else 0,
|
||||
# Filter metrics to ensure only numeric values are included
|
||||
"metrics": {
|
||||
k: float(v) if not isinstance(v, (int, float)) else v
|
||||
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
|
||||
if k != 'product_category' and v is not None
|
||||
}
|
||||
}],
|
||||
"overall_training_time_seconds": model_info.get('training_time', 45.2)
|
||||
},
|
||||
"enhanced_features": True,
|
||||
"repository_integration": True,
|
||||
"completed_at": datetime.now(timezone.utc).isoformat()
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Enhanced single product training failed",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
|
||||
# Update status to failed (using shared session)
|
||||
await self._update_job_status_repository(
|
||||
job_id=job_id,
|
||||
status="failed",
|
||||
progress=0,
|
||||
current_step="Training failed",
|
||||
error_message=str(e),
|
||||
tenant_id=tenant_id,
|
||||
session=session
|
||||
)
|
||||
await session.commit() # Commit failure status
|
||||
|
||||
raise
|
||||
|
||||
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Convert final result to detailed training response"""
|
||||
|
||||
Reference in New Issue
Block a user