185 lines
6.7 KiB
Python
185 lines
6.7 KiB
Python
"""
|
|
POI Feature Selector
|
|
|
|
Determines which POI features are relevant for ML model inclusion.
|
|
Filters out low-signal features to prevent model noise and overfitting.
|
|
"""
|
|
|
|
from typing import Dict, List, Any
|
|
import structlog
|
|
|
|
from app.core.poi_config import RELEVANCE_THRESHOLDS
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class POIFeatureSelector:
|
|
"""
|
|
Feature relevance engine for POI-based ML features.
|
|
|
|
Applies research-based thresholds to filter out irrelevant POI features
|
|
that would add noise to bakery-specific demand forecasting models.
|
|
"""
|
|
|
|
def __init__(self, thresholds: Dict[str, Dict[str, float]] = None):
|
|
"""
|
|
Initialize feature selector.
|
|
|
|
Args:
|
|
thresholds: Custom relevance thresholds (defaults to RELEVANCE_THRESHOLDS)
|
|
"""
|
|
self.thresholds = thresholds or RELEVANCE_THRESHOLDS
|
|
|
|
def select_relevant_features(
|
|
self,
|
|
poi_detection_results: Dict[str, Any],
|
|
tenant_id: str = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Filter POI features based on relevance thresholds.
|
|
|
|
Only includes features for POI categories that pass relevance tests.
|
|
This prevents adding noise to ML models for bakeries where certain
|
|
POI categories are not significant.
|
|
|
|
Args:
|
|
poi_detection_results: Full POI detection results
|
|
tenant_id: Optional tenant ID for logging
|
|
|
|
Returns:
|
|
Dictionary with relevant features and detailed relevance report
|
|
"""
|
|
relevant_features = {}
|
|
relevance_report = []
|
|
relevant_categories = []
|
|
|
|
for category_key, data in poi_detection_results.items():
|
|
features = data.get("features", {})
|
|
thresholds = self.thresholds.get(category_key, {})
|
|
|
|
if not thresholds:
|
|
logger.warning(
|
|
f"No thresholds defined for category {category_key}",
|
|
tenant_id=tenant_id
|
|
)
|
|
continue
|
|
|
|
# Check relevance criteria
|
|
is_relevant, rejection_reason = self._check_relevance(
|
|
features, thresholds, category_key
|
|
)
|
|
|
|
if is_relevant:
|
|
# Include features with category prefix
|
|
for feature_name, value in features.items():
|
|
ml_feature_name = f"poi_{category_key}_{feature_name}"
|
|
# Convert boolean to int for ML
|
|
if isinstance(value, bool):
|
|
value = 1 if value else 0
|
|
relevant_features[ml_feature_name] = value
|
|
|
|
relevant_categories.append(category_key)
|
|
relevance_report.append({
|
|
"category": category_key,
|
|
"relevant": True,
|
|
"reason": "Passes all relevance thresholds",
|
|
"proximity_score": features.get("proximity_score", 0),
|
|
"count": features.get("total_count", 0),
|
|
"distance_to_nearest_m": features.get("distance_to_nearest_m", 9999)
|
|
})
|
|
else:
|
|
relevance_report.append({
|
|
"category": category_key,
|
|
"relevant": False,
|
|
"reason": rejection_reason,
|
|
"proximity_score": features.get("proximity_score", 0),
|
|
"count": features.get("total_count", 0),
|
|
"distance_to_nearest_m": features.get("distance_to_nearest_m", 9999)
|
|
})
|
|
|
|
logger.info(
|
|
"POI feature selection complete",
|
|
tenant_id=tenant_id,
|
|
total_categories=len(poi_detection_results),
|
|
relevant_categories=len(relevant_categories),
|
|
rejected_categories=len(poi_detection_results) - len(relevant_categories)
|
|
)
|
|
|
|
return {
|
|
"features": relevant_features,
|
|
"relevant_categories": relevant_categories,
|
|
"relevance_report": relevance_report,
|
|
"total_features": len(relevant_features),
|
|
"total_relevant_categories": len(relevant_categories)
|
|
}
|
|
|
|
def _check_relevance(
|
|
self,
|
|
features: Dict[str, Any],
|
|
thresholds: Dict[str, float],
|
|
category_key: str
|
|
) -> tuple[bool, str]:
|
|
"""
|
|
Check if POI category passes relevance thresholds.
|
|
|
|
Returns:
|
|
Tuple of (is_relevant, rejection_reason)
|
|
"""
|
|
# Criterion 1: Proximity score
|
|
min_proximity = thresholds.get("min_proximity_score", 0)
|
|
actual_proximity = features.get("proximity_score", 0)
|
|
if actual_proximity < min_proximity:
|
|
return False, f"Proximity score too low ({actual_proximity:.2f} < {min_proximity})"
|
|
|
|
# Criterion 2: Distance to nearest
|
|
max_distance = thresholds.get("max_distance_to_nearest_m", 9999)
|
|
actual_distance = features.get("distance_to_nearest_m", 9999)
|
|
if actual_distance > max_distance:
|
|
return False, f"Nearest POI too far ({actual_distance:.0f}m > {max_distance}m)"
|
|
|
|
# Criterion 3: Count threshold
|
|
min_count = thresholds.get("min_count", 0)
|
|
actual_count = features.get("total_count", 0)
|
|
if actual_count < min_count:
|
|
return False, f"Count too low ({actual_count} < {min_count})"
|
|
|
|
return True, "Passes all thresholds"
|
|
|
|
def get_feature_importance_summary(
|
|
self,
|
|
poi_detection_results: Dict[str, Any]
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Generate summary of feature importance for all categories.
|
|
|
|
Useful for understanding POI landscape around a bakery.
|
|
"""
|
|
summary = []
|
|
|
|
for category_key, data in poi_detection_results.items():
|
|
features = data.get("features", {})
|
|
thresholds = self.thresholds.get(category_key, {})
|
|
|
|
is_relevant, reason = self._check_relevance(
|
|
features, thresholds, category_key
|
|
) if thresholds else (False, "No thresholds defined")
|
|
|
|
summary.append({
|
|
"category": category_key,
|
|
"is_relevant": is_relevant,
|
|
"proximity_score": features.get("proximity_score", 0),
|
|
"weighted_score": features.get("weighted_proximity_score", 0),
|
|
"total_count": features.get("total_count", 0),
|
|
"distance_to_nearest_m": features.get("distance_to_nearest_m", 9999),
|
|
"has_within_100m": features.get("has_within_100m", False),
|
|
"rejection_reason": None if is_relevant else reason
|
|
})
|
|
|
|
# Sort by relevance and proximity score
|
|
summary.sort(
|
|
key=lambda x: (x["is_relevant"], x["proximity_score"]),
|
|
reverse=True
|
|
)
|
|
|
|
return summary
|