Bug fixes of training

This commit is contained in:
Urtzi Alfaro
2025-11-14 20:27:39 +01:00
parent 71f9ca9d65
commit c349b845a6
11 changed files with 606 additions and 408 deletions

View File

@@ -41,8 +41,3 @@ nodes:
- containerPort: 30800
hostPort: 8000
protocol: TCP
sysctls:
# Increase fs.inotify limits to prevent "too many open files" errors
fs.inotify.max_user_watches: 524288
fs.inotify.max_user_instances: 256
fs.inotify.max_queued_events: 32768

View File

@@ -100,14 +100,19 @@ The **External Service** integrates real-world data from Spanish sources to enha
## API Endpoints (Key Routes)
### POI Detection & Context
- `POST /poi-context/{tenant_id}/detect` - Detect POIs for tenant location (lat, long, force_refresh params)
- `GET /poi-context/{tenant_id}` - Get cached POI context for tenant
- `POST /poi-context/{tenant_id}/refresh` - Force refresh POI detection
- `DELETE /poi-context/{tenant_id}` - Delete POI context for tenant
- `GET /poi-context/{tenant_id}/feature-importance` - Get POI feature importance summary
- `GET /poi-context/{tenant_id}/competitor-analysis` - Get competitive analysis
- `GET /poi-context/health` - Check POI service and Overpass API health
- `GET /poi-context/cache/stats` - Get POI cache statistics
- `POST /api/v1/tenants/{tenant_id}/poi-context/detect` - Detect POIs for tenant location (lat, long, force_refresh params) - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/{tenant_id}/poi-context` - Get cached POI context for tenant - Direct access (bypasses gateway authentication)
- `POST /api/v1/tenants/{tenant_id}/poi-context/refresh` - Force refresh POI detection - Direct access (bypasses gateway authentication)
- `DELETE /api/v1/tenants/{tenant_id}/poi-context` - Delete POI context for tenant - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/{tenant_id}/poi-context/feature-importance` - Get POI feature importance summary - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/{tenant_id}/poi-context/competitor-analysis` - Get competitive analysis - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/poi-context/health` - Check POI service and Overpass API health - Direct access (bypasses gateway authentication)
- `GET /api/v1/tenants/poi-context/cache/stats` - Get POI cache statistics - Direct access (bypasses gateway authentication)
### Recommended Access Pattern:
- Services should use `/api/v1/tenants/{tenant_id}/external/poi-context` (detected POIs) and `/api/v1/tenants/{tenant_id}/external/poi-context/detect` (detection) via shared ExternalServiceClient through the API gateway for proper authentication and authorization.
**Note**: When using ExternalServiceClient through shared client, provide relative paths like `poi-context` and `poi-context/detect` - the client automatically constructs the full tenant-scoped path.
### Weather Data (AEMET)
- `GET /api/v1/external/weather/current` - Current weather for location

View File

@@ -18,13 +18,17 @@ from app.services.poi_refresh_service import POIRefreshService
from app.repositories.poi_context_repository import POIContextRepository
from app.cache.poi_cache_service import POICacheService
from app.core.redis_client import get_redis_client
from shared.routing.route_builder import RouteBuilder
logger = structlog.get_logger()
router = APIRouter(prefix="/tenants", tags=["POI Context"])
route_builder = RouteBuilder('external')
router = APIRouter(tags=["POI Context"])
@router.post("/{tenant_id}/poi-context/detect")
@router.post(
route_builder.build_base_route("poi-context/detect")
)
async def detect_pois_for_tenant(
tenant_id: str,
latitude: float = Query(..., description="Bakery latitude"),
@@ -297,7 +301,9 @@ async def detect_pois_for_tenant(
)
@router.get("/{tenant_id}/poi-context")
@router.get(
route_builder.build_base_route("poi-context")
)
async def get_poi_context(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -331,7 +337,9 @@ async def get_poi_context(
}
@router.post("/{tenant_id}/poi-context/refresh")
@router.post(
route_builder.build_base_route("poi-context/refresh")
)
async def refresh_poi_context(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -365,7 +373,9 @@ async def refresh_poi_context(
)
@router.delete("/{tenant_id}/poi-context")
@router.delete(
route_builder.build_base_route("poi-context")
)
async def delete_poi_context(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -393,7 +403,9 @@ async def delete_poi_context(
}
@router.get("/{tenant_id}/poi-context/feature-importance")
@router.get(
route_builder.build_base_route("poi-context/feature-importance")
)
async def get_feature_importance(
tenant_id: str,
db: AsyncSession = Depends(get_db)
@@ -430,7 +442,9 @@ async def get_feature_importance(
}
@router.get("/{tenant_id}/poi-context/competitor-analysis")
@router.get(
route_builder.build_base_route("poi-context/competitor-analysis")
)
async def get_competitor_analysis(
tenant_id: str,
db: AsyncSession = Depends(get_db)

View File

@@ -497,7 +497,7 @@ poi_features = await poi_service.fetch_poi_features(tenant_id)
- **Accuracy Improvement** - POI features contribute 5-10% accuracy improvement
**Endpoint Used:**
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
## Integration Points

View File

@@ -5,10 +5,11 @@ Fetches POI features for use in demand forecasting predictions.
Ensures feature consistency between training and prediction.
"""
import httpx
from typing import Dict, Any, Optional
import structlog
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
@@ -20,15 +21,18 @@ class POIFeatureService:
prediction uses the same features as training.
"""
def __init__(self, external_service_url: str = "http://external-service:8000"):
def __init__(self, external_client: ExternalServiceClient = None):
"""
Initialize POI feature service.
Args:
external_service_url: Base URL for external service
external_client: External service client instance (optional)
"""
self.external_service_url = external_service_url.rstrip("/")
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
if external_client is None:
from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "forecasting-service")
else:
self.external_client = external_client
async def get_poi_features(
self,
@@ -44,21 +48,10 @@ class POIFeatureService:
Dictionary with POI features or empty dict if not available
"""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(
f"{self.poi_context_endpoint}/{tenant_id}"
)
result = await self.external_client.get_poi_context(tenant_id)
if response.status_code == 404:
logger.warning(
"No POI context found for tenant",
tenant_id=tenant_id
)
return {}
response.raise_for_status()
data = response.json()
poi_context = data.get("poi_context", {})
if result:
poi_context = result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
@@ -68,17 +61,16 @@ class POIFeatureService:
)
return ml_features
except httpx.HTTPError as e:
logger.error(
"Failed to fetch POI features for forecasting",
tenant_id=tenant_id,
error=str(e)
else:
logger.warning(
"No POI context found for tenant",
tenant_id=tenant_id
)
return {}
except Exception as e:
logger.error(
"Unexpected error fetching POI features",
"Failed to fetch POI features for forecasting",
tenant_id=tenant_id,
error=str(e),
exc_info=True
@@ -87,17 +79,18 @@ class POIFeatureService:
async def check_poi_service_health(self) -> bool:
"""
Check if POI service is accessible.
Check if POI service is accessible through the external client.
Returns:
True if service is healthy, False otherwise
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.poi_context_endpoint}/health"
)
return response.status_code == 200
# Test the external service health by attempting to get POI context for a dummy tenant
# This will go through the proper authentication and routing
dummy_context = await self.external_client.get_poi_context("test-tenant")
# If we can successfully make a request (even if it returns None for missing tenant),
# it means the service is accessible
return True
except Exception as e:
logger.error(
"POI service health check failed",

View File

@@ -581,7 +581,7 @@ for feature_name in poi_features.keys():
```
**Endpoint Used:**
- `GET {EXTERNAL_SERVICE_URL}/poi-context/{tenant_id}` - Fetch POI features
- Via shared client: `/api/v1/tenants/{tenant_id}/external/poi-context` (routed through API Gateway)
## Integration Points

View File

@@ -5,11 +5,12 @@ Integrates POI features into ML training pipeline.
Fetches POI context from External service and merges features into training data.
"""
import httpx
from typing import Dict, Any, Optional, List
import structlog
import pandas as pd
from shared.clients.external_client import ExternalServiceClient
logger = structlog.get_logger()
@@ -21,15 +22,18 @@ class POIFeatureIntegrator:
to training dataframes for location-based demand forecasting.
"""
def __init__(self, external_service_url: str = "http://external-service:8000"):
def __init__(self, external_client: ExternalServiceClient = None):
"""
Initialize POI feature integrator.
Args:
external_service_url: Base URL for external service
external_client: External service client instance (optional)
"""
self.external_service_url = external_service_url.rstrip("/")
self.poi_context_endpoint = f"{self.external_service_url}/poi-context"
if external_client is None:
from app.core.config import settings
self.external_client = ExternalServiceClient(settings, "training-service")
else:
self.external_client = external_client
async def fetch_poi_features(
self,
@@ -53,33 +57,28 @@ class POIFeatureIntegrator:
Dictionary with POI features or None if detection fails
"""
try:
async with httpx.AsyncClient(timeout=60.0) as client:
# Try to get existing POI context first
if not force_refresh:
try:
response = await client.get(
f"{self.poi_context_endpoint}/{tenant_id}"
)
if response.status_code == 200:
data = response.json()
poi_context = data.get("poi_context", {})
existing_context = await self.external_client.get_poi_context(tenant_id)
if existing_context:
poi_context = existing_context.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
# Check if stale
if not data.get("is_stale", False):
is_stale = existing_context.get("is_stale", False)
if not is_stale:
logger.info(
"Using existing POI context",
tenant_id=tenant_id
)
return poi_context.get("ml_features", {})
return ml_features
else:
logger.info(
"POI context is stale, refreshing",
tenant_id=tenant_id
)
force_refresh = True
except httpx.HTTPStatusError as e:
if e.response.status_code != 404:
raise
else:
logger.info(
"No existing POI context, will detect",
tenant_id=tenant_id
@@ -92,18 +91,15 @@ class POIFeatureIntegrator:
location=(latitude, longitude)
)
response = await client.post(
f"{self.poi_context_endpoint}/{tenant_id}/detect",
params={
"latitude": latitude,
"longitude": longitude,
"force_refresh": force_refresh
}
detection_result = await self.external_client.detect_poi_for_tenant(
tenant_id=tenant_id,
latitude=latitude,
longitude=longitude,
force_refresh=force_refresh
)
response.raise_for_status()
result = response.json()
poi_context = result.get("poi_context", {})
if detection_result:
poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
@@ -114,15 +110,13 @@ class POIFeatureIntegrator:
)
return ml_features
except httpx.HTTPError as e:
else:
logger.error(
"Failed to fetch POI features",
tenant_id=tenant_id,
error=str(e),
exc_info=True
"POI detection failed",
tenant_id=tenant_id
)
return None
except Exception as e:
logger.error(
"Unexpected error fetching POI features",
@@ -185,17 +179,18 @@ class POIFeatureIntegrator:
async def check_poi_service_health(self) -> bool:
"""
Check if POI service is accessible.
Check if POI service is accessible through the external client.
Returns:
True if service is healthy, False otherwise
"""
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(
f"{self.poi_context_endpoint}/health"
)
return response.status_code == 200
# We can test the external service health by attempting to get POI context for a dummy tenant
# This will go through the proper authentication and routing
dummy_context = await self.external_client.get_poi_context("test-tenant")
# If we can successfully make a request (even if it returns None for missing tenant),
# it means the service is accessible
return True
except Exception as e:
logger.error(
"POI service health check failed",

View File

@@ -375,7 +375,6 @@ class EnhancedBakeryMLTrainer:
try:
# Use provided session or create new one to prevent nested sessions and deadlocks
should_create_session = session is None
db_session = session if session is not None else None
if should_create_session:
# Only create a session if one wasn't provided
@@ -425,6 +424,27 @@ class EnhancedBakeryMLTrainer:
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
@@ -453,6 +473,116 @@ class EnhancedBakeryMLTrainer:
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# Only commit if this is our own session (not a parent session)
# Commit after we're done with all database operations
await db_session.commit()
logger.info("Committed single product model record to database",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
else:
# Use the provided session
repos = await self._get_repositories(session)
# Validate input data
if training_data.empty or len(training_data) < settings.MIN_TRAINING_DATA_DAYS:
raise ValueError(f"Insufficient training data: need at least {settings.MIN_TRAINING_DATA_DAYS} data points, got {len(training_data)}")
# Validate required columns
required_columns = ['ds', 'y']
missing_cols = [col for col in required_columns if col not in training_data.columns]
if missing_cols:
raise ValueError(f"Missing required columns in training data: {missing_cols}")
# Create a simple progress tracker for single product
from app.services.progress_tracker import ParallelProductProgressTracker
progress_tracker = ParallelProductProgressTracker(
job_id=job_id,
tenant_id=tenant_id,
total_products=1
)
# Ensure training data has proper data types before training
if 'ds' in training_data.columns:
training_data['ds'] = pd.to_datetime(training_data['ds'])
if 'y' in training_data.columns:
training_data['y'] = pd.to_numeric(training_data['y'], errors='coerce')
# Remove any rows with NaN values
training_data = training_data.dropna()
# Train the model using the existing _train_single_product method
product_id, result = await self._train_single_product(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
product_data=training_data,
job_id=job_id,
repos=repos,
progress_tracker=progress_tracker,
session=session # Pass the provided session
)
logger.info("Single product training completed",
job_id=job_id,
inventory_product_id=inventory_product_id,
result_status=result.get('status'))
# Write training result to database (create model record)
if result.get('status') == 'success':
model_info = result.get('model_info')
product_data = result.get('product_data')
if model_info and product_data is not None:
# Create model record in database
model_record = await self._create_model_record(
repos, tenant_id, inventory_product_id, model_info, job_id, product_data
)
# Create performance metrics
if model_info.get('training_metrics') and model_record:
await self._create_performance_metrics(
repos, model_record.id,
tenant_id, inventory_product_id, model_info['training_metrics']
)
# Update result with model_record_id
result['model_record_id'] = str(model_record.id) if model_record else None
# Get training metrics and filter out non-numeric values
raw_metrics = result.get('model_info', {}).get('training_metrics', {})
# Filter metrics to only include numeric values (per Pydantic schema requirement)
filtered_metrics = {}
for key, value in raw_metrics.items():
if key == 'product_category':
# Skip product_category as it's a string value, not a numeric metric
continue
try:
# Try to convert to float for validation
filtered_metrics[key] = float(value) if value is not None else 0.0
except (ValueError, TypeError):
# Skip non-numeric values
continue
# Return appropriate result format
result_dict = {
"job_id": job_id,
"tenant_id": tenant_id,
"inventory_product_id": inventory_product_id,
"status": result.get('status', 'success'),
"model_id": str(result.get('model_record_id', '')) if result.get('model_record_id') else None,
"training_metrics": filtered_metrics,
"training_time": result.get('training_time_seconds', 0),
"data_points": result.get('data_points', 0),
"message": f"Single product model training {'completed' if result.get('status') != 'error' else 'failed'}"
}
# For provided sessions, do NOT commit here - let the calling method handle commits
# This prevents committing a parent transaction prematurely
logger.info("Single product model processed (commit handled by caller)",
inventory_product_id=inventory_product_id,
model_record_id=result.get('model_record_id'))
return result_dict
except Exception as e:

View File

@@ -186,6 +186,7 @@ class TrainedModel(Base):
"training_samples": self.training_samples,
"hyperparameters": self.hyperparameters,
"features_used": self.features_used,
"features": self.features_used, # Alias for frontend compatibility (ModelDetailsModal expects 'features')
"product_category": self.product_category,
"is_active": self.is_active,
"is_production": self.is_production,

View File

@@ -732,12 +732,47 @@ class EnhancedTrainingService:
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None):
"""Update job status using repository pattern"""
tenant_id: str = None,
session = None):
"""Update job status using repository pattern
Args:
session: Optional database session to reuse. If None, creates a new session.
"""
try:
# Use provided session or create new one
should_create_session = session is None
if should_create_session:
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id
)
else:
# Reuse provided session (don't commit - let caller control transaction)
await self._init_repositories(session)
await self._update_job_status_impl(
session, job_id, status, progress, current_step,
error_message, results, tenant_id, auto_commit=False
)
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
async def _update_job_status_impl(self,
session,
job_id: str,
status: str,
progress: int = None,
current_step: str = None,
error_message: str = None,
results: Dict = None,
tenant_id: str = None,
auto_commit: bool = True):
"""Implementation of job status update"""
# Check if log exists, create if not
existing_log = await self.training_log_repo.get_log_by_job_id(job_id)
@@ -771,6 +806,7 @@ class EnhancedTrainingService:
try:
await self.training_log_repo.create_training_log(log_data)
if auto_commit:
await session.commit() # Explicit commit so other sessions can see it
logger.info("Created initial training log", job_id=job_id, tenant_id=tenant_id)
except Exception as create_error:
@@ -788,6 +824,7 @@ class EnhancedTrainingService:
current_step=current_step,
status=status
)
if auto_commit:
await session.commit()
else:
raise
@@ -817,33 +854,35 @@ class EnhancedTrainingService:
if update_data:
await self.training_log_repo.update(existing_log.id, update_data)
if auto_commit:
await session.commit() # Explicit commit after updates
except Exception as e:
logger.error("Failed to update job status using repository",
job_id=job_id,
error=str(e))
async def start_single_product_training(self,
tenant_id: str,
inventory_product_id: str,
job_id: str,
bakery_location: tuple = (40.4168, -3.7038)) -> Dict[str, Any]:
"""Start enhanced single product training using repository pattern"""
"""Start enhanced single product training using repository pattern with single session"""
# Create a single database session for all operations to avoid connection pool exhaustion
async with self.database_manager.get_session() as session:
await self._init_repositories(session)
try:
logger.info("Starting enhanced single product training",
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
job_id=job_id)
# Create initial training log
# Create initial training log (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=0,
current_step="Fetching training data",
tenant_id=tenant_id
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit after initial log creation
# Prepare training data for all products to get weather/traffic data
# then filter down to the specific product
@@ -853,111 +892,64 @@ class EnhancedTrainingService:
job_id=job_id + "_temp"
)
# Filter sales data to the specific product
# Use the enhanced data processor to prepare training data with all features (POI, weather, traffic)
# Filter sales data to the specific product first
sales_df = pd.DataFrame(training_dataset.sales_data)
product_sales_df = sales_df[sales_df['inventory_product_id'] == inventory_product_id]
if product_sales_df.empty:
raise ValueError(f"No sales data available for product {inventory_product_id}")
# Prepare the data in Prophet format (ds and y columns)
# Ensure proper column names and types for Prophet
product_data = product_sales_df.copy()
product_data = product_data.rename(columns={
'sale_date': 'ds', # Common sales date column
'sale_datetime': 'ds', # Alternative date column
'date': 'ds', # Alternative date column
'quantity': 'y', # Quantity sold
'total_amount': 'y', # Alternative for sales data
'sales_amount': 'y', # Alternative for sales data
'sale_amount': 'y' # Alternative for sales data
})
# Get weather and traffic data as DataFrames
weather_df = pd.DataFrame(training_dataset.weather_data)
traffic_df = pd.DataFrame(training_dataset.traffic_data)
# If 'ds' and 'y' columns are not renamed properly, try to infer them
if 'ds' not in product_data.columns:
# Try to find date-like columns
date_cols = [col for col in product_data.columns if 'date' in col.lower() or 'time' in col.lower()]
if date_cols:
product_data = product_data.rename(columns={date_cols[0]: 'ds'})
# Get POI features from the training dataset (already collected by orchestrator)
poi_features = training_dataset.poi_features if hasattr(training_dataset, 'poi_features') else None
if 'y' not in product_data.columns:
# Try to find sales/quantity-like columns
sales_cols = [col for col in product_data.columns if
any(word in col.lower() for word in ['amount', 'quantity', 'sales', 'total', 'count', 'value'])]
if sales_cols:
product_data = product_data.rename(columns={sales_cols[0]: 'y'})
# Use the enhanced data processor to merge all features properly
# This will include POI, weather, traffic features along with ds and y
from app.ml.data_processor import EnhancedBakeryDataProcessor
data_processor = EnhancedBakeryDataProcessor(self.database_manager)
# Ensure required columns exist
if 'ds' not in product_data.columns or 'y' not in product_data.columns:
raise ValueError(f"Sales data must contain 'date' and 'quantity/sales' columns. Available columns: {list(product_data.columns)}")
product_data = await data_processor.prepare_training_data(
sales_data=product_sales_df,
weather_data=weather_df,
traffic_data=traffic_df,
inventory_product_id=inventory_product_id,
poi_features=poi_features,
tenant_id=tenant_id,
job_id=job_id
)
# Convert the date column to datetime if it's not already
product_data['ds'] = pd.to_datetime(product_data['ds'])
# Convert to numeric ensuring no pandas/numpy objects remain
product_data['y'] = pd.to_numeric(product_data['y'], errors='coerce')
# Sort by date to ensure proper chronological order
product_data = product_data.sort_values('ds').reset_index(drop=True)
# Drop any rows with NaN values
product_data = product_data.dropna(subset=['ds', 'y'])
# Ensure the data is in the right format for Prophet
product_data = product_data[['ds', 'y']].copy()
# Convert to pandas datetime and float types (keep as pandas Series for proper min/max operations)
product_data['ds'] = pd.to_datetime(product_data['ds'])
product_data['y'] = product_data['y'].astype(float)
# DEBUG: Log data types to diagnose dict comparison error
logger.info(f"DEBUG: product_data dtypes after conversion: ds={product_data['ds'].dtype}, y={product_data['y'].dtype}")
logger.info(f"DEBUG: product_data['ds'] sample values: {product_data['ds'].head(3).tolist()}")
logger.info(f"DEBUG: Attempting to get min/max...")
try:
min_val = product_data['ds'].min()
max_val = product_data['ds'].max()
logger.info(f"DEBUG: min_val type={type(min_val)}, value={min_val}")
logger.info(f"DEBUG: max_val type={type(max_val)}, value={max_val}")
except Exception as debug_e:
logger.error(f"DEBUG: Failed to get min/max: {debug_e}")
import traceback
logger.error(f"DEBUG: Traceback: {traceback.format_exc()}")
if product_data.empty:
raise ValueError(f"Data processor returned empty data for product {inventory_product_id}")
logger.info("Prepared training data for single product",
inventory_product_id=inventory_product_id,
data_points=len(product_data),
features=list(product_data.columns),
date_range=f"{product_data['ds'].min()} to {product_data['ds'].max()}")
# Update progress
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=30,
current_step="Training model",
tenant_id=tenant_id
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# Train the model using the trainer
# Extract datetime values with proper pandas Timestamp wrapper for type safety
try:
training_start = pd.Timestamp(product_data['ds'].min()).to_pydatetime()
training_end = pd.Timestamp(product_data['ds'].max()).to_pydatetime()
except Exception as e:
import traceback
logger.error(f"Failed to extract training dates: {e}")
logger.error(f"Traceback: {traceback.format_exc()}")
logger.error(f"product_data['ds'] dtype: {product_data['ds'].dtype}")
logger.error(f"product_data['ds'] first 5 values: {product_data['ds'].head().tolist()}")
raise
# Run the actual training
# Run the actual training (passing the session to avoid nested session creation)
try:
model_info = await self.trainer.train_single_product_model(
tenant_id=tenant_id,
inventory_product_id=inventory_product_id,
training_data=product_data,
job_id=job_id
job_id=job_id,
session=session # ✅ CRITICAL FIX: Pass session to prevent deadlock
)
except Exception as e:
import traceback
@@ -965,14 +957,16 @@ class EnhancedTrainingService:
logger.error(f"Full traceback: {traceback.format_exc()}")
raise
# Update progress
# Update progress (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="running",
progress=80,
current_step="Saving model",
tenant_id=tenant_id
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit progress update
# The model should already be saved by train_single_product_model
# Return appropriate response
@@ -1012,15 +1006,17 @@ class EnhancedTrainingService:
inventory_product_id=inventory_product_id,
error=str(e))
# Update status to failed
# Update status to failed (using shared session)
await self._update_job_status_repository(
job_id=job_id,
status="failed",
progress=0,
current_step="Training failed",
error_message=str(e),
tenant_id=tenant_id
tenant_id=tenant_id,
session=session
)
await session.commit() # Commit failure status
raise

View File

@@ -494,6 +494,75 @@ class ExternalServiceClient(BaseServiceClient):
# POI (POINT OF INTEREST) DATA
# ================================================================
async def detect_poi_for_tenant(
self,
tenant_id: str,
latitude: float,
longitude: float,
force_refresh: bool = False
) -> Optional[Dict[str, Any]]:
"""
Detect POIs for a tenant's location and generate ML features for forecasting.
With the new tenant-based architecture:
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context/detect
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context/detect
- This client calls: poi-context/detect (base client automatically constructs with tenant)
This triggers POI detection using Overpass API and calculates ML features
for demand forecasting.
Args:
tenant_id: Tenant ID
latitude: Latitude of the bakery location
longitude: Longitude of the bakery location
force_refresh: Whether to force refresh even if POI context exists
Returns:
Dict with POI detection results including:
- ml_features: Dict of POI features for ML models (e.g., poi_retail_total_count)
- poi_detection_results: Full detection results
- location: Latitude/longitude
- total_pois_detected: Count of POIs
"""
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude),
force_refresh=force_refresh
)
params = {
"latitude": latitude,
"longitude": longitude,
"force_refresh": force_refresh
}
# Updated endpoint path to follow tenant-based pattern: external/poi-context/detect
result = await self._make_request(
"POST",
"external/poi-context/detect", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context/detect by base client
tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
params=params,
timeout=60.0 # POI detection can take longer
)
if result:
poi_context = result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"POI detection completed successfully",
tenant_id=tenant_id,
total_pois=poi_context.get("total_pois_detected", 0),
ml_features_count=len(ml_features),
source=result.get("source", "unknown")
)
return result
else:
logger.warning("POI detection failed for tenant", tenant_id=tenant_id)
return None
async def get_poi_context(
self,
tenant_id: str
@@ -504,7 +573,7 @@ class ExternalServiceClient(BaseServiceClient):
With the new tenant-based architecture:
- Gateway receives at: /api/v1/tenants/{tenant_id}/external/poi-context
- Gateway proxies to external service at: /api/v1/tenants/{tenant_id}/poi-context
- This client calls: /tenants/{tenant_id}/poi-context
- This client calls: poi-context (base client automatically constructs with tenant)
This retrieves stored POI detection results and calculated ML features
that should be included in demand forecasting predictions.
@@ -521,11 +590,11 @@ class ExternalServiceClient(BaseServiceClient):
"""
logger.info("Fetching POI context for forecasting", tenant_id=tenant_id)
# Updated endpoint path to follow tenant-based pattern: /tenants/{tenant_id}/poi-context
# Updated endpoint path to follow tenant-based pattern: external/poi-context
result = await self._make_request(
"GET",
f"tenants/{tenant_id}/poi-context", # Updated path: /tenants/{tenant_id}/poi-context
tenant_id=tenant_id, # Pass tenant_id to include in headers for authentication
"external/poi-context", # Path will become /api/v1/tenants/{tenant_id}/external/poi-context by base client
tenant_id=tenant_id, # Pass tenant_id to include in headers and path construction
timeout=5.0
)