Add POI feature and imporve the overall backend implementation

This commit is contained in:
Urtzi Alfaro
2025-11-12 15:34:10 +01:00
parent e8096cd979
commit 5783c7ed05
173 changed files with 16862 additions and 9078 deletions

View File

@@ -33,18 +33,22 @@ class ParallelProductProgressTracker:
def __init__(self, job_id: str, tenant_id: str, total_products: int):
self.job_id = job_id
self.tenant_id = tenant_id
self.total_products = total_products
self.total_products = max(total_products, 1) # Ensure at least 1 to avoid division by zero
self.products_completed = 0
self._lock = asyncio.Lock()
self.start_time = datetime.now(timezone.utc)
# Calculate progress increment per product
# Training range (from PROGRESS_TRAINING_RANGE_START to PROGRESS_TRAINING_RANGE_END) divided by number of products
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / total_products if total_products > 0 else 0
self.progress_per_product = PROGRESS_TRAINING_RANGE_WIDTH / self.total_products if self.total_products > 0 else 0
if total_products == 0:
logger.warning("ParallelProductProgressTracker initialized with zero products",
job_id=job_id)
logger.info("ParallelProductProgressTracker initialized",
job_id=job_id,
total_products=total_products,
total_products=self.total_products,
progress_per_product=f"{self.progress_per_product:.2f}%")
async def mark_product_completed(self, product_name: str) -> int:
@@ -87,7 +91,10 @@ class ParallelProductProgressTracker:
# Calculate overall progress (PROGRESS_TRAINING_RANGE_START% base + progress from completed products)
# This calculation is done on the frontend/consumer side based on the event data
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
if self.total_products > 0:
overall_progress = PROGRESS_TRAINING_RANGE_START + int((current_progress / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
else:
overall_progress = PROGRESS_TRAINING_RANGE_START
logger.info("Product training completed",
job_id=self.job_id,
@@ -101,8 +108,13 @@ class ParallelProductProgressTracker:
def get_progress(self) -> dict:
"""Get current progress summary"""
if self.total_products > 0:
progress_percentage = PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
else:
progress_percentage = PROGRESS_TRAINING_RANGE_START
return {
"products_completed": self.products_completed,
"total_products": self.total_products,
"progress_percentage": PROGRESS_TRAINING_RANGE_START + int((self.products_completed / self.total_products) * PROGRESS_TRAINING_RANGE_WIDTH)
"progress_percentage": progress_percentage
}

View File

@@ -15,7 +15,7 @@ import pandas as pd
from app.services.data_client import DataClient
from app.services.date_alignment_service import DateAlignmentService, DateRange, DataSourceType, AlignedDateRange
from app.ml.poi_feature_integrator import POIFeatureIntegrator
from app.services.training_events import publish_training_failed
logger = structlog.get_logger()
@@ -26,6 +26,7 @@ class TrainingDataSet:
sales_data: List[Dict[str, Any]]
weather_data: List[Dict[str, Any]]
traffic_data: List[Dict[str, Any]]
poi_features: Dict[str, Any] # POI features for location-based forecasting
date_range: AlignedDateRange
metadata: Dict[str, Any]
@@ -36,10 +37,12 @@ class TrainingDataOrchestrator:
Uses the new abstracted traffic service layer for multi-city support.
"""
def __init__(self,
date_alignment_service: DateAlignmentService = None):
def __init__(self,
date_alignment_service: DateAlignmentService = None,
poi_feature_integrator: POIFeatureIntegrator = None):
self.data_client = DataClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.poi_feature_integrator = poi_feature_integrator or POIFeatureIntegrator()
self.max_concurrent_requests = 5 # Increased for better performance
async def prepare_training_data(
@@ -122,20 +125,21 @@ class TrainingDataOrchestrator:
# Step 5: Collect external data sources concurrently
logger.info("Collecting external data sources...")
weather_data, traffic_data = await self._collect_external_data(
weather_data, traffic_data, poi_features = await self._collect_external_data(
aligned_range, bakery_location, tenant_id
)
# Step 6: Validate data quality
data_quality_results = self._validate_data_sources(
filtered_sales, weather_data, traffic_data, aligned_range
)
# Step 7: Create comprehensive training dataset
training_dataset = TrainingDataSet(
sales_data=filtered_sales,
weather_data=weather_data,
traffic_data=traffic_data,
poi_features=poi_features or {}, # POI features (static, location-based)
date_range=aligned_range,
metadata={
"tenant_id": tenant_id,
@@ -148,7 +152,8 @@ class TrainingDataOrchestrator:
"original_sales_range": {
"start": sales_date_range.start.isoformat(),
"end": sales_date_range.end.isoformat()
}
},
"poi_features_count": len(poi_features) if poi_features else 0
}
)
@@ -160,6 +165,7 @@ class TrainingDataOrchestrator:
logger.info(f" - Sales records: {len(filtered_sales)}")
logger.info(f" - Weather records: {len(weather_data)}")
logger.info(f" - Traffic records: {len(traffic_data)}")
logger.info(f" - POI features: {len(poi_features) if poi_features else 0}")
logger.info(f" - Data quality score: {final_validation.get('data_quality_score', 'N/A')}")
return training_dataset
@@ -329,21 +335,21 @@ class TrainingDataOrchestrator:
aligned_range: AlignedDateRange,
bakery_location: Tuple[float, float],
tenant_id: str
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""Collect weather and traffic data concurrently with enhanced error handling"""
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any]]:
"""Collect weather, traffic, and POI data concurrently with enhanced error handling"""
lat, lon = bakery_location
# Create collection tasks with timeout
tasks = []
# Weather data collection
if DataSourceType.WEATHER_FORECAST in aligned_range.available_sources:
weather_task = asyncio.create_task(
self._collect_weather_data_with_timeout(lat, lon, aligned_range, tenant_id)
)
tasks.append(("weather", weather_task))
# Enhanced Traffic data collection (supports multiple cities)
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
logger.info(f"🚛 Traffic data source available for multiple cities, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
@@ -353,7 +359,13 @@ class TrainingDataOrchestrator:
tasks.append(("traffic", traffic_task))
else:
logger.warning(f"🚫 Traffic data source NOT available in sources: {[s.value for s in aligned_range.available_sources]}")
# POI features collection (static, location-based)
poi_task = asyncio.create_task(
self._collect_poi_features(lat, lon, tenant_id)
)
tasks.append(("poi", poi_task))
# Execute tasks concurrently with proper error handling
results = {}
if tasks:
@@ -362,24 +374,76 @@ class TrainingDataOrchestrator:
*[task for _, task in tasks],
return_exceptions=True
)
for i, (task_name, _) in enumerate(tasks):
result = completed_tasks[i]
if isinstance(result, Exception):
logger.warning(f"{task_name} data collection failed: {result}")
results[task_name] = []
results[task_name] = [] if task_name != "poi" else {}
else:
results[task_name] = result
logger.info(f"{task_name} data collection completed: {len(result)} records")
if task_name == "poi":
logger.info(f"{task_name} features collected: {len(result) if result else 0} features")
else:
logger.info(f"{task_name} data collection completed: {len(result)} records")
except Exception as e:
logger.error(f"Error in concurrent data collection: {str(e)}")
results = {"weather": [], "traffic": []}
results = {"weather": [], "traffic": [], "poi": {}}
weather_data = results.get("weather", [])
traffic_data = results.get("traffic", [])
return weather_data, traffic_data
poi_features = results.get("poi", {})
return weather_data, traffic_data, poi_features
async def _collect_poi_features(
self,
lat: float,
lon: float,
tenant_id: str
) -> Dict[str, Any]:
"""
Collect POI features for bakery location.
POI features are static (location-based, not time-varying).
"""
try:
logger.info(
"Collecting POI features",
tenant_id=tenant_id,
location=(lat, lon)
)
poi_features = await self.poi_feature_integrator.fetch_poi_features(
tenant_id=tenant_id,
latitude=lat,
longitude=lon,
force_refresh=False
)
if poi_features:
logger.info(
"POI features collected successfully",
tenant_id=tenant_id,
feature_count=len(poi_features)
)
else:
logger.warning(
"No POI features collected (service may be unavailable)",
tenant_id=tenant_id
)
return poi_features or {}
except Exception as e:
logger.error(
"Failed to collect POI features, continuing without them",
tenant_id=tenant_id,
error=str(e),
exc_info=True
)
return {}
async def _collect_weather_data_with_timeout(
self,