Bug fixes of training

This commit is contained in:
Urtzi Alfaro
2025-11-14 20:27:39 +01:00
parent 71f9ca9d65
commit c349b845a6
11 changed files with 606 additions and 408 deletions

View File

@@ -732,297 +732,293 @@ class EnhancedTrainingService:
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None):
"""Update job status using repository pattern"""
tenant_id: str = None,
session = None):
"""Update job status using repository pattern
Args:
session: Optional database session to reuse. If None, creates a new session.
"""
try:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
# Use provided session or create new one
should_create_session = session is None
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if not existing_log:
# Create initial log entry
if not tenant_id:
# Extract tenant_id from job_id if not provided
# Format: enhanced_training_{tenant_id}_{job_suffix}
try:
parts = job_id.split('_')
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
tenant_id = parts[2]
except Exception:
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
if tenant_id:
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": status or "pending",
"progress": progress or 0,
"current_step": current_step or "initializing",
"start_time": datetime.now(timezone.utc)
}
if error_message:
log_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
log_data["results"] = make_json_serializable(results)
try:
await self.training_log_repo.create_training_log(log_data)
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
# Handle race condition: another session may have created the log
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
await session.rollback()
# Query again to get the existing log
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
# Update the existing log instead
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
await session.commit()
else:
raise
else:
logger.error("Cannot create training log without tenant_id", job_id=job_id)
return
else:
# Update existing log
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
if should_create_session:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id
)
# Update additional fields if provided
if error_message or results:
update_data = {}
if error_message:
update_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
update_data["results"] = make_json_serializable(results)
if status in ["completed", "failed"]:
update_data["end_time"] = datetime.now(timezone.utc)
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
await session.commit() # Explicit commit after updates
else:
# Reuse provided session (don't commit - let caller control transaction)
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id, auto_commit=False
)
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
async def _update_job_status_impl(self,
session,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None,
auto_commit: bool = True):
"""Implementation of job status update"""
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if not existing_log:
# Create initial log entry
if not tenant_id:
# Extract tenant_id from job_id if not provided
# Format: enhanced_training_{tenant_id}_{job_suffix}
try:
parts = job_id.split('_')
if len(parts) >= 3 and parts[0] == 'enhanced' and parts[1] == 'training':
tenant_id = parts[2]
except Exception:
logger.warning(f"Could not extract tenant_id from job_id {job_id}")
if tenant_id:
log_data = {
"job_id": job_id,
"tenant_id": tenant_id,
"status": status or "pending",
"progress": progress or 0,
"current_step": current_step or "initializing",
"start_time": datetime.now(timezone.utc)
}
if error_message:
log_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
log_data["results"] = make_json_serializable(results)
try:
await self.training_log_repo.create_training_log(log_data)
if auto_commit:
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
# Handle race condition: another session may have created the log
if "unique constraint" in str(create_error).lower() or "duplicate" in str(create_error).lower():
logger.debug("Training log already exists (race condition), querying again", job_id=job_id)
await session.rollback()
# Query again to get the existing log
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
if existing_log:
# Update the existing log instead
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
if auto_commit:
await session.commit()
else:
raise
else:
logger.error("Cannot create training log without tenant_id", job_id=job_id)
return
else:
# Update existing log
await self.training_log_repo.update_log_progress(
job_id=job_id,
progress=progress,
current_step=current_step,
status=status
)
# Update additional fields if provided
if error_message or results:
update_data = {}
if error_message:
update_data["error_message"] = error_message
if results:
# Ensure results are JSON-serializable before storing
update_data["results"] = make_json_serializable(results)
if status in ["completed", "failed"]:
update_data["end_time"] = datetime.now(timezone.utc)
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
if auto_commit:
await session.commit() # Explicit commit after updates
async def start_single_product_training(self,
tenant_id: str,
inventory_product_id: str,
job_id: str,
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
"""Start enhanced single product training using repository pattern"""
try:
logger.info("Starting enhanced single product training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id)
# Create initial training log
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Fetching training data",
tenant_id=tenant_id
)
# Prepare training data for all products to get weather/traffic data
# then filter down to the specific product
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
job_id=job_id + "_temp"
)
# Filter sales data to the specific product
sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}")
# Prepare the data in Prophet format (ds and y columns)
# Ensure proper column names and types for Prophet
product_data = product_sales_df.copy()
product_data = product_data.rename(columns={
'sale_date': 'ds', # Common sales date column
'sale_datetime': 'ds', # Alternative date column
'date': 'ds', # Alternative date column
'quantity': 'y', # Quantity sold
'total_amount': 'y', # Alternative for sales data
'sales_amount': 'y', # Alternative for sales data
'sale_amount': 'y' # Alternative for sales data
})
# If 'ds' and 'y' columns are not renamed properly, try to infer them
if 'ds' not in product_data.columns:
# Try to find date-like columns
date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()]
if date_cols:
product_data = product_data.rename(columns={date_cols[0]: 'ds'})
if 'y' not in product_data.columns:
# Try to find sales/quantity-like columns
sales_cols = [col for col in product_data.columns if
any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])]
if sales_cols:
product_data = product_data.rename(columns={sales_cols[0]: 'y'})
# Ensure required columns exist
if 'ds' not in product_data.columns or 'y' not in product_data.columns:
raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}")
# Convert the date column to datetime if it's not already
product_data['ds'] = pd.to_datetime(product_data['ds'])
# Convert to numeric ensuring no pandas/numpy objects remain
product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce')
# Sort by date to ensure proper chronological order
product_data = product_data.sort_values('ds').reset_index(drop=True)
# Drop any rows with NaN values
product_data = product_data.dropna(subset=['ds', 'y'])
# Ensure the data is in the right format for Prophet
product_data = product_data[['ds', 'y']].copy()
"""Start enhanced single product training using repository pattern with single session"""
# Create a single database session for all operations to avoid connection pool exhaustion
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
# Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations)
product_data['ds'] = pd.to_datetime(product_data['ds'])
product_data['y'] = product_data['y'].astype(float)
# DEBUG: Log data types to diagnose dict comparison error
logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}")
logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}")
logger.info(f"DEBUG: Attempting to get min/max...")
try:
min_val = product_data['ds'].min()
max_val = product_data['ds'].max()
logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}")
logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}")
except Exception as debug_e:
logger.error(f"DEBUG: Failed to get min/max: {debug_e}")
import traceback
logger.error(f"DEBUG: Traceback: {traceback.format_exc()}")
logger.info("Starting enhanced single product training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id)
logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=30,
current_step="Training model",
tenant_id=tenant_id
)
# Train the model using the trainer
# Extract datetime values with proper pandas Timestamp wrapper for type safety
try:
training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime()
training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime()
except Exception as e:
import traceback
logger.error(f"Failed to extract training dates: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}")
logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}")
raise
# Run the actual training
try:
model_info = await self.trainer.train_single_product_model(
# Create initial training log (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Fetching training data",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit after initial log creation
# Prepare training data for all products to get weather/traffic data
# then filter down to the specific product
training_dataset = await self.orchestrator.prepare_training_data(
tenant_id=tenant_id,
bakery_location=bakery_location,
job_id=job_id + "_temp"
)
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
# Filter sales data to the specific product first
sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}")
# Get weather and traffic data as DataFrames
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# Get POI features from the training dataset (already collected by orchestrator)
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
# Use the enhanced data processor to merge all features properly
# This will include POI, weather, traffic features along with ds and y
from app.ml.data_processor import EnhancedBakeryDataProcessor
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
product_data = await data_processor.prepare_training_data(
sales_data=product_sales_df,
weather_data=weather_df,
traffic_data=traffic_df,
inventory_product_id=inventory_product_id,
training_data=product_data,
poi_features=poi_features,
tenant_id=tenant_id,
job_id=job_id
)
except Exception as e:
import traceback
logger.error(f"Training failed with error: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
# Update progress
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=80,
current_step="Saving model",
tenant_id=tenant_id
)
# The model should already be saved by train_single_product_model
# Return appropriate response
return {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": "completed",
"message": "Enhanced single product training completed successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 1,
"failed_trainings": 0,
"products": [{
"inventory_product_id": inventory_product_id,
"status": "completed",
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
"data_points": len(product_data) if product_data is not None else 0,
# Filter metrics to ensure only numeric values are included
"metrics": {
k: float(v) if not isinstance(v, (int, float)) else v
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
if k != 'product_category' and v is not None
}
}],
"overall_training_time_seconds": model_info.get('training_time', 45.2)
},
"enhanced_features": True,
"repository_integration": True,
"completed_at": datetime.now(timezone.utc).isoformat()
}
except Exception as e:
logger.error("Enhanced single product training failed",
if product_data.empty:
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
features=list(product_data.columns),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=30,
current_step="Training model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# Run the actual training (passing the session to avoid nested session creation)
try:
model_info = await self.trainer.train_single_product_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
error=str(e))
# Update status to failed
await self._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(e),
tenant_id=tenant_id
)
raise
training_data=product_data,
job_id=job_id,
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
)
except Exception as e:
import traceback
logger.error(f"Training failed with error: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=80,
current_step="Saving model",
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# The model should already be saved by train_single_product_model
# Return appropriate response
return {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": "completed",
"message": "Enhanced single product training completed successfully",
"created_at": datetime.now(timezone.utc),
"estimated_duration_minutes": 15, # Default estimate for single product
"training_results": {
"total_products": 1,
"successful_trainings": 1,
"failed_trainings": 0,
"products": [{
"inventory_product_id": inventory_product_id,
"status": "completed",
"model_id": str(model_info.get('model_id', f"model_{inventory_product_id}_{job_id[:8]}")) if model_info.get('model_id') else None,
"data_points": len(product_data) if product_data is not None else 0,
# Filter metrics to ensure only numeric values are included
"metrics": {
k: float(v) if not isinstance(v, (int, float)) else v
for k, v in model_info.get('training_metrics', {"mape": 0.0, "mae": 0.0, "rmse": 0.0, "r2_score": 0.0}).items()
if k != 'product_category' and v is not None
}
}],
"overall_training_time_seconds": model_info.get('training_time', 45.2)
},
"enhanced_features": True,
"repository_integration": True,
"completed_at": datetime.now(timezone.utc).isoformat()
}
except Exception as e:
logger.error("Enhanced single product training failed",
inventory_product_id=inventory_product_id,
error=str(e))
# Update status to failed (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(e),
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit failure status
raise
def _create_detailed_training_response(self, final_result: Dict[str, Any]) -> Dict[str, Any]:
"""Convert final result to detailed training response"""