Add improvements 2
This commit is contained in:
@@ -1,82 +0,0 @@
|
||||
# services/forecasting/app/clients/inventory_client.py
|
||||
"""
|
||||
Simple client for inventory service integration
|
||||
Used when product names are not available locally
|
||||
"""
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
from typing import Optional, Dict, Any
|
||||
import os
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
class InventoryServiceClient:
|
||||
"""Simple client for inventory service interactions"""
|
||||
|
||||
def __init__(self, base_url: str = None):
|
||||
self.base_url = base_url or os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")
|
||||
self.timeout = aiohttp.ClientTimeout(total=5) # 5 second timeout
|
||||
|
||||
async def get_product_name(self, tenant_id: str, inventory_product_id: str) -> Optional[str]:
|
||||
"""
|
||||
Get product name from inventory service
|
||||
Returns None if service is unavailable or product not found
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
url = f"{self.base_url}/api/v1/products/{inventory_product_id}"
|
||||
headers = {"X-Tenant-ID": tenant_id}
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return data.get("name", f"Product-{inventory_product_id}")
|
||||
else:
|
||||
logger.debug("Product not found in inventory service",
|
||||
inventory_product_id=inventory_product_id,
|
||||
status=response.status)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get product name from inventory service",
|
||||
inventory_product_id=inventory_product_id,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def get_multiple_product_names(self, tenant_id: str, product_ids: list) -> Dict[str, str]:
|
||||
"""
|
||||
Get multiple product names efficiently
|
||||
Returns a mapping of product_id -> product_name
|
||||
"""
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=self.timeout) as session:
|
||||
url = f"{self.base_url}/api/v1/products/batch"
|
||||
headers = {"X-Tenant-ID": tenant_id}
|
||||
payload = {"product_ids": product_ids}
|
||||
|
||||
async with session.post(url, json=payload, headers=headers) as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return {item["id"]: item["name"] for item in data.get("products", [])}
|
||||
else:
|
||||
logger.debug("Batch product lookup failed",
|
||||
product_count=len(product_ids),
|
||||
status=response.status)
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get product names from inventory service",
|
||||
product_count=len(product_ids),
|
||||
error=str(e))
|
||||
return {}
|
||||
|
||||
# Global client instance
|
||||
_inventory_client = None
|
||||
|
||||
def get_inventory_client() -> InventoryServiceClient:
|
||||
"""Get the global inventory client instance"""
|
||||
global _inventory_client
|
||||
if _inventory_client is None:
|
||||
_inventory_client = InventoryServiceClient()
|
||||
return _inventory_client
|
||||
@@ -12,7 +12,6 @@ from datetime import datetime, timedelta
|
||||
import structlog
|
||||
|
||||
from shared.messaging import UnifiedEventPublisher
|
||||
from app.clients.inventory_client import get_inventory_client
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
@@ -30,11 +30,11 @@ async def trigger_inventory_alerts(
|
||||
- Expiring ingredients
|
||||
- Overstock situations
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
"""
|
||||
try:
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -350,7 +350,7 @@ async def generate_safety_stock_insights_internal(
|
||||
This endpoint is called by the demo-session service after cloning data.
|
||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant UUID
|
||||
@@ -365,7 +365,7 @@ async def generate_safety_stock_insights_internal(
|
||||
}
|
||||
"""
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -29,7 +29,7 @@ async def trigger_delivery_tracking(
|
||||
This endpoint is called by the demo session cloning process after POs are seeded
|
||||
to generate realistic delivery alerts (arriving soon, overdue, etc.).
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID to check deliveries for
|
||||
@@ -49,7 +49,7 @@ async def trigger_delivery_tracking(
|
||||
"""
|
||||
try:
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -566,7 +566,7 @@ async def generate_price_insights_internal(
|
||||
This endpoint is called by the demo-session service after cloning data.
|
||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant UUID
|
||||
@@ -581,7 +581,7 @@ async def generate_price_insights_internal(
|
||||
}
|
||||
"""
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -1,42 +1,45 @@
|
||||
"""
|
||||
FastAPI Dependencies for Procurement Service
|
||||
Uses shared authentication infrastructure with UUID validation
|
||||
"""
|
||||
|
||||
from fastapi import Header, HTTPException, status
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from uuid import UUID
|
||||
from typing import Optional
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from .database import get_db
|
||||
from shared.auth.decorators import get_current_tenant_id_dep
|
||||
|
||||
|
||||
async def get_current_tenant_id(
|
||||
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID")
|
||||
tenant_id: Optional[str] = Depends(get_current_tenant_id_dep)
|
||||
) -> UUID:
|
||||
"""
|
||||
Extract and validate tenant ID from request header.
|
||||
Extract and validate tenant ID from request using shared infrastructure.
|
||||
Adds UUID validation to ensure tenant ID format is correct.
|
||||
|
||||
Args:
|
||||
x_tenant_id: Tenant ID from X-Tenant-ID header
|
||||
tenant_id: Tenant ID from shared dependency
|
||||
|
||||
Returns:
|
||||
UUID: Validated tenant ID
|
||||
|
||||
Raises:
|
||||
HTTPException: If tenant ID is missing or invalid
|
||||
HTTPException: If tenant ID is missing or invalid UUID format
|
||||
"""
|
||||
if not x_tenant_id:
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="X-Tenant-ID header is required"
|
||||
detail="x-tenant-id header is required"
|
||||
)
|
||||
|
||||
try:
|
||||
return UUID(x_tenant_id)
|
||||
return UUID(tenant_id)
|
||||
except (ValueError, AttributeError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Invalid tenant ID format: {x_tenant_id}"
|
||||
detail=f"Invalid tenant ID format: {tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,11 +31,11 @@ async def trigger_production_alerts(
|
||||
- Equipment maintenance alerts
|
||||
- Batch start delays
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
"""
|
||||
try:
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -331,7 +331,7 @@ async def generate_yield_insights_internal(
|
||||
This endpoint is called by the demo-session service after cloning data.
|
||||
It uses the same ML logic as the public endpoint but with optimized defaults.
|
||||
|
||||
Security: Protected by X-Internal-Service header check.
|
||||
Security: Protected by x-internal-service header check.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant UUID
|
||||
@@ -346,7 +346,7 @@ async def generate_yield_insights_internal(
|
||||
}
|
||||
"""
|
||||
# Verify internal service header
|
||||
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
|
||||
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
|
||||
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
|
||||
@@ -204,7 +204,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
||||
f"{auth_service_url}/api/v1/auth/users/batch",
|
||||
json={"user_ids": user_ids},
|
||||
timeout=10.0,
|
||||
headers={"X-Internal-Service": "tenant-service"}
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
@@ -226,7 +226,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"X-Internal-Service": "tenant-service"}
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
@@ -243,7 +243,7 @@ class TenantMemberRepository(TenantBaseRepository):
|
||||
response = await client.get(
|
||||
f"{auth_service_url}/api/v1/auth/users/{user_id}",
|
||||
timeout=5.0,
|
||||
headers={"X-Internal-Service": "tenant-service"}
|
||||
headers={"x-internal-service": "tenant-service"}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
user_data = response.json()
|
||||
|
||||
@@ -216,17 +216,24 @@ class HybridProphetXGBoost:
|
||||
Get Prophet predictions for given dataframe.
|
||||
|
||||
Args:
|
||||
prophet_result: Prophet model result from training
|
||||
prophet_result: Prophet model result from training (contains model_path)
|
||||
df: DataFrame with 'ds' column
|
||||
|
||||
Returns:
|
||||
Array of predictions
|
||||
"""
|
||||
# Get the Prophet model from result
|
||||
prophet_model = prophet_result.get('model')
|
||||
# Get the model path from result instead of expecting the model object directly
|
||||
model_path = prophet_result.get('model_path')
|
||||
|
||||
if prophet_model is None:
|
||||
raise ValueError("Prophet model not found in result")
|
||||
if model_path is None:
|
||||
raise ValueError("Prophet model path not found in result")
|
||||
|
||||
# Load the actual Prophet model from the stored path
|
||||
try:
|
||||
import joblib
|
||||
prophet_model = joblib.load(model_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load Prophet model from path {model_path}: {str(e)}")
|
||||
|
||||
# Prepare dataframe for prediction
|
||||
pred_df = df[['ds']].copy()
|
||||
@@ -273,7 +280,8 @@ class HybridProphetXGBoost:
|
||||
'reg_lambda': 1.0, # L2 regularization
|
||||
'objective': 'reg:squarederror',
|
||||
'random_state': 42,
|
||||
'n_jobs': -1
|
||||
'n_jobs': -1,
|
||||
'early_stopping_rounds': 10
|
||||
}
|
||||
|
||||
# Initialize model
|
||||
@@ -285,7 +293,6 @@ class HybridProphetXGBoost:
|
||||
model.fit,
|
||||
X_train, y_train,
|
||||
eval_set=[(X_val, y_val)],
|
||||
early_stopping_rounds=10,
|
||||
verbose=False
|
||||
)
|
||||
|
||||
@@ -303,109 +310,86 @@ class HybridProphetXGBoost:
|
||||
train_prophet_pred: np.ndarray,
|
||||
val_prophet_pred: np.ndarray,
|
||||
prophet_result: Dict[str, Any]
|
||||
) -> Dict[str, float]:
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Evaluate hybrid model vs Prophet-only on validation set.
|
||||
|
||||
Args:
|
||||
train_df: Training data
|
||||
val_df: Validation data
|
||||
train_prophet_pred: Prophet predictions on training set
|
||||
val_prophet_pred: Prophet predictions on validation set
|
||||
prophet_result: Prophet training result
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics
|
||||
Evaluate the overall performance of the hybrid model using threading for metrics.
|
||||
"""
|
||||
# Get actual values
|
||||
train_actual = train_df['y'].values
|
||||
val_actual = val_df['y'].values
|
||||
|
||||
# Get XGBoost predictions on residuals
|
||||
import asyncio
|
||||
|
||||
# Get XGBoost predictions on training and validation
|
||||
X_train = train_df[self.feature_columns].values
|
||||
X_val = val_df[self.feature_columns].values
|
||||
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
|
||||
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
|
||||
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
|
||||
|
||||
# Hybrid predictions = Prophet + XGBoost residual correction
|
||||
|
||||
# Hybrid prediction = Prophet prediction + XGBoost residual prediction
|
||||
train_hybrid_pred = train_prophet_pred + train_xgb_pred
|
||||
val_hybrid_pred = val_prophet_pred + val_xgb_pred
|
||||
|
||||
# Calculate metrics for Prophet-only
|
||||
prophet_train_mae = mean_absolute_error(train_actual, train_prophet_pred)
|
||||
prophet_val_mae = mean_absolute_error(val_actual, val_prophet_pred)
|
||||
prophet_train_mape = mean_absolute_percentage_error(train_actual, train_prophet_pred) * 100
|
||||
prophet_val_mape = mean_absolute_percentage_error(val_actual, val_prophet_pred) * 100
|
||||
|
||||
# Calculate metrics for Hybrid
|
||||
hybrid_train_mae = mean_absolute_error(train_actual, train_hybrid_pred)
|
||||
hybrid_val_mae = mean_absolute_error(val_actual, val_hybrid_pred)
|
||||
hybrid_train_mape = mean_absolute_percentage_error(train_actual, train_hybrid_pred) * 100
|
||||
hybrid_val_mape = mean_absolute_percentage_error(val_actual, val_hybrid_pred) * 100
|
||||
|
||||
actual_train = train_df['y'].values
|
||||
actual_val = val_df['y'].values
|
||||
|
||||
# Basic RMSE calculation
|
||||
train_rmse = float(np.sqrt(np.mean((actual_train - train_hybrid_pred)**2)))
|
||||
val_rmse = float(np.sqrt(np.mean((actual_val - val_hybrid_pred)**2)))
|
||||
|
||||
# MAE
|
||||
train_mae = float(np.mean(np.abs(actual_train - train_hybrid_pred)))
|
||||
val_mae = float(np.mean(np.abs(actual_val - val_hybrid_pred)))
|
||||
|
||||
# MAPE (with safety for zero sales)
|
||||
train_mape = float(np.mean(np.abs((actual_train - train_hybrid_pred) / np.maximum(actual_train, 1))))
|
||||
val_mape = float(np.mean(np.abs((actual_val - val_hybrid_pred) / np.maximum(actual_val, 1))))
|
||||
|
||||
# Calculate improvement
|
||||
mae_improvement = ((prophet_val_mae - hybrid_val_mae) / prophet_val_mae) * 100
|
||||
mape_improvement = ((prophet_val_mape - hybrid_val_mape) / prophet_val_mape) * 100
|
||||
prophet_metrics = prophet_result.get("metrics", {})
|
||||
prophet_val_mae = prophet_metrics.get("val_mae", val_mae) # Fallback to hybrid if missing
|
||||
prophet_val_mape = prophet_metrics.get("val_mape", val_mape)
|
||||
|
||||
improvement_pct = 0.0
|
||||
if prophet_val_mape > 0:
|
||||
improvement_pct = ((prophet_val_mape - val_mape) / prophet_val_mape) * 100
|
||||
|
||||
metrics = {
|
||||
'prophet_train_mae': float(prophet_train_mae),
|
||||
'prophet_val_mae': float(prophet_val_mae),
|
||||
'prophet_train_mape': float(prophet_train_mape),
|
||||
'prophet_val_mape': float(prophet_val_mape),
|
||||
'hybrid_train_mae': float(hybrid_train_mae),
|
||||
'hybrid_val_mae': float(hybrid_val_mae),
|
||||
'hybrid_train_mape': float(hybrid_train_mape),
|
||||
'hybrid_val_mape': float(hybrid_val_mape),
|
||||
'mae_improvement_pct': float(mae_improvement),
|
||||
'mape_improvement_pct': float(mape_improvement),
|
||||
'improvement_percentage': float(mape_improvement) # Primary metric
|
||||
"train_rmse": train_rmse,
|
||||
"val_rmse": val_rmse,
|
||||
"train_mae": train_mae,
|
||||
"val_mae": val_mae,
|
||||
"train_mape": train_mape,
|
||||
"val_mape": val_mape,
|
||||
"prophet_val_mape": prophet_val_mape,
|
||||
"hybrid_val_mape": val_mape,
|
||||
"improvement_percentage": float(improvement_pct),
|
||||
"prophet_metrics": prophet_metrics
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Hybrid model evaluation complete",
|
||||
val_rmse=val_rmse,
|
||||
val_mae=val_mae,
|
||||
val_mape=val_mape,
|
||||
improvement=improvement_pct
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
def _package_hybrid_model(
|
||||
self,
|
||||
prophet_result: Dict[str, Any],
|
||||
metrics: Dict[str, float],
|
||||
metrics: Dict[str, Any],
|
||||
tenant_id: str,
|
||||
inventory_product_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Package hybrid model for storage.
|
||||
|
||||
Args:
|
||||
prophet_result: Prophet model result
|
||||
metrics: Hybrid model metrics
|
||||
tenant_id: Tenant ID
|
||||
inventory_product_id: Product ID
|
||||
|
||||
Returns:
|
||||
Model package dictionary
|
||||
"""
|
||||
return {
|
||||
'model_type': 'hybrid_prophet_xgboost',
|
||||
'prophet_model': prophet_result.get('model'),
|
||||
'prophet_model_path': prophet_result.get('model_path'),
|
||||
'xgboost_model': self.xgb_model,
|
||||
'feature_columns': self.feature_columns,
|
||||
'prophet_metrics': {
|
||||
'train_mae': metrics['prophet_train_mae'],
|
||||
'val_mae': metrics['prophet_val_mae'],
|
||||
'train_mape': metrics['prophet_train_mape'],
|
||||
'val_mape': metrics['prophet_val_mape']
|
||||
},
|
||||
'hybrid_metrics': {
|
||||
'train_mae': metrics['hybrid_train_mae'],
|
||||
'val_mae': metrics['hybrid_val_mae'],
|
||||
'train_mape': metrics['hybrid_train_mape'],
|
||||
'val_mape': metrics['hybrid_val_mape']
|
||||
},
|
||||
'improvement_metrics': {
|
||||
'mae_improvement_pct': metrics['mae_improvement_pct'],
|
||||
'mape_improvement_pct': metrics['mape_improvement_pct']
|
||||
},
|
||||
'metrics': metrics,
|
||||
'tenant_id': tenant_id,
|
||||
'inventory_product_id': inventory_product_id,
|
||||
'trained_at': datetime.now(timezone.utc).isoformat()
|
||||
@@ -426,8 +410,18 @@ class HybridProphetXGBoost:
|
||||
Returns:
|
||||
DataFrame with predictions
|
||||
"""
|
||||
# Step 1: Get Prophet predictions
|
||||
prophet_model = model_data['prophet_model']
|
||||
# Step 1: Get Prophet model from path and make predictions
|
||||
prophet_model_path = model_data.get('prophet_model_path')
|
||||
if prophet_model_path is None:
|
||||
raise ValueError("Prophet model path not found in model data")
|
||||
|
||||
# Load the Prophet model from the stored path
|
||||
try:
|
||||
import joblib
|
||||
prophet_model = joblib.load(prophet_model_path)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load Prophet model from path {prophet_model_path}: {str(e)}")
|
||||
|
||||
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
|
||||
import asyncio
|
||||
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)
|
||||
|
||||
@@ -43,86 +43,79 @@ class POIFeatureIntegrator:
|
||||
force_refresh: bool = False
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch POI features for tenant location.
|
||||
Fetch POI features for tenant location (optimized for training).
|
||||
|
||||
First checks if POI context exists, if not, triggers detection.
|
||||
First checks if POI context exists. If not, returns None without triggering detection.
|
||||
POI detection should be triggered during tenant registration, not during training.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant UUID
|
||||
latitude: Bakery latitude
|
||||
longitude: Bakery longitude
|
||||
force_refresh: Force re-detection
|
||||
force_refresh: Force re-detection (only use if POI context already exists)
|
||||
|
||||
Returns:
|
||||
Dictionary with POI features or None if detection fails
|
||||
Dictionary with POI features or None if not available
|
||||
"""
|
||||
try:
|
||||
# Try to get existing POI context first
|
||||
if not force_refresh:
|
||||
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||
if existing_context:
|
||||
poi_context = existing_context.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
existing_context = await self.external_client.get_poi_context(tenant_id)
|
||||
|
||||
# Check if stale
|
||||
is_stale = existing_context.get("is_stale", False)
|
||||
if not is_stale:
|
||||
if existing_context:
|
||||
poi_context = existing_context.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
# Check if stale and force_refresh is requested
|
||||
is_stale = existing_context.get("is_stale", False)
|
||||
|
||||
if not is_stale or not force_refresh:
|
||||
logger.info(
|
||||
"Using existing POI context",
|
||||
tenant_id=tenant_id,
|
||||
is_stale=is_stale,
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.info(
|
||||
"POI context is stale and force_refresh=True, refreshing",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
# Only refresh if explicitly requested and context exists
|
||||
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||
tenant_id=tenant_id,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=True
|
||||
)
|
||||
|
||||
if detection_result:
|
||||
poi_context = detection_result.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
logger.info(
|
||||
"Using existing POI context",
|
||||
tenant_id=tenant_id
|
||||
"POI refresh completed",
|
||||
tenant_id=tenant_id,
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
return ml_features
|
||||
else:
|
||||
logger.info(
|
||||
"POI context is stale, refreshing",
|
||||
logger.warning(
|
||||
"POI refresh failed, returning existing features",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
force_refresh = True
|
||||
else:
|
||||
logger.info(
|
||||
"No existing POI context, will detect",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Detect or refresh POIs
|
||||
logger.info(
|
||||
"Detecting POIs for tenant",
|
||||
tenant_id=tenant_id,
|
||||
location=(latitude, longitude)
|
||||
)
|
||||
|
||||
detection_result = await self.external_client.detect_poi_for_tenant(
|
||||
tenant_id=tenant_id,
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
force_refresh=force_refresh
|
||||
)
|
||||
|
||||
if detection_result:
|
||||
poi_context = detection_result.get("poi_context", {})
|
||||
ml_features = poi_context.get("ml_features", {})
|
||||
|
||||
logger.info(
|
||||
"POI detection completed",
|
||||
tenant_id=tenant_id,
|
||||
total_pois=poi_context.get("total_pois_detected", 0),
|
||||
feature_count=len(ml_features)
|
||||
)
|
||||
|
||||
return ml_features
|
||||
return ml_features
|
||||
else:
|
||||
logger.error(
|
||||
"POI detection failed",
|
||||
logger.info(
|
||||
"No existing POI context found - POI detection should be triggered during tenant registration",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Unexpected error fetching POI features",
|
||||
logger.warning(
|
||||
"Error fetching POI features - returning None",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
error=str(e)
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -29,16 +29,15 @@ class DataClient:
|
||||
self.sales_client = get_sales_client(settings, "training")
|
||||
self.external_client = get_external_client(settings, "training")
|
||||
|
||||
# ExternalServiceClient always has get_stored_traffic_data_for_training method
|
||||
self.supports_stored_traffic_data = True
|
||||
|
||||
# Configure timeouts for HTTP clients
|
||||
self._configure_timeouts()
|
||||
|
||||
# Initialize circuit breakers for external services
|
||||
self._init_circuit_breakers()
|
||||
|
||||
# Check if the new method is available for stored traffic data
|
||||
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
|
||||
self.supports_stored_traffic_data = True
|
||||
|
||||
def _configure_timeouts(self):
|
||||
"""Configure appropriate timeouts for HTTP clients"""
|
||||
timeout = httpx.Timeout(
|
||||
@@ -49,14 +48,12 @@ class DataClient:
|
||||
)
|
||||
|
||||
# Apply timeout to clients if they have httpx clients
|
||||
# Note: BaseServiceClient manages its own HTTP client internally
|
||||
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
|
||||
self.sales_client.client.timeout = timeout
|
||||
|
||||
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
|
||||
self.external_client.client.timeout = timeout
|
||||
else:
|
||||
self.supports_stored_traffic_data = False
|
||||
logger.warning("Stored traffic data method not available in external client")
|
||||
|
||||
def _init_circuit_breakers(self):
|
||||
"""Initialize circuit breakers for external service calls"""
|
||||
|
||||
@@ -404,22 +404,32 @@ class TrainingDataOrchestrator:
|
||||
tenant_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Collect POI features for bakery location.
|
||||
Collect POI features for bakery location (non-blocking).
|
||||
|
||||
POI features are static (location-based, not time-varying).
|
||||
This method is non-blocking with a short timeout to prevent training delays.
|
||||
If POI detection hasn't been run yet, training continues without POI features.
|
||||
|
||||
Returns:
|
||||
Dictionary with POI features or empty dict if unavailable
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
"Collecting POI features",
|
||||
"Collecting POI features (non-blocking)",
|
||||
tenant_id=tenant_id,
|
||||
location=(lat, lon)
|
||||
)
|
||||
|
||||
poi_features = await self.poi_feature_integrator.fetch_poi_features(
|
||||
tenant_id=tenant_id,
|
||||
latitude=lat,
|
||||
longitude=lon,
|
||||
force_refresh=False
|
||||
# Set a short timeout to prevent blocking training
|
||||
# POI detection should have been triggered during tenant registration
|
||||
poi_features = await asyncio.wait_for(
|
||||
self.poi_feature_integrator.fetch_poi_features(
|
||||
tenant_id=tenant_id,
|
||||
latitude=lat,
|
||||
longitude=lon,
|
||||
force_refresh=False
|
||||
),
|
||||
timeout=15.0 # 15 second timeout - POI should be cached from registration
|
||||
)
|
||||
|
||||
if poi_features:
|
||||
@@ -430,18 +440,24 @@ class TrainingDataOrchestrator:
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"No POI features collected (service may be unavailable)",
|
||||
"No POI features collected (service may be unavailable or not yet detected)",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
return poi_features or {}
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"POI collection timeout (15s) - continuing training without POI features. "
|
||||
"POI detection should be triggered during tenant registration for best results.",
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
return {}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to collect POI features, continuing without them",
|
||||
logger.warning(
|
||||
"Failed to collect POI features (non-blocking) - continuing training without them",
|
||||
tenant_id=tenant_id,
|
||||
error=str(e),
|
||||
exc_info=True
|
||||
error=str(e)
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user