Files
bakery-ia/services/external/app/services/poi_feature_selector.py

185 lines
6.7 KiB
Python

"""
POI Feature Selector
Determines which POI features are relevant for ML model inclusion.
Filters out low-signal features to prevent model noise and overfitting.
"""
from typing import Dict, List, Any
import structlog
from app.core.poi_config import RELEVANCE_THRESHOLDS
logger = structlog.get_logger()
class POIFeatureSelector:
"""
Feature relevance engine for POI-based ML features.
Applies research-based thresholds to filter out irrelevant POI features
that would add noise to bakery-specific demand forecasting models.
"""
def __init__(self, thresholds: Dict[str, Dict[str, float]] = None):
"""
Initialize feature selector.
Args:
thresholds: Custom relevance thresholds (defaults to RELEVANCE_THRESHOLDS)
"""
self.thresholds = thresholds or RELEVANCE_THRESHOLDS
def select_relevant_features(
self,
poi_detection_results: Dict[str, Any],
tenant_id: str = None
) -> Dict[str, Any]:
"""
Filter POI features based on relevance thresholds.
Only includes features for POI categories that pass relevance tests.
This prevents adding noise to ML models for bakeries where certain
POI categories are not significant.
Args:
poi_detection_results: Full POI detection results
tenant_id: Optional tenant ID for logging
Returns:
Dictionary with relevant features and detailed relevance report
"""
relevant_features = {}
relevance_report = []
relevant_categories = []
for category_key, data in poi_detection_results.items():
features = data.get("features", {})
thresholds = self.thresholds.get(category_key, {})
if not thresholds:
logger.warning(
f"No thresholds defined for category {category_key}",
tenant_id=tenant_id
)
continue
# Check relevance criteria
is_relevant, rejection_reason = self._check_relevance(
features, thresholds, category_key
)
if is_relevant:
# Include features with category prefix
for feature_name, value in features.items():
ml_feature_name = f"poi_{category_key}_{feature_name}"
# Convert boolean to int for ML
if isinstance(value, bool):
value = 1 if value else 0
relevant_features[ml_feature_name] = value
relevant_categories.append(category_key)
relevance_report.append({
"category": category_key,
"relevant": True,
"reason": "Passes all relevance thresholds",
"proximity_score": features.get("proximity_score", 0),
"count": features.get("total_count", 0),
"distance_to_nearest_m": features.get("distance_to_nearest_m", 9999)
})
else:
relevance_report.append({
"category": category_key,
"relevant": False,
"reason": rejection_reason,
"proximity_score": features.get("proximity_score", 0),
"count": features.get("total_count", 0),
"distance_to_nearest_m": features.get("distance_to_nearest_m", 9999)
})
logger.info(
"POI feature selection complete",
tenant_id=tenant_id,
total_categories=len(poi_detection_results),
relevant_categories=len(relevant_categories),
rejected_categories=len(poi_detection_results) - len(relevant_categories)
)
return {
"features": relevant_features,
"relevant_categories": relevant_categories,
"relevance_report": relevance_report,
"total_features": len(relevant_features),
"total_relevant_categories": len(relevant_categories)
}
def _check_relevance(
self,
features: Dict[str, Any],
thresholds: Dict[str, float],
category_key: str
) -> tuple[bool, str]:
"""
Check if POI category passes relevance thresholds.
Returns:
Tuple of (is_relevant, rejection_reason)
"""
# Criterion 1: Proximity score
min_proximity = thresholds.get("min_proximity_score", 0)
actual_proximity = features.get("proximity_score", 0)
if actual_proximity < min_proximity:
return False, f"Proximity score too low ({actual_proximity:.2f} < {min_proximity})"
# Criterion 2: Distance to nearest
max_distance = thresholds.get("max_distance_to_nearest_m", 9999)
actual_distance = features.get("distance_to_nearest_m", 9999)
if actual_distance > max_distance:
return False, f"Nearest POI too far ({actual_distance:.0f}m > {max_distance}m)"
# Criterion 3: Count threshold
min_count = thresholds.get("min_count", 0)
actual_count = features.get("total_count", 0)
if actual_count < min_count:
return False, f"Count too low ({actual_count} < {min_count})"
return True, "Passes all thresholds"
def get_feature_importance_summary(
self,
poi_detection_results: Dict[str, Any]
) -> List[Dict[str, Any]]:
"""
Generate summary of feature importance for all categories.
Useful for understanding POI landscape around a bakery.
"""
summary = []
for category_key, data in poi_detection_results.items():
features = data.get("features", {})
thresholds = self.thresholds.get(category_key, {})
is_relevant, reason = self._check_relevance(
features, thresholds, category_key
) if thresholds else (False, "No thresholds defined")
summary.append({
"category": category_key,
"is_relevant": is_relevant,
"proximity_score": features.get("proximity_score", 0),
"weighted_score": features.get("weighted_proximity_score", 0),
"total_count": features.get("total_count", 0),
"distance_to_nearest_m": features.get("distance_to_nearest_m", 9999),
"has_within_100m": features.get("has_within_100m", False),
"rejection_reason": None if is_relevant else reason
})
# Sort by relevance and proximity score
summary.sort(
key=lambda x: (x["is_relevant"], x["proximity_score"]),
reverse=True
)
return summary