Bug fixes of training
This commit is contained in:
@@ -5,11 +5,12 @@ Integrates POI features into ML training pipeline.
|
||||
Fetches POI context from External service and merges features into training data.
|
||||
"""
|
||||
|
||||
import httpx
|
||||
from typing import Dict, Any, Optional, List
|
||||
import structlog
|
||||
import pandas as pd
|
||||
|
||||
from shared.clients.external_client import ExternalServiceClient
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -21,15 +22,18 @@ class POIFeatureIntegrator:
|
||||
to training dataframes for location-based demand forecasting.
|
||||
"""
|
||||
|
||||
def __init__(self, external_service_url: str = "http://external-service:8000"):
|
||||
def __init__(self, external_client: ExternalServiceClient = None):
|
||||
"""
|
||||
Initialize POI feature integrator.
|
||||
|
||||
Args:
|
||||
external_service_url: Base URL for external service
|
||||
external_client: External service client instance (optional)
|
||||
"""
|
||||
self.external_service_url = external_service_url.rstrip("/")
|
||||
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
|
||||
if external_client is None:
|
||||
from app.core.config import settings
|
||||
self.external_client = ExternalServiceClient(settings, "training-service")
|
||||
else:
|
||||
self.external_client = external_client
|
||||
|
||||
async def fetch_poi_features(
|
||||
self,
|
||||
@@ -53,57 +57,49 @@ class POIFeatureIntegrator:
|
||||
Dictionary with POI features or None if detection fails
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=60.0) as client:
|
||||
# Try to get existing POI context first
|
||||
if not force_refresh:
|
||||
try:
|
||||
response = await client.get(
|
||||
f"{self.poi_context_endpoint}/{tenant_id}"
|
||||
)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
poi_context = data.get("poi_context", {})
|
||||
# Try to get existing POI context first
|
||||
if not force_refresh:
|
||||
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||
if existing_context:
|
||||
poi_context = existing_context.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
# Check if stale
|
||||
if not data.get("is_stale", False):
|
||||
logger.info(
|
||||
"Using existing POI context",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return poi_context.get("ml_features", {})
|
||||
else:
|
||||
logger.info(
|
||||
"POI context is stale, refreshing",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
force_refresh = True
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code != 404:
|
||||
raise
|
||||
# Check if stale
|
||||
is_stale = existing_context.get("is_stale", False)
|
||||
if not is_stale:
|
||||
logger.info(
|
||||
"No existing POI context, will detect",
|
||||
"Using existing POI context",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.info(
|
||||
"POI context is stale, refreshing",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
force_refresh = True
|
||||
else:
|
||||
logger.info(
|
||||
"No existing POI context, will detect",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Detect or refresh POIs
|
||||
logger.info(
|
||||
"Detecting POIs for tenant",
|
||||
tenant_id=tenant_id,
|
||||
location=(latitude, longitude)
|
||||
)
|
||||
# Detect or refresh POIs
|
||||
logger.info(
|
||||
"Detecting POIs for tenant",
|
||||
tenant_id=tenant_id,
|
||||
location=(latitude, longitude)
|
||||
)
|
||||
|
||||
response = await client.post(
|
||||
f"{self.poi_context_endpoint}/{tenant_id}/detect",
|
||||
params={
|
||||
"latitude": latitude,
|
||||
"longitude": longitude,
|
||||
"force_refresh": force_refresh
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||
tenant_id=tenant_id,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=force_refresh
|
||||
)
|
||||
|
||||
result = response.json()
|
||||
poi_context = result.get("poi_context", {})
|
||||
if detection_result:
|
||||
poi_context = detection_result.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
logger.info(
|
||||
@@ -114,15 +110,13 @@ class POIFeatureIntegrator:
|
||||
)
|
||||
|
||||
return ml_features
|
||||
else:
|
||||
logger.error(
|
||||
"POI detection failed",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return None
|
||||
|
||||
except httpx.HTTPError as e:
|
||||
logger.error(
|
||||
"Failed to fetch POI features",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error fetching POI features",
|
||||
@@ -185,17 +179,18 @@ class POIFeatureIntegrator:
|
||||
|
||||
async def check_poi_service_health(self) -> bool:
|
||||
"""
|
||||
Check if POI service is accessible.
|
||||
Check if POI service is accessible through the external client.
|
||||
|
||||
Returns:
|
||||
True if service is healthy, False otherwise
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(
|
||||
f"{self.poi_context_endpoint}/health"
|
||||
)
|
||||
return response.status_code == 200
|
||||
# We can test the external service health by attempting to get POI context for a dummy tenant
|
||||
# This will go through the proper authentication and routing
|
||||
dummy_context = await self.external_client.get_poi_context("test-tenant")
|
||||
# If we can successfully make a request (even if it returns None for missing tenant),
|
||||
# it means the service is accessible
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"POI service health check failed",
|
||||
|
||||
@@ -375,40 +375,143 @@ class EnhancedBakeryMLTrainer:
|
||||
try:
|
||||
# Use provided session or create new one to prevent nested sessions and deadlocks
|
||||
should_create_session = session is None
|
||||
db_session = session if session is not None else None
|
||||
|
||||
|
||||
if should_create_session:
|
||||
# Only create a session if one wasn't provided
|
||||
async with self.database_manager.get_session() as db_session:
|
||||
repos = await self._get_repositories(db_session)
|
||||
|
||||
|
||||
# Validate input data
|
||||
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
|
||||
|
||||
# Validate required columns
|
||||
required_columns = ['ds', 'y']
|
||||
missing_cols = [col for col in required_columns if col not in training_data.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"Missing required columns in training data: {missing_cols}")
|
||||
|
||||
# Create a simple progress tracker for single product
|
||||
from app.services.progress_tracker import ParallelProductProgressTracker
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=1
|
||||
)
|
||||
|
||||
# Ensure training data has proper data types before training
|
||||
if 'ds' in training_data.columns:
|
||||
training_data['ds'] = pd.to_datetime(training_data['ds'])
|
||||
if 'y' in training_data.columns:
|
||||
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
|
||||
|
||||
# Remove any rows with NaN values
|
||||
training_data = training_data.dropna()
|
||||
|
||||
# Train the model using the existing _train_single_product method
|
||||
product_id, result = await self._train_single_product(
|
||||
tenant_id=tenant_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
product_data=training_data,
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker,
|
||||
session=db_session # Pass the session to prevent nested sessions
|
||||
)
|
||||
|
||||
logger.info("Single product training completed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
result_status=result.get('status'))
|
||||
|
||||
# Write training result to database (create model record)
|
||||
if result.get('status') == 'success':
|
||||
model_info = result.get('model_info')
|
||||
product_data = result.get('product_data')
|
||||
|
||||
if model_info and product_data is not None:
|
||||
# Create model record in database
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics
|
||||
if model_info.get('training_metrics') and model_record:
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
# Update result with model_record_id
|
||||
result['model_record_id'] = str(model_record.id) if model_record else None
|
||||
|
||||
# Get training metrics and filter out non-numeric values
|
||||
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
|
||||
# Filter metrics to only include numeric values (per Pydantic schema requirement)
|
||||
filtered_metrics = {}
|
||||
for key, value in raw_metrics.items():
|
||||
if key == 'product_category':
|
||||
# Skip product_category as it's a string value, not a numeric metric
|
||||
continue
|
||||
try:
|
||||
# Try to convert to float for validation
|
||||
filtered_metrics[key] = float(value) if value is not None else 0.0
|
||||
except (ValueError, TypeError):
|
||||
# Skip non-numeric values
|
||||
continue
|
||||
|
||||
# Return appropriate result format
|
||||
result_dict = {
|
||||
"job_id": job_id,
|
||||
"tenant_id": tenant_id,
|
||||
"inventory_product_id": inventory_product_id,
|
||||
"status": result.get('status', 'success'),
|
||||
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
|
||||
"training_metrics": filtered_metrics,
|
||||
"training_time": result.get('training_time_seconds', 0),
|
||||
"data_points": result.get('data_points', 0),
|
||||
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
||||
}
|
||||
|
||||
# Only commit if this is our own session (not a parent session)
|
||||
# Commit after we're done with all database operations
|
||||
await db_session.commit()
|
||||
logger.info("Committed single product model record to database",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=result.get('model_record_id'))
|
||||
|
||||
return result_dict
|
||||
else:
|
||||
# Use the provided session
|
||||
repos = await self._get_repositories(session)
|
||||
|
||||
# Validate input data
|
||||
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
|
||||
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
|
||||
|
||||
|
||||
# Validate required columns
|
||||
required_columns = ['ds', 'y']
|
||||
missing_cols = [col for col in required_columns if col not in training_data.columns]
|
||||
if missing_cols:
|
||||
raise ValueError(f"Missing required columns in training data: {missing_cols}")
|
||||
|
||||
|
||||
# Create a simple progress tracker for single product
|
||||
from app.services.progress_tracker import ParallelProductProgressTracker
|
||||
progress_tracker = ParallelProductProgressTracker(
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
job_id=job_id,
|
||||
tenant_id=tenant_id,
|
||||
total_products=1
|
||||
)
|
||||
|
||||
|
||||
# Ensure training data has proper data types before training
|
||||
if 'ds' in training_data.columns:
|
||||
training_data['ds'] = pd.to_datetime(training_data['ds'])
|
||||
if 'y' in training_data.columns:
|
||||
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
|
||||
|
||||
|
||||
# Remove any rows with NaN values
|
||||
training_data = training_data.dropna()
|
||||
|
||||
|
||||
# Train the model using the existing _train_single_product method
|
||||
product_id, result = await self._train_single_product(
|
||||
tenant_id=tenant_id,
|
||||
@@ -417,14 +520,35 @@ class EnhancedBakeryMLTrainer:
|
||||
job_id=job_id,
|
||||
repos=repos,
|
||||
progress_tracker=progress_tracker,
|
||||
session=db_session # Pass the session to prevent nested sessions
|
||||
session=session # Pass the provided session
|
||||
)
|
||||
|
||||
|
||||
logger.info("Single product training completed",
|
||||
job_id=job_id,
|
||||
inventory_product_id=inventory_product_id,
|
||||
result_status=result.get('status'))
|
||||
|
||||
|
||||
# Write training result to database (create model record)
|
||||
if result.get('status') == 'success':
|
||||
model_info = result.get('model_info')
|
||||
product_data = result.get('product_data')
|
||||
|
||||
if model_info and product_data is not None:
|
||||
# Create model record in database
|
||||
model_record = await self._create_model_record(
|
||||
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
|
||||
)
|
||||
|
||||
# Create performance metrics
|
||||
if model_info.get('training_metrics') and model_record:
|
||||
await self._create_performance_metrics(
|
||||
repos, model_record.id,
|
||||
tenant_id, inventory_product_id, model_info['training_metrics']
|
||||
)
|
||||
|
||||
# Update result with model_record_id
|
||||
result['model_record_id'] = str(model_record.id) if model_record else None
|
||||
|
||||
# Get training metrics and filter out non-numeric values
|
||||
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
|
||||
# Filter metrics to only include numeric values (per Pydantic schema requirement)
|
||||
@@ -439,7 +563,7 @@ class EnhancedBakeryMLTrainer:
|
||||
except (ValueError, TypeError):
|
||||
# Skip non-numeric values
|
||||
continue
|
||||
|
||||
|
||||
# Return appropriate result format
|
||||
result_dict = {
|
||||
"job_id": job_id,
|
||||
@@ -452,7 +576,13 @@ class EnhancedBakeryMLTrainer:
|
||||
"data_points": result.get('data_points', 0),
|
||||
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
|
||||
}
|
||||
|
||||
|
||||
# For provided sessions, do NOT commit here - let the calling method handle commits
|
||||
# This prevents committing a parent transaction prematurely
|
||||
logger.info("Single product model processed (commit handled by caller)",
|
||||
inventory_product_id=inventory_product_id,
|
||||
model_record_id=result.get('model_record_id'))
|
||||
|
||||
return result_dict
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user