Add improvements 2

This commit is contained in:
Urtzi Alfaro
2026-01-12 22:15:11 +01:00
parent 230bbe6a19
commit b931a5c45e
40 changed files with 1820 additions and 887 deletions

View File

@@ -1,82 +0,0 @@
# services/forecasting/app/clients/inventory_client.py
"""
Simple client for inventory service integration
Used when product names are not available locally
"""
import aiohttp
import structlog
from typing import Optional, Dict, Any
import os
logger = structlog.get_logger()
class InventoryServiceClient:
"""Simple client for inventory service interactions"""
def __init__(self, base_url: str = None):
self.base_url = base_url or os.getenv("INVENTORY_SERVICE_URL", "http://inventory-service:8000")
self.timeout = aiohttp.ClientTimeout(total=5) # 5 second timeout
async def get_product_name(self, tenant_id: str, inventory_product_id: str) -> Optional[str]:
"""
Get product name from inventory service
Returns None if service is unavailable or product not found
"""
try:
async with aiohttp.ClientSession(timeout=self.timeout) as session:
url = f"{self.base_url}/api/v1/products/{inventory_product_id}"
headers = {"X-Tenant-ID": tenant_id}
async with session.get(url, headers=headers) as response:
if response.status == 200:
data = await response.json()
return data.get("name", f"Product-{inventory_product_id}")
else:
logger.debug("Product not found in inventory service",
inventory_product_id=inventory_product_id,
status=response.status)
return None
except Exception as e:
logger.debug("Failed to get product name from inventory service",
inventory_product_id=inventory_product_id,
error=str(e))
return None
async def get_multiple_product_names(self, tenant_id: str, product_ids: list) -> Dict[str, str]:
"""
Get multiple product names efficiently
Returns a mapping of product_id -> product_name
"""
try:
async with aiohttp.ClientSession(timeout=self.timeout) as session:
url = f"{self.base_url}/api/v1/products/batch"
headers = {"X-Tenant-ID": tenant_id}
payload = {"product_ids": product_ids}
async with session.post(url, json=payload, headers=headers) as response:
if response.status == 200:
data = await response.json()
return {item["id"]: item["name"] for item in data.get("products", [])}
else:
logger.debug("Batch product lookup failed",
product_count=len(product_ids),
status=response.status)
return {}
except Exception as e:
logger.debug("Failed to get product names from inventory service",
product_count=len(product_ids),
error=str(e))
return {}
# Global client instance
_inventory_client = None
def get_inventory_client() -> InventoryServiceClient:
"""Get the global inventory client instance"""
global _inventory_client
if _inventory_client is None:
_inventory_client = InventoryServiceClient()
return _inventory_client

View File

@@ -12,7 +12,6 @@ from datetime import datetime, timedelta
import structlog
from shared.messaging import UnifiedEventPublisher
from app.clients.inventory_client import get_inventory_client
logger = structlog.get_logger()

View File

@@ -30,11 +30,11 @@ async def trigger_inventory_alerts(
- Expiring ingredients
- Overstock situations
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
"""
try:
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
raise HTTPException(
status_code=403,

View File

@@ -350,7 +350,7 @@ async def generate_safety_stock_insights_internal(
This endpoint is called by the demo-session service after cloning data.
It uses the same ML logic as the public endpoint but with optimized defaults.
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: The tenant UUID
@@ -365,7 +365,7 @@ async def generate_safety_stock_insights_internal(
}
"""
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
raise HTTPException(
status_code=403,

View File

@@ -29,7 +29,7 @@ async def trigger_delivery_tracking(
This endpoint is called by the demo session cloning process after POs are seeded
to generate realistic delivery alerts (arriving soon, overdue, etc.).
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: Tenant UUID to check deliveries for
@@ -49,7 +49,7 @@ async def trigger_delivery_tracking(
"""
try:
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
raise HTTPException(
status_code=403,

View File

@@ -566,7 +566,7 @@ async def generate_price_insights_internal(
This endpoint is called by the demo-session service after cloning data.
It uses the same ML logic as the public endpoint but with optimized defaults.
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: The tenant UUID
@@ -581,7 +581,7 @@ async def generate_price_insights_internal(
}
"""
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
raise HTTPException(
status_code=403,

View File

@@ -1,42 +1,45 @@
"""
FastAPI Dependencies for Procurement Service
Uses shared authentication infrastructure with UUID validation
"""
from fastapi import Header, HTTPException, status
from fastapi import Depends, HTTPException, status
from uuid import UUID
from typing import Optional
from sqlalchemy.ext.asyncio import AsyncSession
from .database import get_db
from shared.auth.decorators import get_current_tenant_id_dep
async def get_current_tenant_id(
x_tenant_id: Optional[str] = Header(None, alias="X-Tenant-ID")
tenant_id: Optional[str] = Depends(get_current_tenant_id_dep)
) -> UUID:
"""
Extract and validate tenant ID from request header.
Extract and validate tenant ID from request using shared infrastructure.
Adds UUID validation to ensure tenant ID format is correct.
Args:
x_tenant_id: Tenant ID from X-Tenant-ID header
tenant_id: Tenant ID from shared dependency
Returns:
UUID: Validated tenant ID
Raises:
HTTPException: If tenant ID is missing or invalid
HTTPException: If tenant ID is missing or invalid UUID format
"""
if not x_tenant_id:
if not tenant_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="X-Tenant-ID header is required"
detail="x-tenant-id header is required"
)
try:
return UUID(x_tenant_id)
return UUID(tenant_id)
except (ValueError, AttributeError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid tenant ID format: {x_tenant_id}"
detail=f"Invalid tenant ID format: {tenant_id}"
)

View File

@@ -31,11 +31,11 @@ async def trigger_production_alerts(
- Equipment maintenance alerts
- Batch start delays
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
"""
try:
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=str(tenant_id))
raise HTTPException(
status_code=403,

View File

@@ -331,7 +331,7 @@ async def generate_yield_insights_internal(
This endpoint is called by the demo-session service after cloning data.
It uses the same ML logic as the public endpoint but with optimized defaults.
Security: Protected by X-Internal-Service header check.
Security: Protected by x-internal-service header check.
Args:
tenant_id: The tenant UUID
@@ -346,7 +346,7 @@ async def generate_yield_insights_internal(
}
"""
# Verify internal service header
if not request or request.headers.get("X-Internal-Service") not in ["demo-session", "internal"]:
if not request or request.headers.get("x-internal-service") not in ["demo-session", "internal"]:
logger.warning("Unauthorized internal API call", tenant_id=tenant_id)
raise HTTPException(
status_code=403,

View File

@@ -204,7 +204,7 @@ class TenantMemberRepository(TenantBaseRepository):
f"{auth_service_url}/api/v1/auth/users/batch",
json={"user_ids": user_ids},
timeout=10.0,
headers={"X-Internal-Service": "tenant-service"}
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
@@ -226,7 +226,7 @@ class TenantMemberRepository(TenantBaseRepository):
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"X-Internal-Service": "tenant-service"}
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()
@@ -243,7 +243,7 @@ class TenantMemberRepository(TenantBaseRepository):
response = await client.get(
f"{auth_service_url}/api/v1/auth/users/{user_id}",
timeout=5.0,
headers={"X-Internal-Service": "tenant-service"}
headers={"x-internal-service": "tenant-service"}
)
if response.status_code == 200:
user_data = response.json()

View File

@@ -216,17 +216,24 @@ class HybridProphetXGBoost:
Get Prophet predictions for given dataframe.
Args:
prophet_result: Prophet model result from training
prophet_result: Prophet model result from training (contains model_path)
df: DataFrame with 'ds' column
Returns:
Array of predictions
"""
# Get the Prophet model from result
prophet_model = prophet_result.get('model')
# Get the model path from result instead of expecting the model object directly
model_path = prophet_result.get('model_path')
if prophet_model is None:
raise ValueError("Prophet model not found in result")
if model_path is None:
raise ValueError("Prophet model path not found in result")
# Load the actual Prophet model from the stored path
try:
import joblib
prophet_model = joblib.load(model_path)
except Exception as e:
raise ValueError(f"Failed to load Prophet model from path {model_path}: {str(e)}")
# Prepare dataframe for prediction
pred_df = df[['ds']].copy()
@@ -273,7 +280,8 @@ class HybridProphetXGBoost:
'reg_lambda': 1.0, # L2 regularization
'objective': 'reg:squarederror',
'random_state': 42,
'n_jobs': -1
'n_jobs': -1,
'early_stopping_rounds': 10
}
# Initialize model
@@ -285,7 +293,6 @@ class HybridProphetXGBoost:
model.fit,
X_train, y_train,
eval_set=[(X_val, y_val)],
early_stopping_rounds=10,
verbose=False
)
@@ -303,109 +310,86 @@ class HybridProphetXGBoost:
train_prophet_pred: np.ndarray,
val_prophet_pred: np.ndarray,
prophet_result: Dict[str, Any]
) -> Dict[str, float]:
) -> Dict[str, Any]:
"""
Evaluate hybrid model vs Prophet-only on validation set.
Args:
train_df: Training data
val_df: Validation data
train_prophet_pred: Prophet predictions on training set
val_prophet_pred: Prophet predictions on validation set
prophet_result: Prophet training result
Returns:
Dictionary of metrics
Evaluate the overall performance of the hybrid model using threading for metrics.
"""
# Get actual values
train_actual = train_df['y'].values
val_actual = val_df['y'].values
# Get XGBoost predictions on residuals
import asyncio
# Get XGBoost predictions on training and validation
X_train = train_df[self.feature_columns].values
X_val = val_df[self.feature_columns].values
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
import asyncio
train_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_train)
val_xgb_pred = await asyncio.to_thread(self.xgb_model.predict, X_val)
# Hybrid predictions = Prophet + XGBoost residual correction
# Hybrid prediction = Prophet prediction + XGBoost residual prediction
train_hybrid_pred = train_prophet_pred + train_xgb_pred
val_hybrid_pred = val_prophet_pred + val_xgb_pred
# Calculate metrics for Prophet-only
prophet_train_mae = mean_absolute_error(train_actual, train_prophet_pred)
prophet_val_mae = mean_absolute_error(val_actual, val_prophet_pred)
prophet_train_mape = mean_absolute_percentage_error(train_actual, train_prophet_pred) * 100
prophet_val_mape = mean_absolute_percentage_error(val_actual, val_prophet_pred) * 100
# Calculate metrics for Hybrid
hybrid_train_mae = mean_absolute_error(train_actual, train_hybrid_pred)
hybrid_val_mae = mean_absolute_error(val_actual, val_hybrid_pred)
hybrid_train_mape = mean_absolute_percentage_error(train_actual, train_hybrid_pred) * 100
hybrid_val_mape = mean_absolute_percentage_error(val_actual, val_hybrid_pred) * 100
actual_train = train_df['y'].values
actual_val = val_df['y'].values
# Basic RMSE calculation
train_rmse = float(np.sqrt(np.mean((actual_train - train_hybrid_pred)**2)))
val_rmse = float(np.sqrt(np.mean((actual_val - val_hybrid_pred)**2)))
# MAE
train_mae = float(np.mean(np.abs(actual_train - train_hybrid_pred)))
val_mae = float(np.mean(np.abs(actual_val - val_hybrid_pred)))
# MAPE (with safety for zero sales)
train_mape = float(np.mean(np.abs((actual_train - train_hybrid_pred) / np.maximum(actual_train, 1))))
val_mape = float(np.mean(np.abs((actual_val - val_hybrid_pred) / np.maximum(actual_val, 1))))
# Calculate improvement
mae_improvement = ((prophet_val_mae - hybrid_val_mae) / prophet_val_mae) * 100
mape_improvement = ((prophet_val_mape - hybrid_val_mape) / prophet_val_mape) * 100
prophet_metrics = prophet_result.get("metrics", {})
prophet_val_mae = prophet_metrics.get("val_mae", val_mae) # Fallback to hybrid if missing
prophet_val_mape = prophet_metrics.get("val_mape", val_mape)
improvement_pct = 0.0
if prophet_val_mape > 0:
improvement_pct = ((prophet_val_mape - val_mape) / prophet_val_mape) * 100
metrics = {
'prophet_train_mae': float(prophet_train_mae),
'prophet_val_mae': float(prophet_val_mae),
'prophet_train_mape': float(prophet_train_mape),
'prophet_val_mape': float(prophet_val_mape),
'hybrid_train_mae': float(hybrid_train_mae),
'hybrid_val_mae': float(hybrid_val_mae),
'hybrid_train_mape': float(hybrid_train_mape),
'hybrid_val_mape': float(hybrid_val_mape),
'mae_improvement_pct': float(mae_improvement),
'mape_improvement_pct': float(mape_improvement),
'improvement_percentage': float(mape_improvement) # Primary metric
"train_rmse": train_rmse,
"val_rmse": val_rmse,
"train_mae": train_mae,
"val_mae": val_mae,
"train_mape": train_mape,
"val_mape": val_mape,
"prophet_val_mape": prophet_val_mape,
"hybrid_val_mape": val_mape,
"improvement_percentage": float(improvement_pct),
"prophet_metrics": prophet_metrics
}
logger.info(
"Hybrid model evaluation complete",
val_rmse=val_rmse,
val_mae=val_mae,
val_mape=val_mape,
improvement=improvement_pct
)
return metrics
def _package_hybrid_model(
self,
prophet_result: Dict[str, Any],
metrics: Dict[str, float],
metrics: Dict[str, Any],
tenant_id: str,
inventory_product_id: str
) -> Dict[str, Any]:
"""
Package hybrid model for storage.
Args:
prophet_result: Prophet model result
metrics: Hybrid model metrics
tenant_id: Tenant ID
inventory_product_id: Product ID
Returns:
Model package dictionary
"""
return {
'model_type': 'hybrid_prophet_xgboost',
'prophet_model': prophet_result.get('model'),
'prophet_model_path': prophet_result.get('model_path'),
'xgboost_model': self.xgb_model,
'feature_columns': self.feature_columns,
'prophet_metrics': {
'train_mae': metrics['prophet_train_mae'],
'val_mae': metrics['prophet_val_mae'],
'train_mape': metrics['prophet_train_mape'],
'val_mape': metrics['prophet_val_mape']
},
'hybrid_metrics': {
'train_mae': metrics['hybrid_train_mae'],
'val_mae': metrics['hybrid_val_mae'],
'train_mape': metrics['hybrid_train_mape'],
'val_mape': metrics['hybrid_val_mape']
},
'improvement_metrics': {
'mae_improvement_pct': metrics['mae_improvement_pct'],
'mape_improvement_pct': metrics['mape_improvement_pct']
},
'metrics': metrics,
'tenant_id': tenant_id,
'inventory_product_id': inventory_product_id,
'trained_at': datetime.now(timezone.utc).isoformat()
@@ -426,8 +410,18 @@ class HybridProphetXGBoost:
Returns:
DataFrame with predictions
"""
# Step 1: Get Prophet predictions
prophet_model = model_data['prophet_model']
# Step 1: Get Prophet model from path and make predictions
prophet_model_path = model_data.get('prophet_model_path')
if prophet_model_path is None:
raise ValueError("Prophet model path not found in model data")
# Load the Prophet model from the stored path
try:
import joblib
prophet_model = joblib.load(prophet_model_path)
except Exception as e:
raise ValueError(f"Failed to load Prophet model from path {prophet_model_path}: {str(e)}")
# ✅ FIX: Run blocking predict() in thread pool to avoid blocking event loop
import asyncio
prophet_forecast = await asyncio.to_thread(prophet_model.predict, future_df)

View File

@@ -43,86 +43,79 @@ class POIFeatureIntegrator:
force_refresh: bool = False
) -> Optional[Dict[str, Any]]:
"""
Fetch POI features for tenant location.
Fetch POI features for tenant location (optimized for training).
First checks if POI context exists, if not, triggers detection.
First checks if POI context exists. If not, returns None without triggering detection.
POI detection should be triggered during tenant registration, not during training.
Args:
tenant_id: Tenant UUID
latitude: Bakery latitude
longitude: Bakery longitude
force_refresh: Force re-detection
force_refresh: Force re-detection (only use if POI context already exists)
Returns:
Dictionary with POI features or None if detection fails
Dictionary with POI features or None if not available
"""
try:
# Try to get existing POI context first
if not force_refresh:
existing_context = await self.external_client.get_poi_context(tenant_id)
if existing_context:
poi_context = existing_context.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
existing_context = await self.external_client.get_poi_context(tenant_id)
# Check if stale
is_stale = existing_context.get("is_stale", False)
if not is_stale:
if existing_context:
poi_context = existing_context.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
# Check if stale and force_refresh is requested
is_stale = existing_context.get("is_stale", False)
if not is_stale or not force_refresh:
logger.info(
"Using existing POI context",
tenant_id=tenant_id,
is_stale=is_stale,
feature_count=len(ml_features)
)
return ml_features
else:
logger.info(
"POI context is stale and force_refresh=True, refreshing",
tenant_id=tenant_id
)
# Only refresh if explicitly requested and context exists
detection_result = await self.external_client.detect_poi_for_tenant(
tenant_id=tenant_id,
latitude=latitude,
longitude=longitude,
force_refresh=True
)
if detection_result:
poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"Using existing POI context",
tenant_id=tenant_id
"POI refresh completed",
tenant_id=tenant_id,
feature_count=len(ml_features)
)
return ml_features
else:
logger.info(
"POI context is stale, refreshing",
logger.warning(
"POI refresh failed, returning existing features",
tenant_id=tenant_id
)
force_refresh = True
else:
logger.info(
"No existing POI context, will detect",
tenant_id=tenant_id
)
# Detect or refresh POIs
logger.info(
"Detecting POIs for tenant",
tenant_id=tenant_id,
location=(latitude, longitude)
)
detection_result = await self.external_client.detect_poi_for_tenant(
tenant_id=tenant_id,
latitude=latitude,
longitude=longitude,
force_refresh=force_refresh
)
if detection_result:
poi_context = detection_result.get("poi_context", {})
ml_features = poi_context.get("ml_features", {})
logger.info(
"POI detection completed",
tenant_id=tenant_id,
total_pois=poi_context.get("total_pois_detected", 0),
feature_count=len(ml_features)
)
return ml_features
return ml_features
else:
logger.error(
"POI detection failed",
logger.info(
"No existing POI context found - POI detection should be triggered during tenant registration",
tenant_id=tenant_id
)
return None
except Exception as e:
logger.error(
"Unexpected error fetching POI features",
logger.warning(
"Error fetching POI features - returning None",
tenant_id=tenant_id,
error=str(e),
exc_info=True
error=str(e)
)
return None

View File

@@ -29,16 +29,15 @@ class DataClient:
self.sales_client = get_sales_client(settings, "training")
self.external_client = get_external_client(settings, "training")
# ExternalServiceClient always has get_stored_traffic_data_for_training method
self.supports_stored_traffic_data = True
# Configure timeouts for HTTP clients
self._configure_timeouts()
# Initialize circuit breakers for external services
self._init_circuit_breakers()
# Check if the new method is available for stored traffic data
if hasattr(self.external_client, 'get_stored_traffic_data_for_training'):
self.supports_stored_traffic_data = True
def _configure_timeouts(self):
"""Configure appropriate timeouts for HTTP clients"""
timeout = httpx.Timeout(
@@ -49,14 +48,12 @@ class DataClient:
)
# Apply timeout to clients if they have httpx clients
# Note: BaseServiceClient manages its own HTTP client internally
if hasattr(self.sales_client, 'client') and isinstance(self.sales_client.client, httpx.AsyncClient):
self.sales_client.client.timeout = timeout
if hasattr(self.external_client, 'client') and isinstance(self.external_client.client, httpx.AsyncClient):
self.external_client.client.timeout = timeout
else:
self.supports_stored_traffic_data = False
logger.warning("Stored traffic data method not available in external client")
def _init_circuit_breakers(self):
"""Initialize circuit breakers for external service calls"""

View File

@@ -404,22 +404,32 @@ class TrainingDataOrchestrator:
tenant_id: str
) -> Dict[str, Any]:
"""
Collect POI features for bakery location.
Collect POI features for bakery location (non-blocking).
POI features are static (location-based, not time-varying).
This method is non-blocking with a short timeout to prevent training delays.
If POI detection hasn't been run yet, training continues without POI features.
Returns:
Dictionary with POI features or empty dict if unavailable
"""
try:
logger.info(
"Collecting POI features",
"Collecting POI features (non-blocking)",
tenant_id=tenant_id,
location=(lat, lon)
)
poi_features = await self.poi_feature_integrator.fetch_poi_features(
tenant_id=tenant_id,
latitude=lat,
longitude=lon,
force_refresh=False
# Set a short timeout to prevent blocking training
# POI detection should have been triggered during tenant registration
poi_features = await asyncio.wait_for(
self.poi_feature_integrator.fetch_poi_features(
tenant_id=tenant_id,
latitude=lat,
longitude=lon,
force_refresh=False
),
timeout=15.0 # 15 second timeout - POI should be cached from registration
)
if poi_features:
@@ -430,18 +440,24 @@ class TrainingDataOrchestrator:
)
else:
logger.warning(
"No POI features collected (service may be unavailable)",
"No POI features collected (service may be unavailable or not yet detected)",
tenant_id=tenant_id
)
return poi_features or {}
except asyncio.TimeoutError:
logger.warning(
"POI collection timeout (15s) - continuing training without POI features. "
"POI detection should be triggered during tenant registration for best results.",
tenant_id=tenant_id
)
return {}
except Exception as e:
logger.error(
"Failed to collect POI features, continuing without them",
logger.warning(
"Failed to collect POI features (non-blocking) - continuing training without them",
tenant_id=tenant_id,
error=str(e),
exc_info=True
error=str(e)
)
return {}