Refactor the traffic fetching system
This commit is contained in:
1788
services/data/app/external/apis/madrid_traffic_client.py
vendored
1788
services/data/app/external/apis/madrid_traffic_client.py
vendored
File diff suppressed because it is too large
Load Diff
12
services/data/app/external/clients/__init__.py
vendored
Normal file
12
services/data/app/external/clients/__init__.py
vendored
Normal file
@@ -0,0 +1,12 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/clients/__init__.py
|
||||
# ================================================================
|
||||
"""
|
||||
HTTP clients package
|
||||
"""
|
||||
|
||||
from .madrid_client import MadridTrafficAPIClient
|
||||
|
||||
__all__ = [
|
||||
'MadridTrafficAPIClient'
|
||||
]
|
||||
155
services/data/app/external/clients/madrid_client.py
vendored
Normal file
155
services/data/app/external/clients/madrid_client.py
vendored
Normal file
@@ -0,0 +1,155 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/clients/madrid_client.py
|
||||
# ================================================================
|
||||
"""
|
||||
Pure HTTP client for Madrid traffic APIs
|
||||
Handles only HTTP communication and response decoding
|
||||
"""
|
||||
|
||||
import httpx
|
||||
import structlog
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from ..base_client import BaseAPIClient
|
||||
|
||||
|
||||
class MadridTrafficAPIClient(BaseAPIClient):
|
||||
"""Pure HTTP client for Madrid traffic APIs"""
|
||||
|
||||
TRAFFIC_ENDPOINT = "https://datos.madrid.es/egob/catalogo/202468-10-intensidad-trafico.xml"
|
||||
MEASUREMENT_POINTS_URL = "https://datos.madrid.es/egob/catalogo/202468-263-intensidad-trafico.csv"
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(base_url="https://datos.madrid.es")
|
||||
self.logger = structlog.get_logger()
|
||||
|
||||
def _decode_response_content(self, response) -> Optional[str]:
|
||||
"""Decode response content with multiple encoding attempts"""
|
||||
try:
|
||||
return response.text
|
||||
except UnicodeDecodeError:
|
||||
# Try manual encoding for Spanish content
|
||||
for encoding in ['utf-8', 'latin-1', 'windows-1252', 'iso-8859-1']:
|
||||
try:
|
||||
content = response.content.decode(encoding)
|
||||
if content and len(content) > 100:
|
||||
self.logger.debug("Successfully decoded with encoding", encoding=encoding)
|
||||
return content
|
||||
except UnicodeDecodeError:
|
||||
continue
|
||||
return None
|
||||
|
||||
def _build_historical_url(self, year: int, month: int) -> str:
|
||||
"""Build historical ZIP URL for given year and month"""
|
||||
# Madrid historical data URL pattern
|
||||
base_url = "https://datos.madrid.es/egob/catalogo/208627"
|
||||
|
||||
# URL numbering pattern (this may need adjustment based on actual URLs)
|
||||
if year == 2023:
|
||||
url_number = 116 + (month - 1) # 116-127 for 2023
|
||||
elif year == 2024:
|
||||
url_number = 128 + (month - 1) # 128-139 for 2024
|
||||
else:
|
||||
url_number = 116 # Fallback
|
||||
|
||||
return f"{base_url}-{url_number}-transporte-ptomedida-historico.zip"
|
||||
|
||||
async def fetch_current_traffic_xml(self, endpoint: Optional[str] = None) -> Optional[str]:
|
||||
"""Fetch current traffic XML data"""
|
||||
endpoint = endpoint or self.TRAFFIC_ENDPOINT
|
||||
|
||||
try:
|
||||
headers = {
|
||||
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36',
|
||||
'Accept': 'application/xml,text/xml,*/*',
|
||||
'Accept-Language': 'es-ES,es;q=0.9,en;q=0.8',
|
||||
'Accept-Encoding': 'gzip, deflate, br',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Referer': 'https://datos.madrid.es/'
|
||||
}
|
||||
|
||||
response = await self.get(endpoint, headers=headers, timeout=30)
|
||||
|
||||
if not response or response.status_code != 200:
|
||||
self.logger.warning("Failed to fetch XML data",
|
||||
endpoint=endpoint,
|
||||
status=response.status_code if response else None)
|
||||
return None
|
||||
|
||||
# Get XML content with encoding handling
|
||||
xml_content = self._decode_response_content(response)
|
||||
if not xml_content:
|
||||
self.logger.debug("No XML content received", endpoint=endpoint)
|
||||
return None
|
||||
|
||||
self.logger.debug("Madrid XML content fetched",
|
||||
length=len(xml_content),
|
||||
endpoint=endpoint)
|
||||
|
||||
return xml_content
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error fetching traffic XML data",
|
||||
endpoint=endpoint,
|
||||
error=str(e))
|
||||
return None
|
||||
|
||||
async def fetch_measurement_points_csv(self, url: Optional[str] = None) -> Optional[str]:
|
||||
"""Fetch measurement points CSV data"""
|
||||
url = url or self.MEASUREMENT_POINTS_URL
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=30.0,
|
||||
headers={
|
||||
'User-Agent': 'MadridTrafficClient/2.0',
|
||||
'Accept': 'text/csv,application/csv,*/*'
|
||||
},
|
||||
follow_redirects=True
|
||||
) as client:
|
||||
|
||||
self.logger.debug("Fetching measurement points registry", url=url)
|
||||
response = await client.get(url)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.text
|
||||
else:
|
||||
self.logger.warning("Failed to fetch measurement points",
|
||||
status=response.status_code, url=url)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error fetching measurement points registry",
|
||||
url=url, error=str(e))
|
||||
return None
|
||||
|
||||
async def fetch_historical_zip(self, zip_url: str) -> Optional[bytes]:
|
||||
"""Fetch historical traffic ZIP file"""
|
||||
try:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=120.0, # Longer timeout for large files
|
||||
headers={
|
||||
'User-Agent': 'MadridTrafficClient/2.0',
|
||||
'Accept': 'application/zip,*/*'
|
||||
},
|
||||
follow_redirects=True
|
||||
) as client:
|
||||
|
||||
self.logger.debug("Fetching historical ZIP", url=zip_url)
|
||||
response = await client.get(zip_url)
|
||||
|
||||
if response.status_code == 200:
|
||||
self.logger.debug("Historical ZIP fetched",
|
||||
url=zip_url,
|
||||
size=len(response.content))
|
||||
return response.content
|
||||
else:
|
||||
self.logger.warning("Failed to fetch historical ZIP",
|
||||
status=response.status_code, url=zip_url)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error fetching historical ZIP",
|
||||
url=zip_url, error=str(e))
|
||||
return None
|
||||
20
services/data/app/external/models/__init__.py
vendored
Normal file
20
services/data/app/external/models/__init__.py
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/models/__init__.py
|
||||
# ================================================================
|
||||
"""
|
||||
Madrid traffic models package
|
||||
"""
|
||||
|
||||
from .madrid_models import (
|
||||
TrafficServiceLevel,
|
||||
CongestionLevel,
|
||||
MeasurementPoint,
|
||||
TrafficRecord
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'TrafficServiceLevel',
|
||||
'CongestionLevel',
|
||||
'MeasurementPoint',
|
||||
'TrafficRecord'
|
||||
]
|
||||
66
services/data/app/external/models/madrid_models.py
vendored
Normal file
66
services/data/app/external/models/madrid_models.py
vendored
Normal file
@@ -0,0 +1,66 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/models/madrid_models.py
|
||||
# ================================================================
|
||||
"""
|
||||
Data structures, enums, and dataclasses for Madrid traffic system
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class TrafficServiceLevel(Enum):
|
||||
"""Madrid traffic service levels"""
|
||||
FLUID = 0
|
||||
DENSE = 1
|
||||
CONGESTED = 2
|
||||
BLOCKED = 3
|
||||
|
||||
|
||||
class CongestionLevel(Enum):
|
||||
"""Standardized congestion levels"""
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
BLOCKED = "blocked"
|
||||
|
||||
|
||||
@dataclass
|
||||
class MeasurementPoint:
|
||||
"""Madrid measurement point data structure"""
|
||||
id: str
|
||||
latitude: float
|
||||
longitude: float
|
||||
distance: float
|
||||
name: str
|
||||
type: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrafficRecord:
|
||||
"""Standardized traffic record with pedestrian inference"""
|
||||
date: datetime
|
||||
traffic_volume: int
|
||||
occupation_percentage: int
|
||||
load_percentage: int
|
||||
average_speed: int
|
||||
congestion_level: str
|
||||
pedestrian_count: int
|
||||
measurement_point_id: str
|
||||
measurement_point_name: str
|
||||
road_type: str
|
||||
source: str
|
||||
district: Optional[str] = None
|
||||
|
||||
# Madrid-specific data
|
||||
intensidad_raw: Optional[int] = None
|
||||
ocupacion_raw: Optional[int] = None
|
||||
carga_raw: Optional[int] = None
|
||||
vmed_raw: Optional[int] = None
|
||||
|
||||
# Pedestrian inference metadata
|
||||
pedestrian_multiplier: Optional[float] = None
|
||||
time_pattern_factor: Optional[float] = None
|
||||
district_factor: Optional[float] = None
|
||||
14
services/data/app/external/processors/__init__.py
vendored
Normal file
14
services/data/app/external/processors/__init__.py
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/processors/__init__.py
|
||||
# ================================================================
|
||||
"""
|
||||
Data processors package
|
||||
"""
|
||||
|
||||
from .madrid_processor import MadridTrafficDataProcessor
|
||||
from .madrid_business_logic import MadridTrafficAnalyzer
|
||||
|
||||
__all__ = [
|
||||
'MadridTrafficDataProcessor',
|
||||
'MadridTrafficAnalyzer'
|
||||
]
|
||||
346
services/data/app/external/processors/madrid_business_logic.py
vendored
Normal file
346
services/data/app/external/processors/madrid_business_logic.py
vendored
Normal file
@@ -0,0 +1,346 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/processors/madrid_business_logic.py
|
||||
# ================================================================
|
||||
"""
|
||||
Business rules, inference, and domain logic for Madrid traffic data
|
||||
Handles pedestrian inference, district mapping, road classification, and validation
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import structlog
|
||||
|
||||
from ..models.madrid_models import TrafficRecord, CongestionLevel
|
||||
|
||||
|
||||
class MadridTrafficAnalyzer:
|
||||
"""Handles business logic for Madrid traffic analysis"""
|
||||
|
||||
# Madrid district characteristics for pedestrian patterns
|
||||
DISTRICT_MULTIPLIERS = {
|
||||
'Centro': 2.5, # Historic center, high pedestrian activity
|
||||
'Salamanca': 2.0, # Shopping area, high foot traffic
|
||||
'Chamberí': 1.8, # Business district
|
||||
'Retiro': 2.2, # Near park, high leisure activity
|
||||
'Chamartín': 1.6, # Business/residential
|
||||
'Tetuán': 1.4, # Mixed residential/commercial
|
||||
'Fuencarral': 1.3, # Residential with commercial areas
|
||||
'Moncloa': 1.7, # University area
|
||||
'Latina': 1.5, # Residential area
|
||||
'Carabanchel': 1.2, # Residential periphery
|
||||
'Usera': 1.1, # Industrial/residential
|
||||
'Villaverde': 1.0, # Industrial area
|
||||
'Villa de Vallecas': 1.0, # Peripheral residential
|
||||
'Vicálvaro': 0.9, # Peripheral
|
||||
'San Blas': 1.1, # Residential
|
||||
'Barajas': 0.8, # Airport area, low pedestrian activity
|
||||
'Hortaleza': 1.2, # Mixed area
|
||||
'Ciudad Lineal': 1.3, # Linear development
|
||||
'Puente de Vallecas': 1.2, # Working class area
|
||||
'Moratalaz': 1.1, # Residential
|
||||
'Arganzuela': 1.6, # Near center, growing area
|
||||
}
|
||||
|
||||
# Time-based patterns (hour of day)
|
||||
TIME_PATTERNS = {
|
||||
'morning_peak': {'hours': [7, 8, 9], 'multiplier': 2.0},
|
||||
'lunch_peak': {'hours': [12, 13, 14], 'multiplier': 2.5},
|
||||
'evening_peak': {'hours': [18, 19, 20], 'multiplier': 2.2},
|
||||
'afternoon': {'hours': [15, 16, 17], 'multiplier': 1.8},
|
||||
'late_evening': {'hours': [21, 22], 'multiplier': 1.5},
|
||||
'night': {'hours': [23, 0, 1, 2, 3, 4, 5, 6], 'multiplier': 0.3},
|
||||
'morning': {'hours': [10, 11], 'multiplier': 1.4}
|
||||
}
|
||||
|
||||
# Road type specific patterns
|
||||
ROAD_TYPE_BASE = {
|
||||
'URB': 250, # Urban streets - high pedestrian activity
|
||||
'M30': 50, # Ring road - minimal pedestrians
|
||||
'C30': 75, # Secondary ring - some pedestrian access
|
||||
'A': 25, # Highways - very low pedestrians
|
||||
'R': 40 # Radial roads - low to moderate
|
||||
}
|
||||
|
||||
# Weather impact on pedestrian activity
|
||||
WEATHER_IMPACT = {
|
||||
'rain': 0.6, # 40% reduction in rain
|
||||
'hot_weather': 0.8, # 20% reduction when very hot
|
||||
'cold_weather': 0.7, # 30% reduction when very cold
|
||||
'normal': 1.0 # No impact
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.logger = structlog.get_logger()
|
||||
|
||||
def calculate_pedestrian_flow(
|
||||
self,
|
||||
traffic_record: TrafficRecord,
|
||||
location_context: Optional[Dict[str, Any]] = None
|
||||
) -> Tuple[int, Dict[str, float]]:
|
||||
"""
|
||||
Calculate pedestrian flow estimate with detailed metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (pedestrian_count, inference_metadata)
|
||||
"""
|
||||
# Base calculation from road type
|
||||
road_type = traffic_record.road_type or 'URB'
|
||||
base_pedestrians = self.ROAD_TYPE_BASE.get(road_type, 200)
|
||||
|
||||
# Time pattern adjustment
|
||||
hour = traffic_record.date.hour
|
||||
time_factor = self._get_time_pattern_factor(hour)
|
||||
|
||||
# District adjustment (if available)
|
||||
district_factor = 1.0
|
||||
district = traffic_record.district or self.infer_district_from_location(location_context)
|
||||
if district:
|
||||
district_factor = self.DISTRICT_MULTIPLIERS.get(district, 1.0)
|
||||
|
||||
# Traffic correlation adjustment
|
||||
traffic_factor = self._calculate_traffic_correlation(traffic_record)
|
||||
|
||||
# Weather adjustment (if data available)
|
||||
weather_factor = self._get_weather_factor(traffic_record.date, location_context)
|
||||
|
||||
# Weekend adjustment
|
||||
weekend_factor = self._get_weekend_factor(traffic_record.date)
|
||||
|
||||
# Combined calculation
|
||||
pedestrian_count = int(
|
||||
base_pedestrians *
|
||||
time_factor *
|
||||
district_factor *
|
||||
traffic_factor *
|
||||
weather_factor *
|
||||
weekend_factor
|
||||
)
|
||||
|
||||
# Ensure reasonable bounds
|
||||
pedestrian_count = max(10, min(2000, pedestrian_count))
|
||||
|
||||
# Metadata for model training
|
||||
inference_metadata = {
|
||||
'base_pedestrians': base_pedestrians,
|
||||
'time_factor': time_factor,
|
||||
'district_factor': district_factor,
|
||||
'traffic_factor': traffic_factor,
|
||||
'weather_factor': weather_factor,
|
||||
'weekend_factor': weekend_factor,
|
||||
'inferred_district': district,
|
||||
'hour': hour,
|
||||
'road_type': road_type
|
||||
}
|
||||
|
||||
return pedestrian_count, inference_metadata
|
||||
|
||||
def _get_time_pattern_factor(self, hour: int) -> float:
|
||||
"""Get time-based pedestrian activity multiplier"""
|
||||
for pattern, config in self.TIME_PATTERNS.items():
|
||||
if hour in config['hours']:
|
||||
return config['multiplier']
|
||||
return 1.0 # Default multiplier
|
||||
|
||||
def _calculate_traffic_correlation(self, traffic_record: TrafficRecord) -> float:
|
||||
"""
|
||||
Calculate pedestrian correlation with traffic patterns
|
||||
Higher traffic in urban areas often correlates with more pedestrians
|
||||
"""
|
||||
if traffic_record.road_type == 'URB':
|
||||
# Urban areas: moderate traffic indicates commercial activity
|
||||
if 30 <= traffic_record.load_percentage <= 70:
|
||||
return 1.3 # Sweet spot for pedestrian activity
|
||||
elif traffic_record.load_percentage > 70:
|
||||
return 0.9 # Too congested, pedestrians avoid
|
||||
else:
|
||||
return 1.0 # Normal correlation
|
||||
else:
|
||||
# Highway/ring roads: more traffic = fewer pedestrians
|
||||
if traffic_record.load_percentage > 60:
|
||||
return 0.5
|
||||
else:
|
||||
return 0.8
|
||||
|
||||
def _get_weather_factor(self, date: datetime, location_context: Optional[Dict] = None) -> float:
|
||||
"""Estimate weather impact on pedestrian activity"""
|
||||
# Simplified weather inference based on season and typical Madrid patterns
|
||||
month = date.month
|
||||
|
||||
# Madrid seasonal patterns
|
||||
if month in [12, 1, 2]: # Winter - cold weather impact
|
||||
return self.WEATHER_IMPACT['cold_weather']
|
||||
elif month in [7, 8]: # Summer - hot weather impact
|
||||
return self.WEATHER_IMPACT['hot_weather']
|
||||
elif month in [10, 11, 3, 4]: # Rainy seasons - moderate impact
|
||||
return 0.85
|
||||
else: # Spring/early summer - optimal weather
|
||||
return 1.1
|
||||
|
||||
def _get_weekend_factor(self, date: datetime) -> float:
|
||||
"""Weekend vs weekday pedestrian patterns"""
|
||||
weekday = date.weekday()
|
||||
hour = date.hour
|
||||
|
||||
if weekday >= 5: # Weekend
|
||||
if 11 <= hour <= 16: # Weekend shopping/leisure hours
|
||||
return 1.4
|
||||
elif 20 <= hour <= 23: # Weekend evening activity
|
||||
return 1.3
|
||||
else:
|
||||
return 0.9
|
||||
else: # Weekday
|
||||
return 1.0
|
||||
|
||||
def infer_district_from_location(self, location_context: Optional[Dict] = None) -> Optional[str]:
|
||||
"""
|
||||
Infer Madrid district from location context or coordinates
|
||||
"""
|
||||
if not location_context:
|
||||
return None
|
||||
|
||||
lat = location_context.get('latitude')
|
||||
lon = location_context.get('longitude')
|
||||
|
||||
if not (lat and lon):
|
||||
return None
|
||||
|
||||
# Madrid district boundaries (simplified boundaries for inference)
|
||||
districts = {
|
||||
# Central districts
|
||||
'Centro': {'lat_min': 40.405, 'lat_max': 40.425, 'lon_min': -3.720, 'lon_max': -3.690},
|
||||
'Arganzuela': {'lat_min': 40.385, 'lat_max': 40.410, 'lon_min': -3.720, 'lon_max': -3.680},
|
||||
'Retiro': {'lat_min': 40.405, 'lat_max': 40.425, 'lon_min': -3.690, 'lon_max': -3.660},
|
||||
'Salamanca': {'lat_min': 40.420, 'lat_max': 40.445, 'lon_min': -3.690, 'lon_max': -3.660},
|
||||
'Chamartín': {'lat_min': 40.445, 'lat_max': 40.480, 'lon_min': -3.690, 'lon_max': -3.660},
|
||||
'Tetuán': {'lat_min': 40.445, 'lat_max': 40.470, 'lon_min': -3.720, 'lon_max': -3.690},
|
||||
'Chamberí': {'lat_min': 40.425, 'lat_max': 40.450, 'lon_min': -3.720, 'lon_max': -3.690},
|
||||
'Fuencarral-El Pardo': {'lat_min': 40.470, 'lat_max': 40.540, 'lon_min': -3.750, 'lon_max': -3.650},
|
||||
'Moncloa-Aravaca': {'lat_min': 40.430, 'lat_max': 40.480, 'lon_min': -3.750, 'lon_max': -3.720},
|
||||
'Latina': {'lat_min': 40.380, 'lat_max': 40.420, 'lon_min': -3.750, 'lon_max': -3.720},
|
||||
'Carabanchel': {'lat_min': 40.350, 'lat_max': 40.390, 'lon_min': -3.750, 'lon_max': -3.720},
|
||||
'Usera': {'lat_min': 40.350, 'lat_max': 40.385, 'lon_min': -3.720, 'lon_max': -3.690},
|
||||
'Puente de Vallecas': {'lat_min': 40.370, 'lat_max': 40.410, 'lon_min': -3.680, 'lon_max': -3.640},
|
||||
'Moratalaz': {'lat_min': 40.400, 'lat_max': 40.430, 'lon_min': -3.650, 'lon_max': -3.620},
|
||||
'Ciudad Lineal': {'lat_min': 40.430, 'lat_max': 40.460, 'lon_min': -3.650, 'lon_max': -3.620},
|
||||
'Hortaleza': {'lat_min': 40.460, 'lat_max': 40.500, 'lon_min': -3.650, 'lon_max': -3.620},
|
||||
'Villaverde': {'lat_min': 40.320, 'lat_max': 40.360, 'lon_min': -3.720, 'lon_max': -3.680},
|
||||
}
|
||||
|
||||
# Find matching district
|
||||
for district_name, bounds in districts.items():
|
||||
if (bounds['lat_min'] <= lat <= bounds['lat_max'] and
|
||||
bounds['lon_min'] <= lon <= bounds['lon_max']):
|
||||
return district_name
|
||||
|
||||
# Default for coordinates in Madrid but not matching specific districts
|
||||
if 40.3 <= lat <= 40.6 and -3.8 <= lon <= -3.5:
|
||||
return 'Other Madrid'
|
||||
|
||||
return None
|
||||
|
||||
def classify_road_type(self, measurement_point_name: str) -> str:
|
||||
"""Classify road type based on measurement point name"""
|
||||
if not measurement_point_name:
|
||||
return 'URB' # Default to urban
|
||||
|
||||
name_upper = measurement_point_name.upper()
|
||||
|
||||
# Highway patterns
|
||||
if any(pattern in name_upper for pattern in ['A-', 'AP-', 'AUTOPISTA', 'AUTOVIA']):
|
||||
return 'A'
|
||||
|
||||
# M-30 Ring road
|
||||
if 'M-30' in name_upper or 'M30' in name_upper:
|
||||
return 'M30'
|
||||
|
||||
# Other M roads (ring roads)
|
||||
if re.search(r'M-[0-9]', name_upper) or re.search(r'M[0-9]', name_upper):
|
||||
return 'C30'
|
||||
|
||||
# Radial roads (R-1, R-2, etc.)
|
||||
if re.search(r'R-[0-9]', name_upper) or 'RADIAL' in name_upper:
|
||||
return 'R'
|
||||
|
||||
# Default to urban street
|
||||
return 'URB'
|
||||
|
||||
def validate_madrid_coordinates(self, lat: float, lon: float) -> bool:
|
||||
"""Validate coordinates are within Madrid bounds"""
|
||||
# Madrid metropolitan area bounds
|
||||
return 40.3 <= lat <= 40.6 and -3.8 <= lon <= -3.5
|
||||
|
||||
def get_congestion_level(self, occupation_pct: float) -> str:
|
||||
"""Convert occupation percentage to congestion level"""
|
||||
if occupation_pct >= 80:
|
||||
return CongestionLevel.BLOCKED.value
|
||||
elif occupation_pct >= 50:
|
||||
return CongestionLevel.HIGH.value
|
||||
elif occupation_pct >= 25:
|
||||
return CongestionLevel.MEDIUM.value
|
||||
else:
|
||||
return CongestionLevel.LOW.value
|
||||
|
||||
def calculate_distance(self, lat1: float, lon1: float, lat2: float, lon2: float) -> float:
|
||||
"""Calculate distance between two points in kilometers using Haversine formula"""
|
||||
R = 6371 # Earth's radius in kilometers
|
||||
|
||||
dlat = math.radians(lat2 - lat1)
|
||||
dlon = math.radians(lon2 - lon1)
|
||||
a = (math.sin(dlat/2) * math.sin(dlat/2) +
|
||||
math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) *
|
||||
math.sin(dlon/2) * math.sin(dlon/2))
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
|
||||
|
||||
return R * c
|
||||
|
||||
def find_nearest_traffic_point(self, traffic_points: List[Dict[str, Any]],
|
||||
latitude: float, longitude: float) -> Optional[Dict[str, Any]]:
|
||||
"""Find the nearest traffic point to given coordinates"""
|
||||
if not traffic_points:
|
||||
return None
|
||||
|
||||
min_distance = float('inf')
|
||||
nearest_point = None
|
||||
|
||||
for point in traffic_points:
|
||||
point_lat = point.get('latitude')
|
||||
point_lon = point.get('longitude')
|
||||
|
||||
if point_lat and point_lon:
|
||||
distance = self.calculate_distance(latitude, longitude, point_lat, point_lon)
|
||||
if distance < min_distance:
|
||||
min_distance = distance
|
||||
nearest_point = point
|
||||
|
||||
return nearest_point
|
||||
|
||||
def find_nearest_measurement_points(self, measurement_points: Dict[str, Dict[str, Any]],
|
||||
latitude: float, longitude: float,
|
||||
num_points: int = 3, max_distance_km: Optional[float] = 5.0) -> List[Tuple[str, Dict[str, Any], float]]:
|
||||
"""Find nearest measurement points for historical data"""
|
||||
distances = []
|
||||
|
||||
for point_id, point_data in measurement_points.items():
|
||||
point_lat = point_data.get('latitude')
|
||||
point_lon = point_data.get('longitude')
|
||||
|
||||
if point_lat and point_lon:
|
||||
distance_km = self.calculate_distance(latitude, longitude, point_lat, point_lon)
|
||||
distances.append((point_id, point_data, distance_km))
|
||||
|
||||
# Sort by distance and take nearest points
|
||||
distances.sort(key=lambda x: x[2])
|
||||
|
||||
# Apply distance filter if specified
|
||||
if max_distance_km is not None:
|
||||
distances = [p for p in distances if p[2] <= max_distance_km]
|
||||
|
||||
nearest = distances[:num_points]
|
||||
|
||||
self.logger.info("Found nearest measurement points",
|
||||
count=len(nearest),
|
||||
nearest_distance_km=nearest[0][2] if nearest else None)
|
||||
|
||||
return nearest
|
||||
478
services/data/app/external/processors/madrid_processor.py
vendored
Normal file
478
services/data/app/external/processors/madrid_processor.py
vendored
Normal file
@@ -0,0 +1,478 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/processors/madrid_processor.py
|
||||
# ================================================================
|
||||
"""
|
||||
Data transformation and parsing for Madrid traffic data
|
||||
Handles XML parsing, CSV processing, coordinate conversion, and data quality scoring
|
||||
"""
|
||||
|
||||
import csv
|
||||
import io
|
||||
import math
|
||||
import re
|
||||
import xml.etree.ElementTree as ET
|
||||
import zipfile
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import structlog
|
||||
import pyproj
|
||||
|
||||
from ..models.madrid_models import TrafficRecord, MeasurementPoint, CongestionLevel
|
||||
|
||||
|
||||
class MadridTrafficDataProcessor:
|
||||
"""Handles all data transformation and parsing for Madrid traffic data"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = structlog.get_logger()
|
||||
# UTM Zone 30N (Madrid's coordinate system)
|
||||
self.utm_proj = pyproj.Proj(proj='utm', zone=30, ellps='WGS84', datum='WGS84')
|
||||
self.wgs84_proj = pyproj.Proj(proj='latlong', ellps='WGS84', datum='WGS84')
|
||||
|
||||
def safe_int(self, value: str) -> int:
|
||||
"""Safely convert string to int"""
|
||||
try:
|
||||
return int(float(value.replace(',', '.')))
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
def _safe_float(self, value: str) -> float:
|
||||
"""Safely convert string to float"""
|
||||
try:
|
||||
return float(value.replace(',', '.'))
|
||||
except (ValueError, TypeError):
|
||||
return 0.0
|
||||
|
||||
def clean_madrid_xml(self, xml_content: str) -> str:
|
||||
"""Clean and prepare Madrid XML content for parsing"""
|
||||
if not xml_content:
|
||||
return ""
|
||||
|
||||
# Remove BOM and extra whitespace
|
||||
cleaned = xml_content.strip()
|
||||
if cleaned.startswith('\ufeff'):
|
||||
cleaned = cleaned[1:]
|
||||
|
||||
# Fix common XML issues
|
||||
cleaned = re.sub(r'&(?!amp;|lt;|gt;|quot;|apos;)', '&', cleaned)
|
||||
|
||||
# Ensure proper encoding declaration
|
||||
if not cleaned.startswith('<?xml'):
|
||||
cleaned = '<?xml version="1.0" encoding="UTF-8"?>\n' + cleaned
|
||||
|
||||
return cleaned
|
||||
|
||||
def convert_utm_to_latlon(self, utm_x: str, utm_y: str) -> Tuple[Optional[float], Optional[float]]:
|
||||
"""Convert UTM coordinates to latitude/longitude"""
|
||||
try:
|
||||
utm_x_float = float(utm_x.replace(',', '.'))
|
||||
utm_y_float = float(utm_y.replace(',', '.'))
|
||||
|
||||
# Convert from UTM Zone 30N to WGS84
|
||||
longitude, latitude = pyproj.transform(self.utm_proj, self.wgs84_proj, utm_x_float, utm_y_float)
|
||||
|
||||
# Validate coordinates are in Madrid area
|
||||
if 40.3 <= latitude <= 40.6 and -3.8 <= longitude <= -3.5:
|
||||
return latitude, longitude
|
||||
else:
|
||||
self.logger.debug("Coordinates outside Madrid bounds",
|
||||
lat=latitude, lon=longitude, utm_x=utm_x, utm_y=utm_y)
|
||||
return None, None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug("UTM conversion error",
|
||||
utm_x=utm_x, utm_y=utm_y, error=str(e))
|
||||
return None, None
|
||||
|
||||
def parse_traffic_xml(self, xml_content: str) -> List[Dict[str, Any]]:
|
||||
"""Parse Madrid traffic XML data"""
|
||||
traffic_points = []
|
||||
|
||||
try:
|
||||
cleaned_xml = self.clean_madrid_xml(xml_content)
|
||||
root = ET.fromstring(cleaned_xml)
|
||||
|
||||
self.logger.debug("Madrid XML structure", root_tag=root.tag, children_count=len(list(root)))
|
||||
|
||||
if root.tag == 'pms':
|
||||
pm_elements = root.findall('pm')
|
||||
self.logger.debug("Found PM elements", count=len(pm_elements))
|
||||
|
||||
for pm in pm_elements:
|
||||
try:
|
||||
traffic_point = self._extract_madrid_pm_element(pm)
|
||||
|
||||
if self._is_valid_traffic_point(traffic_point):
|
||||
traffic_points.append(traffic_point)
|
||||
|
||||
# Log first few points for debugging
|
||||
if len(traffic_points) <= 3:
|
||||
self.logger.debug("Sample traffic point",
|
||||
id=traffic_point['idelem'],
|
||||
lat=traffic_point['latitude'],
|
||||
lon=traffic_point['longitude'],
|
||||
intensity=traffic_point.get('intensidad'))
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug("Error parsing PM element", error=str(e))
|
||||
continue
|
||||
else:
|
||||
self.logger.warning("Unexpected XML root tag", root_tag=root.tag)
|
||||
|
||||
self.logger.debug("Madrid traffic XML parsing completed", valid_points=len(traffic_points))
|
||||
return traffic_points
|
||||
|
||||
except ET.ParseError as e:
|
||||
self.logger.warning("Failed to parse Madrid XML", error=str(e))
|
||||
return self._extract_traffic_data_regex(xml_content)
|
||||
except Exception as e:
|
||||
self.logger.error("Error in Madrid traffic XML parsing", error=str(e))
|
||||
return []
|
||||
|
||||
def _extract_madrid_pm_element(self, pm_element) -> Dict[str, Any]:
|
||||
"""Extract traffic data from Madrid <pm> element with coordinate conversion"""
|
||||
try:
|
||||
point_data = {}
|
||||
utm_x = utm_y = None
|
||||
|
||||
# Extract all child elements
|
||||
for child in pm_element:
|
||||
tag, text = child.tag, child.text.strip() if child.text else ''
|
||||
|
||||
if tag == 'idelem':
|
||||
point_data['idelem'] = text
|
||||
elif tag == 'descripcion':
|
||||
point_data['descripcion'] = text
|
||||
elif tag == 'intensidad':
|
||||
point_data['intensidad'] = self.safe_int(text)
|
||||
elif tag == 'ocupacion':
|
||||
point_data['ocupacion'] = self._safe_float(text)
|
||||
elif tag == 'carga':
|
||||
point_data['carga'] = self.safe_int(text)
|
||||
elif tag == 'nivelServicio':
|
||||
point_data['nivelServicio'] = self.safe_int(text)
|
||||
elif tag == 'st_x': # UTM X coordinate
|
||||
utm_x = text
|
||||
point_data['utm_x'] = text
|
||||
elif tag == 'st_y': # UTM Y coordinate
|
||||
utm_y = text
|
||||
point_data['utm_y'] = text
|
||||
elif tag == 'error':
|
||||
point_data['error'] = text
|
||||
elif tag in ['subarea', 'accesoAsociado', 'intensidadSat']:
|
||||
point_data[tag] = text
|
||||
|
||||
# Convert coordinates
|
||||
if utm_x and utm_y:
|
||||
latitude, longitude = self.convert_utm_to_latlon(utm_x, utm_y)
|
||||
|
||||
if latitude and longitude:
|
||||
point_data.update({
|
||||
'latitude': latitude,
|
||||
'longitude': longitude,
|
||||
'measurement_point_id': point_data.get('idelem'),
|
||||
'measurement_point_name': point_data.get('descripcion'),
|
||||
'timestamp': datetime.now(timezone.utc),
|
||||
'source': 'madrid_opendata_xml'
|
||||
})
|
||||
|
||||
return point_data
|
||||
else:
|
||||
self.logger.debug("Invalid coordinates after conversion",
|
||||
idelem=point_data.get('idelem'), utm_x=utm_x, utm_y=utm_y)
|
||||
return {}
|
||||
else:
|
||||
self.logger.debug("Missing UTM coordinates", idelem=point_data.get('idelem'))
|
||||
return {}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug("Error extracting PM element", error=str(e))
|
||||
return {}
|
||||
|
||||
def _is_valid_traffic_point(self, traffic_point: Dict[str, Any]) -> bool:
|
||||
"""Validate traffic point data"""
|
||||
required_fields = ['idelem', 'latitude', 'longitude']
|
||||
return all(field in traffic_point and traffic_point[field] for field in required_fields)
|
||||
|
||||
def _extract_traffic_data_regex(self, xml_content: str) -> List[Dict[str, Any]]:
|
||||
"""Fallback regex-based extraction if XML parsing fails"""
|
||||
traffic_points = []
|
||||
|
||||
try:
|
||||
# Pattern to match PM elements
|
||||
pm_pattern = r'<pm>(.*?)</pm>'
|
||||
pm_matches = re.findall(pm_pattern, xml_content, re.DOTALL)
|
||||
|
||||
for pm_content in pm_matches:
|
||||
traffic_point = {}
|
||||
|
||||
# Extract key fields
|
||||
patterns = {
|
||||
'idelem': r'<idelem>(.*?)</idelem>',
|
||||
'descripcion': r'<descripcion>(.*?)</descripcion>',
|
||||
'intensidad': r'<intensidad>(.*?)</intensidad>',
|
||||
'ocupacion': r'<ocupacion>(.*?)</ocupacion>',
|
||||
'st_x': r'<st_x>(.*?)</st_x>',
|
||||
'st_y': r'<st_y>(.*?)</st_y>'
|
||||
}
|
||||
|
||||
for field, pattern in patterns.items():
|
||||
match = re.search(pattern, pm_content)
|
||||
if match:
|
||||
traffic_point[field] = match.group(1).strip()
|
||||
|
||||
# Convert coordinates
|
||||
if 'st_x' in traffic_point and 'st_y' in traffic_point:
|
||||
latitude, longitude = self.convert_utm_to_latlon(
|
||||
traffic_point['st_x'], traffic_point['st_y']
|
||||
)
|
||||
|
||||
if latitude and longitude:
|
||||
traffic_point.update({
|
||||
'latitude': latitude,
|
||||
'longitude': longitude,
|
||||
'intensidad': self.safe_int(traffic_point.get('intensidad', '0')),
|
||||
'ocupacion': self._safe_float(traffic_point.get('ocupacion', '0')),
|
||||
'measurement_point_id': traffic_point.get('idelem'),
|
||||
'measurement_point_name': traffic_point.get('descripcion'),
|
||||
'timestamp': datetime.now(timezone.utc),
|
||||
'source': 'madrid_opendata_xml_regex'
|
||||
})
|
||||
|
||||
traffic_points.append(traffic_point)
|
||||
|
||||
self.logger.debug("Regex extraction completed", points=len(traffic_points))
|
||||
return traffic_points
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error in regex extraction", error=str(e))
|
||||
return []
|
||||
|
||||
def parse_measurement_points_csv(self, csv_content: str) -> Dict[str, Dict[str, Any]]:
|
||||
"""Parse measurement points CSV into lookup dictionary"""
|
||||
measurement_points = {}
|
||||
|
||||
try:
|
||||
# Parse CSV with semicolon delimiter
|
||||
csv_reader = csv.DictReader(io.StringIO(csv_content), delimiter=';')
|
||||
|
||||
processed_count = 0
|
||||
for row in csv_reader:
|
||||
try:
|
||||
# Extract point ID and coordinates
|
||||
point_id = row.get('id', '').strip()
|
||||
if not point_id:
|
||||
continue
|
||||
|
||||
processed_count += 1
|
||||
|
||||
# Try different coordinate field names
|
||||
lat_str = ''
|
||||
lon_str = ''
|
||||
|
||||
# Common coordinate field patterns
|
||||
lat_fields = ['lat', 'latitude', 'latitud', 'y', 'utm_y']
|
||||
lon_fields = ['lon', 'lng', 'longitude', 'longitud', 'x', 'utm_x']
|
||||
|
||||
for field in lat_fields:
|
||||
if field in row and row[field].strip():
|
||||
lat_str = row[field].strip()
|
||||
break
|
||||
|
||||
for field in lon_fields:
|
||||
if field in row and row[field].strip():
|
||||
lon_str = row[field].strip()
|
||||
break
|
||||
|
||||
if lat_str and lon_str:
|
||||
try:
|
||||
# Try direct lat/lon first
|
||||
latitude = self._safe_float(lat_str)
|
||||
longitude = self._safe_float(lon_str)
|
||||
|
||||
# If values look like UTM coordinates, convert them
|
||||
if latitude > 1000 or longitude > 1000:
|
||||
latitude, longitude = self.convert_utm_to_latlon(lon_str, lat_str)
|
||||
if not latitude or not longitude:
|
||||
continue
|
||||
|
||||
# Validate Madrid area
|
||||
if not (40.3 <= latitude <= 40.6 and -3.8 <= longitude <= -3.5):
|
||||
continue
|
||||
|
||||
measurement_points[point_id] = {
|
||||
'id': point_id,
|
||||
'latitude': latitude,
|
||||
'longitude': longitude,
|
||||
'name': row.get('nombre', row.get('descripcion', f"Point {point_id}")),
|
||||
'type': row.get('tipo', 'traffic'),
|
||||
'raw_data': dict(row) # Keep original data
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug("Error processing point coordinates",
|
||||
point_id=point_id, error=str(e))
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug("Error processing CSV row", error=str(e))
|
||||
continue
|
||||
|
||||
self.logger.info("Parsed measurement points registry",
|
||||
total_points=len(measurement_points))
|
||||
return measurement_points
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error parsing measurement points CSV", error=str(e))
|
||||
return {}
|
||||
|
||||
def calculate_data_quality_score(self, row: Dict[str, str]) -> float:
|
||||
"""Calculate data quality score for a traffic record"""
|
||||
try:
|
||||
score = 1.0
|
||||
|
||||
# Check for missing or invalid values
|
||||
intensidad = row.get('intensidad', '').strip()
|
||||
if not intensidad or intensidad in ['N', '', '0']:
|
||||
score *= 0.7
|
||||
|
||||
ocupacion = row.get('ocupacion', '').strip()
|
||||
if not ocupacion or ocupacion in ['N', '', '0']:
|
||||
score *= 0.8
|
||||
|
||||
error_status = row.get('error', '').strip()
|
||||
if error_status and error_status != 'N':
|
||||
score *= 0.6
|
||||
|
||||
# Check for reasonable value ranges
|
||||
try:
|
||||
intensidad_val = self.safe_int(intensidad)
|
||||
if intensidad_val < 0 or intensidad_val > 5000: # Unrealistic traffic volume
|
||||
score *= 0.7
|
||||
|
||||
ocupacion_val = self.safe_int(ocupacion)
|
||||
if ocupacion_val < 0 or ocupacion_val > 100: # Invalid percentage
|
||||
score *= 0.5
|
||||
|
||||
except:
|
||||
score *= 0.6
|
||||
|
||||
return max(0.1, score) # Minimum quality score
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug("Error calculating quality score", error=str(e))
|
||||
return 0.5 # Default medium quality
|
||||
|
||||
async def process_csv_content_chunked(self, text_content: str, csv_filename: str,
|
||||
nearest_ids: set, nearest_points: list) -> list:
|
||||
"""Process CSV content in chunks to prevent memory issues"""
|
||||
import csv
|
||||
import io
|
||||
import gc
|
||||
|
||||
try:
|
||||
csv_reader = csv.DictReader(io.StringIO(text_content), delimiter=';')
|
||||
|
||||
chunk_size = 10000
|
||||
chunk_records = []
|
||||
all_records = []
|
||||
processed_count = 0
|
||||
total_rows_seen = 0
|
||||
|
||||
for row in csv_reader:
|
||||
total_rows_seen += 1
|
||||
measurement_point_id = row.get('id', '').strip()
|
||||
|
||||
if measurement_point_id not in nearest_ids:
|
||||
continue
|
||||
|
||||
try:
|
||||
record_data = await self.parse_historical_csv_row(row, nearest_points)
|
||||
|
||||
if record_data:
|
||||
chunk_records.append(record_data)
|
||||
processed_count += 1
|
||||
|
||||
if len(chunk_records) >= chunk_size:
|
||||
all_records.extend(chunk_records)
|
||||
chunk_records = []
|
||||
gc.collect()
|
||||
|
||||
except Exception as e:
|
||||
if processed_count < 5:
|
||||
self.logger.error("Row parsing exception",
|
||||
row_num=total_rows_seen,
|
||||
measurement_point_id=measurement_point_id,
|
||||
error=str(e))
|
||||
continue
|
||||
|
||||
# Process remaining records
|
||||
if chunk_records:
|
||||
all_records.extend(chunk_records)
|
||||
chunk_records = []
|
||||
gc.collect()
|
||||
|
||||
self.logger.info("Processed CSV file",
|
||||
filename=csv_filename,
|
||||
total_rows_read=total_rows_seen,
|
||||
processed_records=processed_count)
|
||||
|
||||
return all_records
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error processing CSV content",
|
||||
filename=csv_filename, error=str(e))
|
||||
return []
|
||||
|
||||
async def parse_historical_csv_row(self, row: dict, nearest_points: list) -> dict:
|
||||
"""Parse a single row from Madrid's historical traffic CSV"""
|
||||
try:
|
||||
# Extract date
|
||||
fecha_str = row.get('fecha', '').strip()
|
||||
if not fecha_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
date_obj = datetime.strptime(fecha_str, '%Y-%m-%d %H:%M:%S')
|
||||
date_obj = date_obj.replace(tzinfo=timezone.utc)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
measurement_point_id = row.get('id', '').strip()
|
||||
|
||||
# Find point data
|
||||
point_match = next((p for p in nearest_points if p[0] == measurement_point_id), None)
|
||||
if not point_match:
|
||||
return None
|
||||
|
||||
point_data = point_match[1]
|
||||
distance_km = point_match[2]
|
||||
|
||||
# Extract traffic data
|
||||
intensidad = self.safe_int(row.get('intensidad', '0'))
|
||||
ocupacion = self.safe_int(row.get('ocupacion', '0'))
|
||||
carga = self.safe_int(row.get('carga', '0'))
|
||||
vmed = self.safe_int(row.get('vmed', '0'))
|
||||
|
||||
# Build basic result (business logic will be applied elsewhere)
|
||||
result = {
|
||||
'date': date_obj,
|
||||
'measurement_point_id': measurement_point_id,
|
||||
'point_data': point_data,
|
||||
'distance_km': distance_km,
|
||||
'traffic_data': {
|
||||
'intensidad': intensidad,
|
||||
'ocupacion': ocupacion,
|
||||
'carga': carga,
|
||||
'vmed': vmed
|
||||
},
|
||||
'data_quality_score': self.calculate_data_quality_score(row),
|
||||
'raw_row': row
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug("Error parsing historical CSV row", error=str(e))
|
||||
return None
|
||||
@@ -639,10 +639,17 @@ class EnhancedForecastingService:
|
||||
if precipitation > 2.0:
|
||||
adjustment_factor *= 0.7
|
||||
|
||||
# Apply adjustments
|
||||
# Apply adjustments to prediction
|
||||
adjusted_prediction = max(0, base_prediction * adjustment_factor)
|
||||
adjusted_lower = max(0, lower_bound * adjustment_factor)
|
||||
adjusted_upper = max(0, upper_bound * adjustment_factor)
|
||||
|
||||
# For confidence bounds, preserve relative interval width while respecting minimum bounds
|
||||
original_interval = upper_bound - lower_bound
|
||||
adjusted_interval = original_interval * adjustment_factor
|
||||
|
||||
# Ensure minimum reasonable lower bound (at least 20% of prediction or 5, whichever is larger)
|
||||
min_lower_bound = max(adjusted_prediction * 0.2, 5.0)
|
||||
adjusted_lower = max(min_lower_bound, adjusted_prediction - (adjusted_interval / 2))
|
||||
adjusted_upper = max(adjusted_lower + 10, adjusted_prediction + (adjusted_interval / 2))
|
||||
|
||||
return {
|
||||
"prediction": adjusted_prediction,
|
||||
|
||||
@@ -162,7 +162,8 @@ class BakeryProphetManager:
|
||||
'seasonality_mode': 'additive',
|
||||
'daily_seasonality': False,
|
||||
'weekly_seasonality': True,
|
||||
'yearly_seasonality': False
|
||||
'yearly_seasonality': False,
|
||||
'uncertainty_samples': 100 # ✅ FIX: Minimal uncertainty sampling for very sparse data
|
||||
}
|
||||
elif zero_ratio > 0.6:
|
||||
logger.info(f"Moderate sparsity for {product_name}, using conservative optimization")
|
||||
@@ -174,7 +175,8 @@ class BakeryProphetManager:
|
||||
'seasonality_mode': 'additive',
|
||||
'daily_seasonality': False,
|
||||
'weekly_seasonality': True,
|
||||
'yearly_seasonality': len(df) > 365 # Only if we have enough data
|
||||
'yearly_seasonality': len(df) > 365, # Only if we have enough data
|
||||
'uncertainty_samples': 200 # ✅ FIX: Conservative uncertainty sampling for moderately sparse data
|
||||
}
|
||||
|
||||
# Use unique seed for each product to avoid identical results
|
||||
@@ -196,6 +198,16 @@ class BakeryProphetManager:
|
||||
changepoint_scale_range = (0.001, 0.5)
|
||||
seasonality_scale_range = (0.01, 10.0)
|
||||
|
||||
# ✅ FIX: Determine appropriate uncertainty samples range based on product category
|
||||
if product_category == 'high_volume':
|
||||
uncertainty_range = (300, 800) # More samples for stable high-volume products
|
||||
elif product_category == 'medium_volume':
|
||||
uncertainty_range = (200, 500) # Moderate samples for medium volume
|
||||
elif product_category == 'low_volume':
|
||||
uncertainty_range = (150, 300) # Fewer samples for low volume
|
||||
else: # intermittent
|
||||
uncertainty_range = (100, 200) # Minimal samples for intermittent demand
|
||||
|
||||
params = {
|
||||
'changepoint_prior_scale': trial.suggest_float(
|
||||
'changepoint_prior_scale',
|
||||
@@ -214,7 +226,8 @@ class BakeryProphetManager:
|
||||
'seasonality_mode': 'additive' if product_category == 'high_volume' else trial.suggest_categorical('seasonality_mode', ['additive', 'multiplicative']),
|
||||
'daily_seasonality': trial.suggest_categorical('daily_seasonality', [True, False]),
|
||||
'weekly_seasonality': True, # Always keep weekly
|
||||
'yearly_seasonality': trial.suggest_categorical('yearly_seasonality', [True, False])
|
||||
'yearly_seasonality': trial.suggest_categorical('yearly_seasonality', [True, False]),
|
||||
'uncertainty_samples': trial.suggest_int('uncertainty_samples', uncertainty_range[0], uncertainty_range[1]) # ✅ FIX: Adaptive uncertainty sampling
|
||||
}
|
||||
|
||||
# Simple 2-fold cross-validation for speed
|
||||
@@ -229,8 +242,10 @@ class BakeryProphetManager:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Create and train model
|
||||
model = Prophet(**params, interval_width=0.8, uncertainty_samples=100)
|
||||
# Create and train model with adaptive uncertainty sampling
|
||||
uncertainty_samples = params.get('uncertainty_samples', 200) # ✅ FIX: Use adaptive uncertainty samples
|
||||
model = Prophet(**{k: v for k, v in params.items() if k != 'uncertainty_samples'},
|
||||
interval_width=0.8, uncertainty_samples=uncertainty_samples)
|
||||
|
||||
for regressor in regressor_columns:
|
||||
if regressor in train_data.columns:
|
||||
@@ -291,6 +306,12 @@ class BakeryProphetManager:
|
||||
|
||||
logger.info(f"Optimization completed for {product_name}. Best score: {best_score:.2f}%. "
|
||||
f"Parameters: {best_params}")
|
||||
|
||||
# ✅ FIX: Log uncertainty sampling configuration for debugging confidence intervals
|
||||
uncertainty_samples = best_params.get('uncertainty_samples', 500)
|
||||
logger.info(f"Prophet model will use {uncertainty_samples} uncertainty samples for {product_name} "
|
||||
f"(category: {product_category}, zero_ratio: {zero_ratio:.2f})")
|
||||
|
||||
return best_params
|
||||
|
||||
def _classify_product(self, product_name: str, sales_data: pd.DataFrame) -> str:
|
||||
@@ -329,9 +350,12 @@ class BakeryProphetManager:
|
||||
return 'intermittent'
|
||||
|
||||
def _create_optimized_prophet_model(self, optimized_params: Dict[str, Any], regressor_columns: List[str]) -> Prophet:
|
||||
"""Create Prophet model with optimized parameters"""
|
||||
"""Create Prophet model with optimized parameters and adaptive uncertainty sampling"""
|
||||
holidays = self._get_spanish_holidays()
|
||||
|
||||
# Determine uncertainty samples based on data characteristics
|
||||
uncertainty_samples = optimized_params.get('uncertainty_samples', 500)
|
||||
|
||||
model = Prophet(
|
||||
holidays=holidays if not holidays.empty else None,
|
||||
daily_seasonality=optimized_params.get('daily_seasonality', True),
|
||||
@@ -344,7 +368,7 @@ class BakeryProphetManager:
|
||||
changepoint_range=optimized_params.get('changepoint_range', 0.8),
|
||||
interval_width=0.8,
|
||||
mcmc_samples=0,
|
||||
uncertainty_samples=1000
|
||||
uncertainty_samples=uncertainty_samples
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user