Bug fixes of training

This commit is contained in:
Urtzi Alfaro
2025-11-14 20:27:39 +01:00
parent 71f9ca9d65
commit c349b845a6
11 changed files with 606 additions and 408 deletions

View File

@@ -5,11 +5,12 @@ Integrates POI features into ML training pipeline.
Fetches POI context from External service and merges features into training data.
"""
import httpx
from typing import Dict, Any, Optional, List
import structlog
import pandas as pd
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
@@ -21,15 +22,18 @@ class POIFeatureIntegrator:
to training dataframes for location-based demand forecasting.
"""
def __init__(self, external_service_url: str = "http://external-service:8000"):
def __init__(self, external_client: ExternalServiceClient = None):
"""
Initialize POI feature integrator.
Args:
external_service_url: Base URL for external service
external_client: External service client instance (optional)
"""
self.external_service_url = external_service_url.rstrip("/")
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
if external_client is None:
from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "training-service")
else:
self.external_client = external_client
async def fetch_poi_features(
self,
@@ -53,57 +57,49 @@ class POIFeatureIntegrator:
Dictionary with POI features or None if detection fails
"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Try to get existing POI context first
if not force_refresh:
try:
response = await client.get(
f"{self.poi_context_endpoint}/{tenant_id}"
)
if response.status_code == 200:
data = response.json()
poi_context = data.get("poi_context", {})
# Try to get existing POI context first
if not force_refresh:
existing_context = await self.external_client.get_poi_context(tenant_id)
if existing_context:
poi_context = existing_context.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
# Check if stale
if not data.get("is_stale", False):
logger.info(
"Using existing POI context",
tenant_id=tenant_id
)
return poi_context.get("ml_features", {})
else:
logger.info(
"POI context is stale, refreshing",
tenant_id=tenant_id
)
force_refresh = True
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
# Check if stale
is_stale = existing_context.get("is_stale", False)
if not is_stale:
logger.info(
"No existing POI context, will detect",
"Using existing POI context",
tenant_id=tenant_id
)
return ml_features
else:
logger.info(
"POI context is stale, refreshing",
tenant_id=tenant_id
)
force_refresh = True
else:
logger.info(
"No existing POI context, will detect",
tenant_id=tenant_id
)
# Detect or refresh POIs
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude)
)
# Detect or refresh POIs
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude)
)
response = await client.post(
f"{self.poi_context_endpoint}/{tenant_id}/detect",
params={
"latitude": latitude,
"longitude": longitude,
"force_refresh": force_refresh
}
)
response.raise_for_status()
detection_result = await self.external_client.detect_poi_for_tenant(
tenant_id=tenant_id,
latitude=latitude,
longitude=longitude,
force_refresh=force_refresh
)
result = response.json()
poi_context = result.get("poi_context", {})
if detection_result:
poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
@@ -114,15 +110,13 @@ class POIFeatureIntegrator:
)
return ml_features
else:
logger.error(
"POI detection failed",
tenant_id=tenant_id
)
return None
except httpx.HTTPError as e:
logger.error(
"Failed to fetch POI features",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
return None
except Exception as e:
logger.error(
"Unexpected error fetching POI features",
@@ -185,17 +179,18 @@ class POIFeatureIntegrator:
async def check_poi_service_health(self) -> bool:
"""
Check if POI service is accessible.
Check if POI service is accessible through the external client.
Returns:
True if service is healthy, False otherwise
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.poi_context_endpoint}/health"
)
return response.status_code == 200
# We can test the external service health by attempting to get POI context for a dummy tenant
# This will go through the proper authentication and routing
dummy_context = await self.external_client.get_poi_context("test-tenant")
# If we can successfully make a request (even if it returns None for missing tenant),
# it means the service is accessible
return True
except Exception as e:
logger.error(
"POI service health check failed",

View File

@@ -375,40 +375,143 @@ class EnhancedBakeryMLTrainer:
try:
# Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None
db_session = session if session is not None else None
if should_create_session:
# Only create a session if one wasn't provided
async with self.database_manager.get_session() as db_session:
repos = await self._get_repositories(db_session)
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=1
)
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
# Remove any rows with NaN values
training_data = training_data.dropna()
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=training_data,
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker,
session=db_session # Pass the session to prevent nested sessions
)
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
filtered_metrics = {}
for key, value in raw_metrics.items():
if key == 'product_category':
# Skip product_category as it's a string value, not a numeric metric
continue
try:
# Try to convert to float for validation
filtered_metrics[key] = float(value) if value is not None else 0.0
except (ValueError, TypeError):
# Skip non-numeric values
continue
# Return appropriate result format
result_dict = {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": result.get('status', 'success'),
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
"training_metrics": filtered_metrics,
"training_time": result.get('training_time_seconds', 0),
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# Only commit if this is our own session (not a parent session)
# Commit after we're done with all database operations
await db_session.commit()
logger.info("Committed single product model record to database",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
else:
# Use the provided session
repos = await self._get_repositories(session)
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
job_id=job_id,
tenant_id=tenant_id,
total_products=1
)
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
# Remove any rows with NaN values
training_data = training_data.dropna()
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
@@ -417,14 +520,35 @@ class EnhancedBakeryMLTrainer:
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker,
session=db_session # Pass the session to prevent nested sessions
session=session # Pass the provided session
)
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
@@ -439,7 +563,7 @@ class EnhancedBakeryMLTrainer:
except (ValueError, TypeError):
# Skip non-numeric values
continue
# Return appropriate result format
result_dict = {
"job_id": job_id,
@@ -452,7 +576,13 @@ class EnhancedBakeryMLTrainer:
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# For provided sessions, do NOT commit here - let the calling method handle commits
# This prevents committing a parent transaction prematurely
logger.info("Single product model processed (commit handled by caller)",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
except Exception as e: