Improve the traffic fetching system
This commit is contained in:
312
services/data/app/core/performance.py
Normal file
312
services/data/app/core/performance.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# ================================================================
|
||||
# services/data/app/core/performance.py
|
||||
# ================================================================
|
||||
"""
|
||||
Performance optimization utilities for async operations
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, Optional, TypeVar
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import hashlib
|
||||
import json
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class AsyncCache:
|
||||
"""Simple in-memory async cache with TTL"""
|
||||
|
||||
def __init__(self, default_ttl: int = 300):
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.default_ttl = default_ttl
|
||||
|
||||
def _generate_key(self, *args, **kwargs) -> str:
|
||||
"""Generate cache key from arguments"""
|
||||
key_data = {
|
||||
'args': args,
|
||||
'kwargs': sorted(kwargs.items())
|
||||
}
|
||||
key_string = json.dumps(key_data, sort_keys=True, default=str)
|
||||
return hashlib.md5(key_string.encode()).hexdigest()
|
||||
|
||||
def _is_expired(self, entry: Dict[str, Any]) -> bool:
|
||||
"""Check if cache entry is expired"""
|
||||
expires_at = entry.get('expires_at')
|
||||
if not expires_at:
|
||||
return True
|
||||
return datetime.now(timezone.utc) > expires_at
|
||||
|
||||
async def get(self, key: str) -> Optional[Any]:
|
||||
"""Get value from cache"""
|
||||
if key in self.cache:
|
||||
entry = self.cache[key]
|
||||
if not self._is_expired(entry):
|
||||
logger.debug("Cache hit", cache_key=key)
|
||||
return entry['value']
|
||||
else:
|
||||
# Clean up expired entry
|
||||
del self.cache[key]
|
||||
logger.debug("Cache expired", cache_key=key)
|
||||
|
||||
logger.debug("Cache miss", cache_key=key)
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
|
||||
"""Set value in cache"""
|
||||
ttl = ttl or self.default_ttl
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
|
||||
|
||||
self.cache[key] = {
|
||||
'value': value,
|
||||
'expires_at': expires_at,
|
||||
'created_at': datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
logger.debug("Cache set", cache_key=key, ttl=ttl)
|
||||
|
||||
async def clear(self) -> None:
|
||||
"""Clear all cache entries"""
|
||||
self.cache.clear()
|
||||
logger.info("Cache cleared")
|
||||
|
||||
async def cleanup_expired(self) -> int:
|
||||
"""Clean up expired entries"""
|
||||
expired_keys = [
|
||||
key for key, entry in self.cache.items()
|
||||
if self._is_expired(entry)
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.cache[key]
|
||||
|
||||
if expired_keys:
|
||||
logger.info("Cleaned up expired cache entries", count=len(expired_keys))
|
||||
|
||||
return len(expired_keys)
|
||||
|
||||
|
||||
def async_cache(ttl: int = 300, cache_instance: Optional[AsyncCache] = None):
|
||||
"""Decorator for caching async function results"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
_cache = cache_instance or AsyncCache(ttl)
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
# Generate cache key
|
||||
cache_key = _cache._generate_key(func.__name__, *args, **kwargs)
|
||||
|
||||
# Try to get from cache
|
||||
cached_result = await _cache.get(cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
# Execute function and cache result
|
||||
result = await func(*args, **kwargs)
|
||||
await _cache.set(cache_key, result, ttl)
|
||||
|
||||
return result
|
||||
|
||||
# Add cache management methods
|
||||
wrapper.cache_clear = _cache.clear
|
||||
wrapper.cache_cleanup = _cache.cleanup_expired
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ConnectionPool:
|
||||
"""Simple connection pool for HTTP clients"""
|
||||
|
||||
def __init__(self, max_connections: int = 10):
|
||||
self.max_connections = max_connections
|
||||
self.semaphore = asyncio.Semaphore(max_connections)
|
||||
self._active_connections = 0
|
||||
|
||||
async def acquire(self):
|
||||
"""Acquire a connection slot"""
|
||||
await self.semaphore.acquire()
|
||||
self._active_connections += 1
|
||||
logger.debug("Connection acquired", active=self._active_connections, max=self.max_connections)
|
||||
|
||||
async def release(self):
|
||||
"""Release a connection slot"""
|
||||
self.semaphore.release()
|
||||
self._active_connections = max(0, self._active_connections - 1)
|
||||
logger.debug("Connection released", active=self._active_connections, max=self.max_connections)
|
||||
|
||||
async def __aenter__(self):
|
||||
await self.acquire()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.release()
|
||||
|
||||
|
||||
def rate_limit(calls: int, period: int):
|
||||
"""Rate limiting decorator"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
call_times = []
|
||||
lock = asyncio.Lock()
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
async with lock:
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# Remove old call times
|
||||
cutoff = now - timedelta(seconds=period)
|
||||
call_times[:] = [t for t in call_times if t > cutoff]
|
||||
|
||||
# Check rate limit
|
||||
if len(call_times) >= calls:
|
||||
sleep_time = (call_times[0] + timedelta(seconds=period) - now).total_seconds()
|
||||
if sleep_time > 0:
|
||||
logger.warning("Rate limit reached, sleeping", sleep_time=sleep_time)
|
||||
await asyncio.sleep(sleep_time)
|
||||
|
||||
# Record this call
|
||||
call_times.append(now)
|
||||
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def batch_process(
|
||||
items: list,
|
||||
process_func: Callable,
|
||||
batch_size: int = 10,
|
||||
max_concurrency: int = 5
|
||||
) -> list:
|
||||
"""Process items in batches with controlled concurrency"""
|
||||
|
||||
results = []
|
||||
semaphore = asyncio.Semaphore(max_concurrency)
|
||||
|
||||
async def process_batch(batch):
|
||||
async with semaphore:
|
||||
return await process_func(batch)
|
||||
|
||||
# Create batches
|
||||
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
|
||||
|
||||
logger.info("Processing items in batches",
|
||||
total_items=len(items),
|
||||
batches=len(batches),
|
||||
batch_size=batch_size,
|
||||
max_concurrency=max_concurrency)
|
||||
|
||||
# Process batches concurrently
|
||||
batch_results = await asyncio.gather(
|
||||
*[process_batch(batch) for batch in batches],
|
||||
return_exceptions=True
|
||||
)
|
||||
|
||||
# Flatten results
|
||||
for batch_result in batch_results:
|
||||
if isinstance(batch_result, Exception):
|
||||
logger.error("Batch processing error", error=str(batch_result))
|
||||
continue
|
||||
|
||||
if isinstance(batch_result, list):
|
||||
results.extend(batch_result)
|
||||
else:
|
||||
results.append(batch_result)
|
||||
|
||||
logger.info("Batch processing completed",
|
||||
processed_items=len(results),
|
||||
total_batches=len(batches))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""Simple performance monitoring for async functions"""
|
||||
|
||||
def __init__(self):
|
||||
self.metrics = {}
|
||||
|
||||
def record_execution(self, func_name: str, duration: float, success: bool = True):
|
||||
"""Record function execution metrics"""
|
||||
if func_name not in self.metrics:
|
||||
self.metrics[func_name] = {
|
||||
'call_count': 0,
|
||||
'success_count': 0,
|
||||
'error_count': 0,
|
||||
'total_duration': 0.0,
|
||||
'min_duration': float('inf'),
|
||||
'max_duration': 0.0
|
||||
}
|
||||
|
||||
metric = self.metrics[func_name]
|
||||
metric['call_count'] += 1
|
||||
metric['total_duration'] += duration
|
||||
metric['min_duration'] = min(metric['min_duration'], duration)
|
||||
metric['max_duration'] = max(metric['max_duration'], duration)
|
||||
|
||||
if success:
|
||||
metric['success_count'] += 1
|
||||
else:
|
||||
metric['error_count'] += 1
|
||||
|
||||
def get_metrics(self, func_name: str = None) -> dict:
|
||||
"""Get performance metrics"""
|
||||
if func_name:
|
||||
metric = self.metrics.get(func_name, {})
|
||||
if metric and metric['call_count'] > 0:
|
||||
metric['avg_duration'] = metric['total_duration'] / metric['call_count']
|
||||
metric['success_rate'] = metric['success_count'] / metric['call_count']
|
||||
return metric
|
||||
|
||||
return self.metrics
|
||||
|
||||
|
||||
def monitor_performance(monitor: Optional[PerformanceMonitor] = None):
|
||||
"""Decorator to monitor function performance"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
_monitor = monitor or PerformanceMonitor()
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = datetime.now(timezone.utc)
|
||||
success = True
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
return result
|
||||
except Exception as e:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
end_time = datetime.now(timezone.utc)
|
||||
duration = (end_time - start_time).total_seconds()
|
||||
_monitor.record_execution(func.__name__, duration, success)
|
||||
|
||||
logger.debug("Function performance",
|
||||
function=func.__name__,
|
||||
duration=duration,
|
||||
success=success)
|
||||
|
||||
# Add metrics access
|
||||
wrapper.get_metrics = lambda: _monitor.get_metrics(func.__name__)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
# Global instances
|
||||
global_cache = AsyncCache(default_ttl=300)
|
||||
global_connection_pool = ConnectionPool(max_connections=20)
|
||||
global_performance_monitor = PerformanceMonitor()
|
||||
10
services/data/app/external/apis/__init__.py
vendored
Normal file
10
services/data/app/external/apis/__init__.py
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/apis/__init__.py
|
||||
# ================================================================
|
||||
"""
|
||||
External API clients module - Scalable architecture for multiple cities
|
||||
"""
|
||||
|
||||
from .traffic import TrafficAPIClientFactory
|
||||
|
||||
__all__ = ["TrafficAPIClientFactory"]
|
||||
1689
services/data/app/external/apis/madrid_traffic_client.py
vendored
Normal file
1689
services/data/app/external/apis/madrid_traffic_client.py
vendored
Normal file
File diff suppressed because it is too large
Load Diff
257
services/data/app/external/apis/traffic.py
vendored
Normal file
257
services/data/app/external/apis/traffic.py
vendored
Normal file
@@ -0,0 +1,257 @@
|
||||
# ================================================================
|
||||
# services/data/app/external/apis/traffic.py
|
||||
# ================================================================
|
||||
"""
|
||||
Traffic API abstraction layer for multiple cities
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class SupportedCity(Enum):
|
||||
"""Supported cities for traffic data collection"""
|
||||
MADRID = "madrid"
|
||||
BARCELONA = "barcelona"
|
||||
VALENCIA = "valencia"
|
||||
|
||||
|
||||
class BaseTrafficClient(ABC):
|
||||
"""
|
||||
Abstract base class for city-specific traffic clients
|
||||
Defines the contract that all traffic clients must implement
|
||||
"""
|
||||
|
||||
def __init__(self, city: SupportedCity):
|
||||
self.city = city
|
||||
self.logger = structlog.get_logger().bind(city=city.value)
|
||||
|
||||
@abstractmethod
|
||||
async def get_current_traffic(self, latitude: float, longitude: float) -> Optional[Dict[str, Any]]:
|
||||
"""Get current traffic data for location"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_historical_traffic(self, latitude: float, longitude: float,
|
||||
start_date: datetime, end_date: datetime) -> List[Dict[str, Any]]:
|
||||
"""Get historical traffic data"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def get_events(self, latitude: float, longitude: float, radius_km: float = 5.0) -> List[Dict[str, Any]]:
|
||||
"""Get traffic incidents and events"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_location(self, latitude: float, longitude: float) -> bool:
|
||||
"""Check if this client supports the given location"""
|
||||
pass
|
||||
|
||||
|
||||
class TrafficAPIClientFactory:
|
||||
"""
|
||||
Factory class to create appropriate traffic clients based on location
|
||||
"""
|
||||
|
||||
# City geographical bounds
|
||||
CITY_BOUNDS = {
|
||||
SupportedCity.MADRID: {
|
||||
'lat_min': 40.31, 'lat_max': 40.56,
|
||||
'lon_min': -3.89, 'lon_max': -3.51
|
||||
},
|
||||
SupportedCity.BARCELONA: {
|
||||
'lat_min': 41.32, 'lat_max': 41.47,
|
||||
'lon_min': 2.05, 'lon_max': 2.25
|
||||
},
|
||||
SupportedCity.VALENCIA: {
|
||||
'lat_min': 39.42, 'lat_max': 39.52,
|
||||
'lon_min': -0.42, 'lon_max': -0.32
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_client_for_location(cls, latitude: float, longitude: float) -> Optional[BaseTrafficClient]:
|
||||
"""
|
||||
Get appropriate traffic client for given location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
|
||||
Returns:
|
||||
BaseTrafficClient instance or None if location not supported
|
||||
"""
|
||||
try:
|
||||
# Check each city's bounds
|
||||
for city, bounds in cls.CITY_BOUNDS.items():
|
||||
if (bounds['lat_min'] <= latitude <= bounds['lat_max'] and
|
||||
bounds['lon_min'] <= longitude <= bounds['lon_max']):
|
||||
|
||||
logger.info("Location matched to city",
|
||||
city=city.value, lat=latitude, lon=longitude)
|
||||
return cls._create_client(city)
|
||||
|
||||
# If no specific city matches, try to find closest supported city
|
||||
closest_city = cls._find_closest_city(latitude, longitude)
|
||||
if closest_city:
|
||||
logger.info("Using closest city for location",
|
||||
closest_city=closest_city.value, lat=latitude, lon=longitude)
|
||||
return cls._create_client(closest_city)
|
||||
|
||||
logger.warning("No traffic client available for location",
|
||||
lat=latitude, lon=longitude)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error getting traffic client for location",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _create_client(cls, city: SupportedCity) -> BaseTrafficClient:
|
||||
"""Create traffic client for specific city"""
|
||||
if city == SupportedCity.MADRID:
|
||||
from .madrid_traffic_client import MadridTrafficClient
|
||||
return MadridTrafficClient()
|
||||
elif city == SupportedCity.BARCELONA:
|
||||
# Future implementation
|
||||
raise NotImplementedError(f"Traffic client for {city.value} not yet implemented")
|
||||
elif city == SupportedCity.VALENCIA:
|
||||
# Future implementation
|
||||
raise NotImplementedError(f"Traffic client for {city.value} not yet implemented")
|
||||
else:
|
||||
raise ValueError(f"Unsupported city: {city}")
|
||||
|
||||
@classmethod
|
||||
def _find_closest_city(cls, latitude: float, longitude: float) -> Optional[SupportedCity]:
|
||||
"""Find closest supported city to given coordinates"""
|
||||
import math
|
||||
|
||||
def distance(lat1, lon1, lat2, lon2):
|
||||
"""Calculate distance between two coordinates"""
|
||||
R = 6371 # Earth's radius in km
|
||||
dlat = math.radians(lat2 - lat1)
|
||||
dlon = math.radians(lon2 - lon1)
|
||||
a = (math.sin(dlat/2) * math.sin(dlat/2) +
|
||||
math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) *
|
||||
math.sin(dlon/2) * math.sin(dlon/2))
|
||||
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
|
||||
return R * c
|
||||
|
||||
min_distance = float('inf')
|
||||
closest_city = None
|
||||
|
||||
# City centers for distance calculation
|
||||
city_centers = {
|
||||
SupportedCity.MADRID: (40.4168, -3.7038),
|
||||
SupportedCity.BARCELONA: (41.3851, 2.1734),
|
||||
SupportedCity.VALENCIA: (39.4699, -0.3763)
|
||||
}
|
||||
|
||||
for city, (city_lat, city_lon) in city_centers.items():
|
||||
dist = distance(latitude, longitude, city_lat, city_lon)
|
||||
if dist < min_distance and dist < 100: # Within 100km
|
||||
min_distance = dist
|
||||
closest_city = city
|
||||
|
||||
return closest_city
|
||||
|
||||
@classmethod
|
||||
def get_supported_cities(cls) -> List[Dict[str, Any]]:
|
||||
"""Get list of supported cities with their bounds"""
|
||||
cities = []
|
||||
for city, bounds in cls.CITY_BOUNDS.items():
|
||||
cities.append({
|
||||
"city": city.value,
|
||||
"bounds": bounds,
|
||||
"status": "active" if city == SupportedCity.MADRID else "planned"
|
||||
})
|
||||
return cities
|
||||
|
||||
|
||||
class UniversalTrafficClient:
|
||||
"""
|
||||
Universal traffic client that delegates to appropriate city-specific clients
|
||||
This is the main interface that external services should use
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.factory = TrafficAPIClientFactory()
|
||||
self.client_cache = {} # Cache clients for performance
|
||||
|
||||
async def get_current_traffic(self, latitude: float, longitude: float) -> Optional[Dict[str, Any]]:
|
||||
"""Get current traffic data for any supported location"""
|
||||
try:
|
||||
client = self._get_client_for_location(latitude, longitude)
|
||||
if client:
|
||||
return await client.get_current_traffic(latitude, longitude)
|
||||
else:
|
||||
logger.warning("No traffic data available for location",
|
||||
lat=latitude, lon=longitude)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Error getting current traffic",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return None
|
||||
|
||||
async def get_historical_traffic(self, latitude: float, longitude: float,
|
||||
start_date: datetime, end_date: datetime) -> List[Dict[str, Any]]:
|
||||
"""Get historical traffic data for any supported location"""
|
||||
try:
|
||||
client = self._get_client_for_location(latitude, longitude)
|
||||
if client:
|
||||
return await client.get_historical_traffic(latitude, longitude, start_date, end_date)
|
||||
else:
|
||||
logger.warning("No historical traffic data available for location",
|
||||
lat=latitude, lon=longitude)
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error("Error getting historical traffic",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
async def get_events(self, latitude: float, longitude: float, radius_km: float = 5.0) -> List[Dict[str, Any]]:
|
||||
"""Get traffic events for any supported location"""
|
||||
try:
|
||||
client = self._get_client_for_location(latitude, longitude)
|
||||
if client:
|
||||
return await client.get_events(latitude, longitude, radius_km)
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error("Error getting traffic events",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
def _get_client_for_location(self, latitude: float, longitude: float) -> Optional[BaseTrafficClient]:
|
||||
"""Get cached or create new client for location"""
|
||||
cache_key = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
if cache_key not in self.client_cache:
|
||||
client = self.factory.get_client_for_location(latitude, longitude)
|
||||
self.client_cache[cache_key] = client
|
||||
|
||||
return self.client_cache[cache_key]
|
||||
|
||||
def get_location_info(self, latitude: float, longitude: float) -> Dict[str, Any]:
|
||||
"""Get information about traffic data availability for location"""
|
||||
client = self._get_client_for_location(latitude, longitude)
|
||||
if client:
|
||||
return {
|
||||
"supported": True,
|
||||
"city": client.city.value,
|
||||
"features": ["current_traffic", "historical_traffic", "events"]
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"supported": False,
|
||||
"city": None,
|
||||
"features": [],
|
||||
"message": "No traffic data available for this location"
|
||||
}
|
||||
28
services/data/app/external/base_client.py
vendored
28
services/data/app/external/base_client.py
vendored
@@ -54,6 +54,19 @@ class BaseAPIClient:
|
||||
logger.error("Unexpected error", error=str(e), url=url)
|
||||
return None
|
||||
|
||||
async def get(self, url: str, headers: Optional[Dict] = None, timeout: Optional[int] = None) -> httpx.Response:
|
||||
"""
|
||||
Public GET method for direct HTTP requests
|
||||
Returns the raw httpx Response object for maximum flexibility
|
||||
"""
|
||||
request_headers = headers or {}
|
||||
request_timeout = httpx.Timeout(timeout if timeout else 30.0)
|
||||
|
||||
async with httpx.AsyncClient(timeout=request_timeout, follow_redirects=True) as client:
|
||||
response = await client.get(url, headers=request_headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
|
||||
async def _fetch_url_directly(self, url: str, headers: Optional[Dict] = None) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch data directly from a full URL (for AEMET datos URLs)"""
|
||||
try:
|
||||
@@ -123,4 +136,17 @@ class BaseAPIClient:
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error("Unexpected error", error=str(e), url=url)
|
||||
return None
|
||||
return None
|
||||
|
||||
async def get(self, url: str, headers: Optional[Dict] = None, timeout: Optional[int] = None) -> httpx.Response:
|
||||
"""
|
||||
Public GET method for direct HTTP requests
|
||||
Returns the raw httpx Response object for maximum flexibility
|
||||
"""
|
||||
request_headers = headers or {}
|
||||
request_timeout = httpx.Timeout(timeout if timeout else 30.0)
|
||||
|
||||
async with httpx.AsyncClient(timeout=request_timeout, follow_redirects=True) as client:
|
||||
response = await client.get(url, headers=request_headers)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
1409
services/data/app/external/madrid_opendata.py
vendored
1409
services/data/app/external/madrid_opendata.py
vendored
File diff suppressed because it is too large
Load Diff
@@ -1,30 +1,294 @@
|
||||
# ================================================================
|
||||
# services/data/app/models/traffic.py
|
||||
# services/data/app/models/traffic.py - Enhanced for Multiple Cities
|
||||
# ================================================================
|
||||
"""Traffic data models"""
|
||||
"""
|
||||
Flexible traffic data models supporting multiple cities and extensible schemas
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index
|
||||
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index, Boolean, JSON
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from shared.database.base import Base
|
||||
|
||||
|
||||
class TrafficData(Base):
|
||||
"""
|
||||
Flexible traffic data model supporting multiple cities
|
||||
Designed to accommodate varying data structures across different cities
|
||||
"""
|
||||
__tablename__ = "traffic_data"
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
location_id = Column(String(100), nullable=False, index=True)
|
||||
|
||||
# Location and temporal data
|
||||
location_id = Column(String(100), nullable=False, index=True) # "lat,lon" or city-specific ID
|
||||
city = Column(String(50), nullable=False, index=True) # madrid, barcelona, valencia, etc.
|
||||
date = Column(DateTime(timezone=True), nullable=False, index=True)
|
||||
traffic_volume = Column(Integer, nullable=True) # vehicles per hour
|
||||
pedestrian_count = Column(Integer, nullable=True) # pedestrians per hour
|
||||
congestion_level = Column(String(20), nullable=True) # low/medium/high
|
||||
average_speed = Column(Float, nullable=True) # km/h
|
||||
source = Column(String(50), nullable=False, default="madrid_opendata")
|
||||
raw_data = Column(Text, nullable=True)
|
||||
|
||||
# Core standardized traffic metrics (common across all cities)
|
||||
traffic_volume = Column(Integer, nullable=True) # Vehicle count or intensity
|
||||
congestion_level = Column(String(20), nullable=True) # low, medium, high, blocked
|
||||
average_speed = Column(Float, nullable=True) # Average speed in km/h
|
||||
|
||||
# Enhanced metrics (may not be available for all cities)
|
||||
occupation_percentage = Column(Float, nullable=True) # Road occupation %
|
||||
load_percentage = Column(Float, nullable=True) # Traffic load %
|
||||
pedestrian_count = Column(Integer, nullable=True) # Estimated pedestrian count
|
||||
|
||||
# Measurement point information
|
||||
measurement_point_id = Column(String(100), nullable=True, index=True)
|
||||
measurement_point_name = Column(String(500), nullable=True)
|
||||
measurement_point_type = Column(String(50), nullable=True) # URB, M30, A, etc.
|
||||
|
||||
# Geographic data
|
||||
latitude = Column(Float, nullable=True)
|
||||
longitude = Column(Float, nullable=True)
|
||||
district = Column(String(100), nullable=True) # City district/area
|
||||
zone = Column(String(100), nullable=True) # Traffic zone or sector
|
||||
|
||||
# Data source and quality
|
||||
source = Column(String(50), nullable=False, default="unknown") # madrid_opendata, synthetic, etc.
|
||||
data_quality_score = Column(Float, nullable=True) # Quality score 0-100
|
||||
is_synthetic = Column(Boolean, default=False)
|
||||
has_pedestrian_inference = Column(Boolean, default=False)
|
||||
|
||||
# City-specific data (flexible JSON storage)
|
||||
city_specific_data = Column(JSON, nullable=True) # Store city-specific fields
|
||||
|
||||
# Raw data backup
|
||||
raw_data = Column(Text, nullable=True) # Original data for debugging
|
||||
|
||||
# Audit fields
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=True, index=True) # For multi-tenancy
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
# Performance-optimized indexes
|
||||
__table_args__ = (
|
||||
# Core query patterns
|
||||
Index('idx_traffic_location_date', 'location_id', 'date'),
|
||||
Index('idx_traffic_city_date', 'city', 'date'),
|
||||
Index('idx_traffic_tenant_date', 'tenant_id', 'date'),
|
||||
|
||||
# Advanced query patterns
|
||||
Index('idx_traffic_city_location', 'city', 'location_id'),
|
||||
Index('idx_traffic_measurement_point', 'city', 'measurement_point_id'),
|
||||
Index('idx_traffic_district_date', 'city', 'district', 'date'),
|
||||
|
||||
# Training data queries
|
||||
Index('idx_traffic_training', 'tenant_id', 'city', 'date', 'is_synthetic'),
|
||||
Index('idx_traffic_quality', 'city', 'data_quality_score', 'date'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert model to dictionary for API responses"""
|
||||
result = {
|
||||
'id': str(self.id),
|
||||
'location_id': self.location_id,
|
||||
'city': self.city,
|
||||
'date': self.date.isoformat() if self.date else None,
|
||||
'traffic_volume': self.traffic_volume,
|
||||
'congestion_level': self.congestion_level,
|
||||
'average_speed': self.average_speed,
|
||||
'occupation_percentage': self.occupation_percentage,
|
||||
'load_percentage': self.load_percentage,
|
||||
'pedestrian_count': self.pedestrian_count,
|
||||
'measurement_point_id': self.measurement_point_id,
|
||||
'measurement_point_name': self.measurement_point_name,
|
||||
'measurement_point_type': self.measurement_point_type,
|
||||
'latitude': self.latitude,
|
||||
'longitude': self.longitude,
|
||||
'district': self.district,
|
||||
'zone': self.zone,
|
||||
'source': self.source,
|
||||
'data_quality_score': self.data_quality_score,
|
||||
'is_synthetic': self.is_synthetic,
|
||||
'has_pedestrian_inference': self.has_pedestrian_inference,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
|
||||
# Add city-specific data if present
|
||||
if self.city_specific_data:
|
||||
result['city_specific_data'] = self.city_specific_data
|
||||
|
||||
return result
|
||||
|
||||
def get_city_specific_field(self, field_name: str, default: Any = None) -> Any:
|
||||
"""Safely get city-specific field value"""
|
||||
if self.city_specific_data and isinstance(self.city_specific_data, dict):
|
||||
return self.city_specific_data.get(field_name, default)
|
||||
return default
|
||||
|
||||
def set_city_specific_field(self, field_name: str, value: Any) -> None:
|
||||
"""Set city-specific field value"""
|
||||
if not self.city_specific_data:
|
||||
self.city_specific_data = {}
|
||||
if not isinstance(self.city_specific_data, dict):
|
||||
self.city_specific_data = {}
|
||||
self.city_specific_data[field_name] = value
|
||||
|
||||
|
||||
class TrafficMeasurementPoint(Base):
|
||||
"""
|
||||
Registry of traffic measurement points across all cities
|
||||
Supports different city-specific measurement point schemas
|
||||
"""
|
||||
__tablename__ = "traffic_measurement_points"
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Location and identification
|
||||
city = Column(String(50), nullable=False, index=True)
|
||||
measurement_point_id = Column(String(100), nullable=False, index=True) # City-specific ID
|
||||
name = Column(String(500), nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Geographic information
|
||||
latitude = Column(Float, nullable=False)
|
||||
longitude = Column(Float, nullable=False)
|
||||
district = Column(String(100), nullable=True)
|
||||
zone = Column(String(100), nullable=True)
|
||||
|
||||
# Classification
|
||||
road_type = Column(String(50), nullable=True) # URB, M30, A, etc.
|
||||
measurement_type = Column(String(50), nullable=True) # intensity, speed, etc.
|
||||
point_category = Column(String(50), nullable=True) # urban, highway, ring_road
|
||||
|
||||
# Status and metadata
|
||||
is_active = Column(Boolean, default=True)
|
||||
installation_date = Column(DateTime(timezone=True), nullable=True)
|
||||
last_data_received = Column(DateTime(timezone=True), nullable=True)
|
||||
data_quality_rating = Column(Float, nullable=True) # Average quality 0-100
|
||||
|
||||
# City-specific point data
|
||||
city_specific_metadata = Column(JSON, nullable=True)
|
||||
|
||||
# Audit fields
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_traffic_location_date', 'location_id', 'date'),
|
||||
# Ensure unique measurement points per city
|
||||
Index('idx_unique_city_point', 'city', 'measurement_point_id', unique=True),
|
||||
|
||||
# Geographic queries
|
||||
Index('idx_points_city_location', 'city', 'latitude', 'longitude'),
|
||||
Index('idx_points_district', 'city', 'district'),
|
||||
Index('idx_points_road_type', 'city', 'road_type'),
|
||||
|
||||
# Status queries
|
||||
Index('idx_points_active', 'city', 'is_active', 'last_data_received'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert measurement point to dictionary"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'city': self.city,
|
||||
'measurement_point_id': self.measurement_point_id,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'latitude': self.latitude,
|
||||
'longitude': self.longitude,
|
||||
'district': self.district,
|
||||
'zone': self.zone,
|
||||
'road_type': self.road_type,
|
||||
'measurement_type': self.measurement_type,
|
||||
'point_category': self.point_category,
|
||||
'is_active': self.is_active,
|
||||
'installation_date': self.installation_date.isoformat() if self.installation_date else None,
|
||||
'last_data_received': self.last_data_received.isoformat() if self.last_data_received else None,
|
||||
'data_quality_rating': self.data_quality_rating,
|
||||
'city_specific_metadata': self.city_specific_metadata,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None
|
||||
}
|
||||
|
||||
|
||||
class TrafficDataBackgroundJob(Base):
|
||||
"""
|
||||
Track background data collection jobs for multiple cities
|
||||
Supports scheduling and monitoring of data fetching processes
|
||||
"""
|
||||
__tablename__ = "traffic_background_jobs"
|
||||
|
||||
# Primary identification
|
||||
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
|
||||
|
||||
# Job configuration
|
||||
job_type = Column(String(50), nullable=False) # historical_fetch, cleanup, etc.
|
||||
city = Column(String(50), nullable=False, index=True)
|
||||
location_pattern = Column(String(200), nullable=True) # Location pattern or specific coords
|
||||
|
||||
# Scheduling
|
||||
scheduled_at = Column(DateTime(timezone=True), nullable=False)
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Status tracking
|
||||
status = Column(String(20), nullable=False, default='pending') # pending, running, completed, failed
|
||||
progress_percentage = Column(Float, default=0.0)
|
||||
records_processed = Column(Integer, default=0)
|
||||
records_stored = Column(Integer, default=0)
|
||||
|
||||
# Date range for data jobs
|
||||
data_start_date = Column(DateTime(timezone=True), nullable=True)
|
||||
data_end_date = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Results and error handling
|
||||
success_count = Column(Integer, default=0)
|
||||
error_count = Column(Integer, default=0)
|
||||
error_message = Column(Text, nullable=True)
|
||||
job_metadata = Column(JSON, nullable=True) # Additional job-specific data
|
||||
|
||||
# Tenant association
|
||||
tenant_id = Column(UUID(as_uuid=True), nullable=True, index=True)
|
||||
|
||||
# Audit fields
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
|
||||
updated_at = Column(DateTime(timezone=True),
|
||||
default=lambda: datetime.now(timezone.utc),
|
||||
onupdate=lambda: datetime.now(timezone.utc))
|
||||
|
||||
__table_args__ = (
|
||||
# Job monitoring
|
||||
Index('idx_jobs_city_status', 'city', 'status', 'scheduled_at'),
|
||||
Index('idx_jobs_tenant_status', 'tenant_id', 'status', 'scheduled_at'),
|
||||
Index('idx_jobs_type_city', 'job_type', 'city', 'scheduled_at'),
|
||||
|
||||
# Cleanup queries
|
||||
Index('idx_jobs_completed', 'status', 'completed_at'),
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert job to dictionary"""
|
||||
return {
|
||||
'id': str(self.id),
|
||||
'job_type': self.job_type,
|
||||
'city': self.city,
|
||||
'location_pattern': self.location_pattern,
|
||||
'scheduled_at': self.scheduled_at.isoformat() if self.scheduled_at else None,
|
||||
'started_at': self.started_at.isoformat() if self.started_at else None,
|
||||
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
|
||||
'status': self.status,
|
||||
'progress_percentage': self.progress_percentage,
|
||||
'records_processed': self.records_processed,
|
||||
'records_stored': self.records_stored,
|
||||
'data_start_date': self.data_start_date.isoformat() if self.data_start_date else None,
|
||||
'data_end_date': self.data_end_date.isoformat() if self.data_end_date else None,
|
||||
'success_count': self.success_count,
|
||||
'error_count': self.error_count,
|
||||
'error_message': self.error_message,
|
||||
'job_metadata': self.job_metadata,
|
||||
'tenant_id': str(self.tenant_id) if self.tenant_id else None,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
|
||||
874
services/data/app/repositories/traffic_repository.py
Normal file
874
services/data/app/repositories/traffic_repository.py
Normal file
@@ -0,0 +1,874 @@
|
||||
# ================================================================
|
||||
# services/data/app/repositories/traffic_repository.py
|
||||
# ================================================================
|
||||
"""
|
||||
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
|
||||
Follows existing repository architecture while adding city-specific functionality
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
|
||||
from sqlalchemy.orm import selectinload
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import DataBaseRepository
|
||||
from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataBackgroundJob
|
||||
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrafficRepository(DataBaseRepository[TrafficData, TrafficDataCreate, Dict]):
|
||||
"""
|
||||
Enhanced repository for traffic data operations across multiple cities
|
||||
Provides city-aware queries and advanced traffic analytics
|
||||
"""
|
||||
|
||||
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
# ================================================================
|
||||
# CORE TRAFFIC DATA OPERATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_by_location_and_date_range(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
city: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrafficData]:
|
||||
"""Get traffic data by location and date range with city filtering"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
# Build base query
|
||||
query = select(self.model).where(self.model.location_id == location_id)
|
||||
|
||||
# Add city filter if specified
|
||||
if city:
|
||||
query = query.where(self.model.city == city)
|
||||
|
||||
# Add tenant filter if specified
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filters
|
||||
if start_date:
|
||||
start_date = self._ensure_utc_datetime(start_date)
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_date = self._ensure_utc_datetime(end_date)
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Order by date descending (most recent first)
|
||||
query = query.order_by(desc(self.model.date))
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic data by location and date range",
|
||||
latitude=latitude, longitude=longitude,
|
||||
city=city, error=str(e))
|
||||
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
||||
|
||||
async def get_by_city_and_date_range(
|
||||
self,
|
||||
city: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
district: Optional[str] = None,
|
||||
measurement_point_ids: Optional[List[str]] = None,
|
||||
include_synthetic: bool = True,
|
||||
tenant_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 1000
|
||||
) -> List[TrafficData]:
|
||||
"""Get traffic data by city with advanced filtering options"""
|
||||
try:
|
||||
# Build base query
|
||||
query = select(self.model).where(self.model.city == city)
|
||||
|
||||
# Add tenant filter if specified
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filters
|
||||
if start_date:
|
||||
start_date = self._ensure_utc_datetime(start_date)
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_date = self._ensure_utc_datetime(end_date)
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Add district filter
|
||||
if district:
|
||||
query = query.where(self.model.district == district)
|
||||
|
||||
# Add measurement point filter
|
||||
if measurement_point_ids:
|
||||
query = query.where(self.model.measurement_point_id.in_(measurement_point_ids))
|
||||
|
||||
# Filter synthetic data if requested
|
||||
if not include_synthetic:
|
||||
query = query.where(self.model.is_synthetic == False)
|
||||
|
||||
# Order by date and measurement point
|
||||
query = query.order_by(desc(self.model.date), self.model.measurement_point_id)
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic data by city",
|
||||
city=city, district=district, error=str(e))
|
||||
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
||||
|
||||
async def get_latest_by_measurement_points(
|
||||
self,
|
||||
measurement_point_ids: List[str],
|
||||
city: str,
|
||||
hours_back: int = 24
|
||||
) -> List[TrafficData]:
|
||||
"""Get latest traffic data for specific measurement points"""
|
||||
try:
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
|
||||
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.measurement_point_id.in_(measurement_point_ids),
|
||||
self.model.date >= cutoff_time
|
||||
)
|
||||
).order_by(
|
||||
self.model.measurement_point_id,
|
||||
desc(self.model.date)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
all_records = result.scalars().all()
|
||||
|
||||
# Get the latest record for each measurement point
|
||||
latest_records = {}
|
||||
for record in all_records:
|
||||
point_id = record.measurement_point_id
|
||||
if point_id not in latest_records:
|
||||
latest_records[point_id] = record
|
||||
|
||||
return list(latest_records.values())
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest traffic data by measurement points",
|
||||
city=city, points=len(measurement_point_ids), error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest traffic data: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# ANALYTICS AND AGGREGATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_traffic_statistics_by_city(
|
||||
self,
|
||||
city: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
group_by: str = "daily"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get aggregated traffic statistics by city"""
|
||||
try:
|
||||
# Determine date truncation based on group_by
|
||||
if group_by == "hourly":
|
||||
date_trunc = "hour"
|
||||
elif group_by == "daily":
|
||||
date_trunc = "day"
|
||||
elif group_by == "weekly":
|
||||
date_trunc = "week"
|
||||
elif group_by == "monthly":
|
||||
date_trunc = "month"
|
||||
else:
|
||||
raise ValidationError(f"Invalid group_by value: {group_by}")
|
||||
|
||||
# Build aggregation query
|
||||
if self.session.bind.dialect.name == 'postgresql':
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE_TRUNC(:date_trunc, date) as period,
|
||||
city,
|
||||
district,
|
||||
COUNT(*) as record_count,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
MAX(traffic_volume) as max_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count,
|
||||
AVG(average_speed) as avg_speed,
|
||||
COUNT(CASE WHEN congestion_level = 'high' THEN 1 END) as high_congestion_count,
|
||||
COUNT(CASE WHEN is_synthetic = false THEN 1 END) as real_data_count,
|
||||
COUNT(CASE WHEN has_pedestrian_inference = true THEN 1 END) as pedestrian_inference_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
""")
|
||||
else:
|
||||
# SQLite fallback
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE(date) as period,
|
||||
city,
|
||||
district,
|
||||
COUNT(*) as record_count,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
MAX(traffic_volume) as max_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count,
|
||||
AVG(average_speed) as avg_speed,
|
||||
SUM(CASE WHEN congestion_level = 'high' THEN 1 ELSE 0 END) as high_congestion_count,
|
||||
SUM(CASE WHEN is_synthetic = 0 THEN 1 ELSE 0 END) as real_data_count,
|
||||
SUM(CASE WHEN has_pedestrian_inference = 1 THEN 1 ELSE 0 END) as pedestrian_inference_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"date_trunc": date_trunc
|
||||
}
|
||||
|
||||
# Add date filters
|
||||
if start_date:
|
||||
query = text(str(query) + " AND date >= :start_date")
|
||||
params["start_date"] = self._ensure_utc_datetime(start_date)
|
||||
|
||||
if end_date:
|
||||
query = text(str(query) + " AND date <= :end_date")
|
||||
params["end_date"] = self._ensure_utc_datetime(end_date)
|
||||
|
||||
# Add GROUP BY and ORDER BY
|
||||
query = text(str(query) + " GROUP BY period, city, district ORDER BY period DESC")
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert to list of dictionaries
|
||||
statistics = []
|
||||
for row in rows:
|
||||
statistics.append({
|
||||
"period": group_by,
|
||||
"date": row.period,
|
||||
"city": row.city,
|
||||
"district": row.district,
|
||||
"record_count": row.record_count,
|
||||
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
||||
"max_traffic_volume": row.max_traffic_volume or 0,
|
||||
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0),
|
||||
"avg_speed": float(row.avg_speed or 0),
|
||||
"high_congestion_count": row.high_congestion_count or 0,
|
||||
"real_data_percentage": round((row.real_data_count or 0) / max(1, row.record_count) * 100, 2),
|
||||
"pedestrian_inference_percentage": round((row.pedestrian_inference_count or 0) / max(1, row.record_count) * 100, 2)
|
||||
})
|
||||
|
||||
return statistics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic statistics by city",
|
||||
city=city, group_by=group_by, error=str(e))
|
||||
raise DatabaseError(f"Traffic statistics query failed: {str(e)}")
|
||||
|
||||
async def get_congestion_heatmap_data(
|
||||
self,
|
||||
city: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
time_granularity: str = "hour"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get congestion data for heatmap visualization"""
|
||||
try:
|
||||
if time_granularity == "hour":
|
||||
time_extract = "EXTRACT(hour FROM date)"
|
||||
elif time_granularity == "day_of_week":
|
||||
time_extract = "EXTRACT(dow FROM date)"
|
||||
else:
|
||||
time_extract = "EXTRACT(hour FROM date)"
|
||||
|
||||
query = text(f"""
|
||||
SELECT
|
||||
{time_extract} as time_period,
|
||||
district,
|
||||
measurement_point_id,
|
||||
latitude,
|
||||
longitude,
|
||||
AVG(CASE
|
||||
WHEN congestion_level = 'low' THEN 1
|
||||
WHEN congestion_level = 'medium' THEN 2
|
||||
WHEN congestion_level = 'high' THEN 3
|
||||
WHEN congestion_level = 'blocked' THEN 4
|
||||
ELSE 1
|
||||
END) as avg_congestion_score,
|
||||
COUNT(*) as data_points,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
AND date >= :start_date
|
||||
AND date <= :end_date
|
||||
AND latitude IS NOT NULL
|
||||
AND longitude IS NOT NULL
|
||||
GROUP BY time_period, district, measurement_point_id, latitude, longitude
|
||||
ORDER BY time_period, district, avg_congestion_score DESC
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"start_date": self._ensure_utc_datetime(start_date),
|
||||
"end_date": self._ensure_utc_datetime(end_date)
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
heatmap_data = []
|
||||
for row in rows:
|
||||
heatmap_data.append({
|
||||
"time_period": int(row.time_period or 0),
|
||||
"district": row.district,
|
||||
"measurement_point_id": row.measurement_point_id,
|
||||
"latitude": float(row.latitude),
|
||||
"longitude": float(row.longitude),
|
||||
"avg_congestion_score": float(row.avg_congestion_score),
|
||||
"data_points": row.data_points,
|
||||
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
||||
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0)
|
||||
})
|
||||
|
||||
return heatmap_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get congestion heatmap data",
|
||||
city=city, error=str(e))
|
||||
raise DatabaseError(f"Congestion heatmap query failed: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# BULK OPERATIONS AND DATA MANAGEMENT
|
||||
# ================================================================
|
||||
|
||||
async def create_bulk_traffic_data(
|
||||
self,
|
||||
traffic_records: List[Dict[str, Any]],
|
||||
city: str,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[TrafficData]:
|
||||
"""Create multiple traffic records in bulk with enhanced validation"""
|
||||
try:
|
||||
# Ensure all records have city and tenant_id
|
||||
for record in traffic_records:
|
||||
record["city"] = city
|
||||
if tenant_id:
|
||||
record["tenant_id"] = tenant_id
|
||||
# Ensure dates are timezone-aware
|
||||
if "date" in record and record["date"]:
|
||||
record["date"] = self._ensure_utc_datetime(record["date"])
|
||||
|
||||
# Enhanced validation
|
||||
validated_records = []
|
||||
for record in traffic_records:
|
||||
if self._validate_traffic_record(record):
|
||||
validated_records.append(record)
|
||||
else:
|
||||
logger.warning("Invalid traffic record skipped",
|
||||
city=city, record_keys=list(record.keys()))
|
||||
|
||||
if not validated_records:
|
||||
logger.warning("No valid traffic records to create", city=city)
|
||||
return []
|
||||
|
||||
# Use bulk create with deduplication
|
||||
created_records = await self.bulk_create_with_deduplication(validated_records)
|
||||
|
||||
logger.info("Bulk traffic data creation completed",
|
||||
city=city, requested=len(traffic_records),
|
||||
validated=len(validated_records), created=len(created_records))
|
||||
|
||||
return created_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create bulk traffic data",
|
||||
city=city, record_count=len(traffic_records), error=str(e))
|
||||
raise DatabaseError(f"Bulk traffic creation failed: {str(e)}")
|
||||
|
||||
async def bulk_create_with_deduplication(
|
||||
self,
|
||||
records: List[Dict[str, Any]]
|
||||
) -> List[TrafficData]:
|
||||
"""Bulk create with automatic deduplication based on location, city, and date"""
|
||||
try:
|
||||
if not records:
|
||||
return []
|
||||
|
||||
# Extract unique keys for deduplication check
|
||||
unique_keys = []
|
||||
for record in records:
|
||||
key = (
|
||||
record.get('location_id'),
|
||||
record.get('city'),
|
||||
record.get('date'),
|
||||
record.get('measurement_point_id')
|
||||
)
|
||||
unique_keys.append(key)
|
||||
|
||||
# Check for existing records
|
||||
location_ids = [key[0] for key in unique_keys if key[0]]
|
||||
cities = [key[1] for key in unique_keys if key[1]]
|
||||
dates = [key[2] for key in unique_keys if key[2]]
|
||||
|
||||
# For large datasets, use chunked deduplication to avoid memory issues
|
||||
if len(location_ids) > 1000:
|
||||
logger.info(f"Large dataset detected ({len(records)} records), using chunked deduplication")
|
||||
new_records = []
|
||||
chunk_size = 1000
|
||||
|
||||
for i in range(0, len(records), chunk_size):
|
||||
chunk_records = records[i:i + chunk_size]
|
||||
chunk_keys = unique_keys[i:i + chunk_size]
|
||||
|
||||
# Get unique values for this chunk
|
||||
chunk_location_ids = list(set(key[0] for key in chunk_keys if key[0]))
|
||||
chunk_cities = list(set(key[1] for key in chunk_keys if key[1]))
|
||||
chunk_dates = list(set(key[2] for key in chunk_keys if key[2]))
|
||||
|
||||
if chunk_location_ids and chunk_cities and chunk_dates:
|
||||
existing_query = select(
|
||||
self.model.location_id,
|
||||
self.model.city,
|
||||
self.model.date,
|
||||
self.model.measurement_point_id
|
||||
).where(
|
||||
and_(
|
||||
self.model.location_id.in_(chunk_location_ids),
|
||||
self.model.city.in_(chunk_cities),
|
||||
self.model.date.in_(chunk_dates)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(existing_query)
|
||||
chunk_existing_keys = set(result.fetchall())
|
||||
|
||||
# Filter chunk duplicates
|
||||
for j, record in enumerate(chunk_records):
|
||||
key = chunk_keys[j]
|
||||
if key not in chunk_existing_keys:
|
||||
new_records.append(record)
|
||||
else:
|
||||
new_records.extend(chunk_records)
|
||||
|
||||
logger.debug("Chunked deduplication completed",
|
||||
total_records=len(records),
|
||||
new_records=len(new_records))
|
||||
records = new_records
|
||||
|
||||
elif location_ids and cities and dates:
|
||||
existing_query = select(
|
||||
self.model.location_id,
|
||||
self.model.city,
|
||||
self.model.date,
|
||||
self.model.measurement_point_id
|
||||
).where(
|
||||
and_(
|
||||
self.model.location_id.in_(location_ids),
|
||||
self.model.city.in_(cities),
|
||||
self.model.date.in_(dates)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(existing_query)
|
||||
existing_keys = set(result.fetchall())
|
||||
|
||||
# Filter out duplicates
|
||||
new_records = []
|
||||
for i, record in enumerate(records):
|
||||
key = unique_keys[i]
|
||||
if key not in existing_keys:
|
||||
new_records.append(record)
|
||||
|
||||
logger.debug("Standard deduplication completed",
|
||||
total_records=len(records),
|
||||
existing_records=len(existing_keys),
|
||||
new_records=len(new_records))
|
||||
|
||||
records = new_records
|
||||
|
||||
# Proceed with bulk creation
|
||||
return await self.bulk_create(records)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed bulk create with deduplication", error=str(e))
|
||||
raise DatabaseError(f"Bulk create with deduplication failed: {str(e)}")
|
||||
|
||||
def _validate_traffic_record(self, record: Dict[str, Any]) -> bool:
|
||||
"""Enhanced validation for traffic records"""
|
||||
required_fields = ['date', 'city']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if not record.get(field):
|
||||
return False
|
||||
|
||||
# Validate city
|
||||
city = record.get('city', '').lower()
|
||||
if city not in ['madrid', 'barcelona', 'valencia', 'test']: # Extendable list
|
||||
return False
|
||||
|
||||
# Validate data ranges
|
||||
traffic_volume = record.get('traffic_volume')
|
||||
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 50000):
|
||||
return False
|
||||
|
||||
pedestrian_count = record.get('pedestrian_count')
|
||||
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
|
||||
return False
|
||||
|
||||
average_speed = record.get('average_speed')
|
||||
if average_speed is not None and (average_speed < 0 or average_speed > 200):
|
||||
return False
|
||||
|
||||
congestion_level = record.get('congestion_level')
|
||||
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# ================================================================
|
||||
# TRAINING DATA SPECIFIC OPERATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_training_data_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None,
|
||||
include_pedestrian_inference: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get optimized training data for ML models"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.location_id == location_id,
|
||||
self.model.date >= self._ensure_utc_datetime(start_date),
|
||||
self.model.date <= self._ensure_utc_datetime(end_date)
|
||||
)
|
||||
)
|
||||
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
if include_pedestrian_inference:
|
||||
# Prefer records with pedestrian inference
|
||||
query = query.order_by(
|
||||
desc(self.model.has_pedestrian_inference),
|
||||
desc(self.model.data_quality_score),
|
||||
self.model.date
|
||||
)
|
||||
else:
|
||||
query = query.order_by(
|
||||
desc(self.model.data_quality_score),
|
||||
self.model.date
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
# Convert to training format with enhanced features
|
||||
training_data = []
|
||||
for record in records:
|
||||
training_record = {
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume or 0,
|
||||
'pedestrian_count': record.pedestrian_count or 0,
|
||||
'congestion_level': record.congestion_level or 'medium',
|
||||
'average_speed': record.average_speed or 25.0,
|
||||
'city': record.city,
|
||||
'district': record.district,
|
||||
'measurement_point_id': record.measurement_point_id,
|
||||
'source': record.source,
|
||||
'is_synthetic': record.is_synthetic or False,
|
||||
'has_pedestrian_inference': record.has_pedestrian_inference or False,
|
||||
'data_quality_score': record.data_quality_score or 50.0,
|
||||
|
||||
# Enhanced features for training
|
||||
'hour_of_day': record.date.hour if record.date else 12,
|
||||
'day_of_week': record.date.weekday() if record.date else 0,
|
||||
'month': record.date.month if record.date else 1,
|
||||
|
||||
# City-specific features
|
||||
'city_specific_data': record.city_specific_data or {}
|
||||
}
|
||||
|
||||
training_data.append(training_record)
|
||||
|
||||
logger.info("Retrieved training data",
|
||||
location_id=location_id, records=len(training_data),
|
||||
with_pedestrian_inference=sum(1 for r in training_data if r['has_pedestrian_inference']))
|
||||
|
||||
return training_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training data",
|
||||
latitude=latitude, longitude=longitude, error=str(e))
|
||||
raise DatabaseError(f"Training data retrieval failed: {str(e)}")
|
||||
|
||||
async def get_historical_data_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[TrafficData]:
|
||||
"""Get historical traffic data for a specific location and date range"""
|
||||
return await self.get_by_location_and_date_range(
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
tenant_id=tenant_id,
|
||||
limit=1000000 # Large limit for historical data
|
||||
)
|
||||
|
||||
async def count_records_in_period(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
city: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count traffic records for a specific location and time period"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
query = select(func.count(self.model.id)).where(
|
||||
and_(
|
||||
self.model.location_id == location_id,
|
||||
self.model.date >= self._ensure_utc_datetime(start_date),
|
||||
self.model.date <= self._ensure_utc_datetime(end_date)
|
||||
)
|
||||
)
|
||||
|
||||
if city:
|
||||
query = query.where(self.model.city == city)
|
||||
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
count = result.scalar()
|
||||
|
||||
return count or 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to count records in period",
|
||||
latitude=latitude, longitude=longitude, error=str(e))
|
||||
raise DatabaseError(f"Record count failed: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# DATA QUALITY AND MAINTENANCE
|
||||
# ================================================================
|
||||
|
||||
async def update_data_quality_scores(self, city: str) -> int:
|
||||
"""Update data quality scores based on various criteria"""
|
||||
try:
|
||||
# Calculate quality scores based on data completeness and consistency
|
||||
query = text("""
|
||||
UPDATE traffic_data
|
||||
SET data_quality_score = (
|
||||
CASE
|
||||
WHEN traffic_volume IS NOT NULL THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN pedestrian_count IS NOT NULL THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN average_speed IS NOT NULL AND average_speed > 0 THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN congestion_level IS NOT NULL THEN 15 ELSE 0 END +
|
||||
CASE
|
||||
WHEN measurement_point_id IS NOT NULL THEN 10 ELSE 0 END +
|
||||
CASE
|
||||
WHEN district IS NOT NULL THEN 10 ELSE 0 END +
|
||||
CASE
|
||||
WHEN has_pedestrian_inference = true THEN 5 ELSE 0 END
|
||||
),
|
||||
updated_at = :updated_at
|
||||
WHERE city = :city AND data_quality_score IS NULL
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"updated_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
updated_count = result.rowcount
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Updated data quality scores",
|
||||
city=city, updated_count=updated_count)
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update data quality scores",
|
||||
city=city, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Data quality update failed: {str(e)}")
|
||||
|
||||
async def cleanup_old_synthetic_data(
|
||||
self,
|
||||
city: str,
|
||||
days_to_keep: int = 90
|
||||
) -> int:
|
||||
"""Clean up old synthetic data while preserving real data"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
query = delete(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.is_synthetic == True,
|
||||
self.model.date < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
deleted_count = result.rowcount
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Cleaned up old synthetic data",
|
||||
city=city, deleted_count=deleted_count, days_kept=days_to_keep)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old synthetic data",
|
||||
city=city, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Synthetic data cleanup failed: {str(e)}")
|
||||
|
||||
|
||||
class TrafficMeasurementPointRepository(DataBaseRepository[TrafficMeasurementPoint, Dict, Dict]):
|
||||
"""Repository for traffic measurement points across cities"""
|
||||
|
||||
async def get_points_near_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
city: str,
|
||||
radius_km: float = 10.0,
|
||||
limit: int = 20
|
||||
) -> List[TrafficMeasurementPoint]:
|
||||
"""Get measurement points near a location using spatial query"""
|
||||
try:
|
||||
# Simple distance calculation (for more precise, use PostGIS)
|
||||
query = text("""
|
||||
SELECT *,
|
||||
(6371 * acos(
|
||||
cos(radians(:lat)) * cos(radians(latitude)) *
|
||||
cos(radians(longitude) - radians(:lon)) +
|
||||
sin(radians(:lat)) * sin(radians(latitude))
|
||||
)) as distance_km
|
||||
FROM traffic_measurement_points
|
||||
WHERE city = :city
|
||||
AND is_active = true
|
||||
HAVING distance_km <= :radius_km
|
||||
ORDER BY distance_km
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
params = {
|
||||
"lat": latitude,
|
||||
"lon": longitude,
|
||||
"city": city,
|
||||
"radius_km": radius_km,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert rows to model instances
|
||||
points = []
|
||||
for row in rows:
|
||||
point = TrafficMeasurementPoint()
|
||||
for key, value in row._mapping.items():
|
||||
if hasattr(point, key) and key != 'distance_km':
|
||||
setattr(point, key, value)
|
||||
points.append(point)
|
||||
|
||||
return points
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get measurement points near location",
|
||||
latitude=latitude, longitude=longitude, city=city, error=str(e))
|
||||
raise DatabaseError(f"Measurement points query failed: {str(e)}")
|
||||
|
||||
|
||||
class TrafficBackgroundJobRepository(DataBaseRepository[TrafficDataBackgroundJob, Dict, Dict]):
|
||||
"""Repository for managing background traffic data jobs"""
|
||||
|
||||
async def get_pending_jobs_by_city(self, city: str) -> List[TrafficDataBackgroundJob]:
|
||||
"""Get pending background jobs for a specific city"""
|
||||
try:
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.status == 'pending'
|
||||
)
|
||||
).order_by(self.model.scheduled_at)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending jobs by city", city=city, error=str(e))
|
||||
raise DatabaseError(f"Background jobs query failed: {str(e)}")
|
||||
|
||||
async def update_job_progress(
|
||||
self,
|
||||
job_id: str,
|
||||
progress_percentage: float,
|
||||
records_processed: int = 0,
|
||||
records_stored: int = 0
|
||||
) -> bool:
|
||||
"""Update job progress"""
|
||||
try:
|
||||
query = update(self.model).where(
|
||||
self.model.id == job_id
|
||||
).values(
|
||||
progress_percentage=progress_percentage,
|
||||
records_processed=records_processed,
|
||||
records_stored=records_stored,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
await self.session.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update job progress", job_id=job_id, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Job progress update failed: {str(e)}")
|
||||
@@ -3,7 +3,7 @@
|
||||
# ================================================================
|
||||
"""Sales data schemas"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from uuid import UUID
|
||||
@@ -20,7 +20,8 @@ class SalesDataCreate(BaseModel):
|
||||
source: str = Field(default="manual", max_length=50)
|
||||
notes: Optional[str] = Field(None, max_length=500)
|
||||
|
||||
@validator('product_name')
|
||||
@field_validator('product_name')
|
||||
@classmethod
|
||||
def normalize_product_name(cls, v):
|
||||
return v.strip().lower()
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# ================================================================
|
||||
"""Traffic data schemas"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
@@ -14,7 +14,7 @@ class TrafficDataBase(BaseModel):
|
||||
date: datetime = Field(..., description="Date and time of traffic measurement")
|
||||
traffic_volume: Optional[int] = Field(None, ge=0, description="Vehicles per hour")
|
||||
pedestrian_count: Optional[int] = Field(None, ge=0, description="Pedestrians per hour")
|
||||
congestion_level: Optional[str] = Field(None, regex="^(low|medium|high)$", description="Traffic congestion level")
|
||||
congestion_level: Optional[str] = Field(None, pattern="^(low|medium|high)$", description="Traffic congestion level")
|
||||
average_speed: Optional[float] = Field(None, ge=0, le=200, description="Average speed in km/h")
|
||||
source: str = Field("madrid_opendata", max_length=50, description="Data source")
|
||||
raw_data: Optional[str] = Field(None, description="Raw data from source")
|
||||
@@ -27,7 +27,7 @@ class TrafficDataUpdate(BaseModel):
|
||||
"""Schema for updating traffic data"""
|
||||
traffic_volume: Optional[int] = Field(None, ge=0)
|
||||
pedestrian_count: Optional[int] = Field(None, ge=0)
|
||||
congestion_level: Optional[str] = Field(None, regex="^(low|medium|high)$")
|
||||
congestion_level: Optional[str] = Field(None, pattern="^(low|medium|high)$")
|
||||
average_speed: Optional[float] = Field(None, ge=0, le=200)
|
||||
raw_data: Optional[str] = None
|
||||
|
||||
@@ -37,7 +37,8 @@ class TrafficDataResponse(TrafficDataBase):
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
@validator('id', pre=True)
|
||||
@field_validator('id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
# ================================================================
|
||||
"""Weather data schemas"""
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
from uuid import UUID
|
||||
@@ -41,7 +41,8 @@ class WeatherDataResponse(WeatherDataBase):
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
@validator('id', pre=True)
|
||||
@field_validator('id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
@@ -76,7 +77,8 @@ class WeatherForecastResponse(WeatherForecastBase):
|
||||
created_at: datetime = Field(..., description="Creation timestamp")
|
||||
updated_at: datetime = Field(..., description="Last update timestamp")
|
||||
|
||||
@validator('id', pre=True)
|
||||
@field_validator('id', mode='before')
|
||||
@classmethod
|
||||
def convert_uuid_to_string(cls, v):
|
||||
if isinstance(v, UUID):
|
||||
return str(v)
|
||||
|
||||
@@ -1,122 +1,283 @@
|
||||
# ================================================================
|
||||
# services/data/app/services/traffic_service.py - FIXED VERSION
|
||||
# services/data/app/services/traffic_service.py
|
||||
# ================================================================
|
||||
"""Traffic data service with improved error handling"""
|
||||
"""
|
||||
Abstracted Traffic Service - Universal interface for traffic data across multiple cities
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_
|
||||
import structlog
|
||||
|
||||
from app.external.apis.traffic import UniversalTrafficClient
|
||||
from app.models.traffic import TrafficData
|
||||
from app.external.madrid_opendata import MadridOpenDataClient
|
||||
from app.schemas.external import TrafficDataResponse
|
||||
|
||||
import uuid
|
||||
from app.core.performance import (
|
||||
async_cache,
|
||||
monitor_performance,
|
||||
global_connection_pool,
|
||||
global_performance_monitor,
|
||||
batch_process
|
||||
)
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrafficService:
|
||||
"""
|
||||
Abstracted traffic service providing unified interface for traffic data
|
||||
Routes requests to appropriate city-specific clients automatically
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.madrid_client = MadridOpenDataClient()
|
||||
self.universal_client = UniversalTrafficClient()
|
||||
self.logger = structlog.get_logger(__name__)
|
||||
|
||||
async def get_current_traffic(self, latitude: float, longitude: float) -> Optional[TrafficDataResponse]:
|
||||
"""Get current traffic data for location"""
|
||||
@async_cache(ttl=300) # Cache for 5 minutes
|
||||
@monitor_performance(monitor=global_performance_monitor)
|
||||
async def get_current_traffic(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get current traffic data for any supported location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
tenant_id: Optional tenant identifier for logging/analytics
|
||||
|
||||
Returns:
|
||||
Dict with current traffic data or None if not available
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting current traffic", lat=latitude, lon=longitude)
|
||||
traffic_data = await self.madrid_client.get_current_traffic(latitude, longitude)
|
||||
self.logger.info("Getting current traffic data",
|
||||
lat=latitude, lon=longitude, tenant_id=tenant_id)
|
||||
|
||||
# Delegate to universal client
|
||||
traffic_data = await self.universal_client.get_current_traffic(latitude, longitude)
|
||||
|
||||
if traffic_data:
|
||||
logger.debug("Traffic data received", source=traffic_data.get('source'))
|
||||
# Add service metadata
|
||||
traffic_data['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude}
|
||||
}
|
||||
|
||||
# Validate and clean traffic data before creating response
|
||||
# Use keyword arguments instead of unpacking
|
||||
response = TrafficDataResponse(
|
||||
date=traffic_data.get("date", datetime.now()),
|
||||
traffic_volume=int(traffic_data.get("traffic_volume", 100)),
|
||||
pedestrian_count=int(traffic_data.get("pedestrian_count", 150)),
|
||||
congestion_level=str(traffic_data.get("congestion_level", "medium")),
|
||||
average_speed=float(traffic_data.get("average_speed", 25.0)), # Fixed: use float, not int
|
||||
source=str(traffic_data.get("source", "unknown"))
|
||||
)
|
||||
self.logger.info("Successfully retrieved current traffic data",
|
||||
lat=latitude, lon=longitude,
|
||||
source=traffic_data.get('source', 'unknown'))
|
||||
|
||||
logger.debug("Successfully created traffic response",
|
||||
traffic_volume=response.traffic_volume,
|
||||
congestion_level=response.congestion_level)
|
||||
return response
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("No traffic data received from Madrid client")
|
||||
self.logger.warning("No current traffic data available",
|
||||
lat=latitude, lon=longitude)
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get current traffic", error=str(e), lat=latitude, lon=longitude)
|
||||
# Log the full traceback for debugging
|
||||
import traceback
|
||||
logger.error("Traffic service traceback", traceback=traceback.format_exc())
|
||||
self.logger.error("Error getting current traffic data",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return None
|
||||
|
||||
async def get_historical_traffic(self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
db: AsyncSession) -> List[TrafficDataResponse]:
|
||||
"""Get historical traffic data with enhanced storage for re-training"""
|
||||
@async_cache(ttl=1800) # Cache for 30 minutes (historical data changes less frequently)
|
||||
@monitor_performance(monitor=global_performance_monitor)
|
||||
async def get_historical_traffic(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None,
|
||||
db: Optional[AsyncSession] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get historical traffic data for any supported location with database storage
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
start_date: Start date for historical data
|
||||
end_date: End date for historical data
|
||||
tenant_id: Optional tenant identifier
|
||||
db: Optional database session for storage
|
||||
|
||||
Returns:
|
||||
List of historical traffic data dictionaries
|
||||
"""
|
||||
try:
|
||||
logger.debug("Getting historical traffic",
|
||||
lat=latitude, lon=longitude,
|
||||
start=start_date, end=end_date)
|
||||
self.logger.info("Getting historical traffic data",
|
||||
lat=latitude, lon=longitude,
|
||||
start=start_date, end=end_date, tenant_id=tenant_id)
|
||||
|
||||
# Validate date range
|
||||
if start_date >= end_date:
|
||||
self.logger.warning("Invalid date range", start=start_date, end=end_date)
|
||||
return []
|
||||
|
||||
# Check database first
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
stmt = select(TrafficData).where(
|
||||
and_(
|
||||
TrafficData.location_id == location_id,
|
||||
TrafficData.date >= start_date,
|
||||
TrafficData.date <= end_date
|
||||
)
|
||||
).order_by(TrafficData.date)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
db_records = result.scalars().all()
|
||||
# Check database first if session provided
|
||||
if db:
|
||||
stmt = select(TrafficData).where(
|
||||
and_(
|
||||
TrafficData.location_id == location_id,
|
||||
TrafficData.date >= start_date,
|
||||
TrafficData.date <= end_date
|
||||
)
|
||||
).order_by(TrafficData.date)
|
||||
|
||||
result = await db.execute(stmt)
|
||||
db_records = result.scalars().all()
|
||||
|
||||
if db_records:
|
||||
self.logger.info("Historical traffic data found in database",
|
||||
count=len(db_records))
|
||||
return [self._convert_db_record_to_dict(record) for record in db_records]
|
||||
|
||||
if db_records:
|
||||
logger.debug("Historical traffic data found in database", count=len(db_records))
|
||||
return [TrafficDataResponse(
|
||||
date=record.date,
|
||||
traffic_volume=record.traffic_volume,
|
||||
pedestrian_count=record.pedestrian_count,
|
||||
congestion_level=record.congestion_level,
|
||||
average_speed=record.average_speed,
|
||||
source=record.source
|
||||
) for record in db_records]
|
||||
|
||||
# If not in database, fetch from API and store
|
||||
logger.debug("Fetching historical data from MADRID OPEN DATA")
|
||||
traffic_data = await self.madrid_client.get_historical_traffic(
|
||||
# Delegate to universal client
|
||||
traffic_data = await self.universal_client.get_historical_traffic(
|
||||
latitude, longitude, start_date, end_date
|
||||
)
|
||||
|
||||
if traffic_data:
|
||||
# Enhanced storage with better error handling and validation
|
||||
stored_count = await self._store_traffic_data_batch(
|
||||
traffic_data, location_id, db
|
||||
)
|
||||
logger.info("Traffic data stored for re-training",
|
||||
fetched=len(traffic_data), stored=stored_count, location=location_id)
|
||||
|
||||
return [TrafficDataResponse(**item) for item in traffic_data]
|
||||
# Add service metadata to each record
|
||||
for record in traffic_data:
|
||||
record['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude},
|
||||
'date_range': {
|
||||
'start': start_date.isoformat(),
|
||||
'end': end_date.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Store in database if session provided
|
||||
if db:
|
||||
stored_count = await self._store_traffic_data_batch(
|
||||
traffic_data, location_id, db
|
||||
)
|
||||
self.logger.info("Traffic data stored for re-training",
|
||||
fetched=len(traffic_data), stored=stored_count,
|
||||
location=location_id)
|
||||
|
||||
self.logger.info("Successfully retrieved historical traffic data",
|
||||
lat=latitude, lon=longitude, records=len(traffic_data))
|
||||
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("No historical traffic data received")
|
||||
self.logger.info("No historical traffic data available",
|
||||
lat=latitude, lon=longitude)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get historical traffic", error=str(e))
|
||||
self.logger.error("Error getting historical traffic data",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
def _convert_db_record_to_dict(self, record: TrafficData) -> Dict[str, Any]:
|
||||
"""Convert database record to dictionary format"""
|
||||
return {
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume,
|
||||
'pedestrian_count': record.pedestrian_count,
|
||||
'congestion_level': record.congestion_level,
|
||||
'average_speed': record.average_speed,
|
||||
'source': record.source,
|
||||
'location_id': record.location_id,
|
||||
'raw_data': record.raw_data
|
||||
}
|
||||
|
||||
async def get_traffic_events(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
radius_km: float = 5.0,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get traffic events and incidents for any supported location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
radius_km: Search radius in kilometers
|
||||
tenant_id: Optional tenant identifier
|
||||
|
||||
Returns:
|
||||
List of traffic events
|
||||
"""
|
||||
try:
|
||||
self.logger.info("Getting traffic events",
|
||||
lat=latitude, lon=longitude, radius=radius_km, tenant_id=tenant_id)
|
||||
|
||||
# Delegate to universal client
|
||||
events = await self.universal_client.get_events(latitude, longitude, radius_km)
|
||||
|
||||
# Add metadata to events
|
||||
for event in events:
|
||||
event['service_metadata'] = {
|
||||
'request_timestamp': datetime.now().isoformat(),
|
||||
'tenant_id': tenant_id,
|
||||
'service_version': '2.0',
|
||||
'query_location': {'latitude': latitude, 'longitude': longitude},
|
||||
'search_radius_km': radius_km
|
||||
}
|
||||
|
||||
self.logger.info("Retrieved traffic events",
|
||||
lat=latitude, lon=longitude, events=len(events))
|
||||
|
||||
return events
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error getting traffic events",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return []
|
||||
|
||||
def get_location_info(self, latitude: float, longitude: float) -> Dict[str, Any]:
|
||||
"""
|
||||
Get information about traffic data availability for location
|
||||
|
||||
Args:
|
||||
latitude: Query location latitude
|
||||
longitude: Query location longitude
|
||||
|
||||
Returns:
|
||||
Dict with location support information
|
||||
"""
|
||||
try:
|
||||
info = self.universal_client.get_location_info(latitude, longitude)
|
||||
|
||||
# Add service layer information
|
||||
info['service_layer'] = {
|
||||
'version': '2.0',
|
||||
'abstraction_level': 'universal',
|
||||
'supported_operations': [
|
||||
'current_traffic',
|
||||
'historical_traffic',
|
||||
'traffic_events',
|
||||
'bulk_requests'
|
||||
]
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error("Error getting location info",
|
||||
lat=latitude, lon=longitude, error=str(e))
|
||||
return {
|
||||
'supported': False,
|
||||
'error': str(e),
|
||||
'service_layer': {'version': '2.0'}
|
||||
}
|
||||
|
||||
async def store_traffic_data(self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
@@ -176,7 +337,8 @@ class TrafficService:
|
||||
else:
|
||||
existing_dates = set()
|
||||
|
||||
# Store only new records
|
||||
# Prepare batch of new records for bulk insert
|
||||
batch_records = []
|
||||
for data in traffic_data:
|
||||
try:
|
||||
record_date = data.get('date')
|
||||
@@ -188,32 +350,41 @@ class TrafficService:
|
||||
logger.warning("Invalid traffic data, skipping", data=data)
|
||||
continue
|
||||
|
||||
traffic_record = TrafficData(
|
||||
location_id=location_id,
|
||||
date=record_date,
|
||||
traffic_volume=data.get('traffic_volume'),
|
||||
pedestrian_count=data.get('pedestrian_count'),
|
||||
congestion_level=data.get('congestion_level'),
|
||||
average_speed=data.get('average_speed'),
|
||||
source=data.get('source', 'madrid_opendata'),
|
||||
raw_data=str(data)
|
||||
)
|
||||
|
||||
db.add(traffic_record)
|
||||
stored_count += 1
|
||||
|
||||
# Commit in batches to avoid memory issues
|
||||
if stored_count % 100 == 0:
|
||||
await db.commit()
|
||||
logger.debug(f"Committed batch of {stored_count} records")
|
||||
# Prepare record data for bulk insert
|
||||
record_data = {
|
||||
'location_id': location_id,
|
||||
'date': record_date,
|
||||
'traffic_volume': data.get('traffic_volume'),
|
||||
'pedestrian_count': data.get('pedestrian_count'),
|
||||
'congestion_level': data.get('congestion_level'),
|
||||
'average_speed': data.get('average_speed'),
|
||||
'source': data.get('source', 'madrid_opendata'),
|
||||
'raw_data': str(data)
|
||||
}
|
||||
batch_records.append(record_data)
|
||||
|
||||
except Exception as record_error:
|
||||
logger.warning("Failed to store individual traffic record",
|
||||
logger.warning("Failed to prepare traffic record",
|
||||
error=str(record_error), data=data)
|
||||
continue
|
||||
|
||||
# Final commit
|
||||
await db.commit()
|
||||
# Use efficient bulk insert instead of individual records
|
||||
if batch_records:
|
||||
# Process in chunks to avoid memory issues
|
||||
chunk_size = 5000
|
||||
for i in range(0, len(batch_records), chunk_size):
|
||||
chunk = batch_records[i:i + chunk_size]
|
||||
|
||||
# Use SQLAlchemy bulk insert for maximum performance
|
||||
await db.execute(
|
||||
TrafficData.__table__.insert(),
|
||||
chunk
|
||||
)
|
||||
await db.commit()
|
||||
stored_count += len(chunk)
|
||||
|
||||
logger.debug(f"Bulk inserted {len(chunk)} records (total: {stored_count})")
|
||||
|
||||
logger.info(f"Successfully stored {stored_count} traffic records for location {location_id}")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,405 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Updated Madrid Historical Traffic test for pytest inside Docker
|
||||
Configured for June 2025 data availability (last available historical data)
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# Import from the actual service
|
||||
from app.external.madrid_opendata import MadridOpenDataClient
|
||||
from app.core.config import settings
|
||||
import structlog
|
||||
|
||||
# Configure pytest for async
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
# Use actual logger
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TestMadridTrafficInside:
|
||||
"""Test class for Madrid traffic functionality inside Docker"""
|
||||
|
||||
@pytest.fixture
|
||||
def client(self):
|
||||
"""Create Madrid client for testing"""
|
||||
return MadridOpenDataClient()
|
||||
|
||||
@pytest.fixture
|
||||
def madrid_coords(self):
|
||||
"""Madrid center coordinates"""
|
||||
return 40.4168, -3.7038
|
||||
|
||||
@pytest.fixture
|
||||
def june_2025_dates(self):
|
||||
"""Date ranges for June 2025 (last available historical data)"""
|
||||
return {
|
||||
"quick": {
|
||||
"start": datetime(2025, 6, 1, 0, 0),
|
||||
"end": datetime(2025, 6, 1, 6, 0) # 6 hours on June 1st
|
||||
},
|
||||
"one_day": {
|
||||
"start": datetime(2025, 6, 15, 0, 0), # Mid-June
|
||||
"end": datetime(2025, 6, 16, 0, 0) # One full day
|
||||
},
|
||||
"three_days": {
|
||||
"start": datetime(2025, 6, 10, 0, 0),
|
||||
"end": datetime(2025, 6, 13, 0, 0) # 3 days in June
|
||||
},
|
||||
"recent_synthetic": {
|
||||
"start": datetime.now() - timedelta(hours=6),
|
||||
"end": datetime.now() # Recent data (will be synthetic)
|
||||
}
|
||||
}
|
||||
|
||||
async def test_quick_historical_traffic_june2025(self, client, madrid_coords, june_2025_dates):
|
||||
"""Test quick historical traffic data from June 2025"""
|
||||
lat, lon = madrid_coords
|
||||
date_range = june_2025_dates["quick"]
|
||||
start_time = date_range["start"]
|
||||
end_time = date_range["end"]
|
||||
|
||||
print(f"\n=== Quick Test (June 2025 - 6 hours) ===")
|
||||
print(f"Location: {lat}, {lon}")
|
||||
print(f"Date range: {start_time.strftime('%Y-%m-%d %H:%M')} to {end_time.strftime('%Y-%m-%d %H:%M')}")
|
||||
print(f"Note: Testing with June 2025 data (last available historical month)")
|
||||
|
||||
# Test the function
|
||||
execution_start = datetime.now()
|
||||
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
|
||||
execution_time = (datetime.now() - execution_start).total_seconds()
|
||||
|
||||
print(f"⏱️ Execution time: {execution_time:.2f} seconds")
|
||||
print(f"📊 Records returned: {len(result)}")
|
||||
|
||||
# Assertions
|
||||
assert isinstance(result, list), "Result should be a list"
|
||||
assert len(result) > 0, "Should return at least some records"
|
||||
assert execution_time < 5000, "Should execute in reasonable time (allowing for ZIP download)"
|
||||
|
||||
# Check first record structure
|
||||
if result:
|
||||
sample = result[0]
|
||||
print(f"📋 Sample record keys: {list(sample.keys())}")
|
||||
print(f"📡 Data source: {sample.get('source', 'unknown')}")
|
||||
|
||||
# Required fields
|
||||
required_fields = ['date', 'traffic_volume', 'congestion_level', 'average_speed', 'source']
|
||||
for field in required_fields:
|
||||
assert field in sample, f"Missing required field: {field}"
|
||||
|
||||
# Data validation
|
||||
assert isinstance(sample['traffic_volume'], int), "Traffic volume should be int"
|
||||
assert 0 <= sample['traffic_volume'] <= 1000, "Traffic volume should be reasonable"
|
||||
assert sample['congestion_level'] in ['low', 'medium', 'high', 'blocked'], "Invalid congestion level"
|
||||
assert 5 <= sample['average_speed'] <= 100, "Speed should be reasonable"
|
||||
assert isinstance(sample['date'], datetime), "Date should be datetime object"
|
||||
|
||||
# Check if we got real Madrid data or synthetic
|
||||
if sample['source'] == 'madrid_opendata_zip':
|
||||
print(f"🎉 SUCCESS: Got real Madrid historical data from ZIP!")
|
||||
else:
|
||||
print(f"ℹ️ Got synthetic data (real data may not be available)")
|
||||
|
||||
print(f"✅ All validations passed")
|
||||
|
||||
async def test_one_day_june2025(self, client, madrid_coords, june_2025_dates):
|
||||
"""Test one day of June 2025 historical traffic data"""
|
||||
lat, lon = madrid_coords
|
||||
date_range = june_2025_dates["one_day"]
|
||||
start_time = date_range["start"]
|
||||
end_time = date_range["end"]
|
||||
|
||||
print(f"\n=== One Day Test (June 15, 2025) ===")
|
||||
print(f"Date range: {start_time.strftime('%Y-%m-%d %H:%M')} to {end_time.strftime('%Y-%m-%d %H:%M')}")
|
||||
|
||||
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
|
||||
|
||||
print(f"📊 Records returned: {len(result)}")
|
||||
|
||||
# Should have roughly 24 records (one per hour)
|
||||
assert len(result) >= 20, "Should have at least 20 hourly records for one day"
|
||||
assert len(result) <= 5000, "Should not have more than 30 records for one day"
|
||||
|
||||
# Check data source
|
||||
if result:
|
||||
sources = set(r['source'] for r in result)
|
||||
print(f"📡 Data sources: {', '.join(sources)}")
|
||||
|
||||
# If we got real data, check for realistic measurement point IDs
|
||||
real_data_records = [r for r in result if r['source'] == 'madrid_opendata_zip']
|
||||
if real_data_records:
|
||||
point_ids = set(r['measurement_point_id'] for r in real_data_records)
|
||||
print(f"🏷️ Real measurement points found: {len(point_ids)}")
|
||||
print(f" Sample IDs: {list(point_ids)[:3]}")
|
||||
|
||||
# Check traffic patterns
|
||||
if len(result) >= 24:
|
||||
# Find rush hour records (7-9 AM, 6-8 PM)
|
||||
rush_hour_records = [r for r in result if 7 <= r['date'].hour <= 9 or 18 <= r['date'].hour <= 20]
|
||||
night_records = [r for r in result if r['date'].hour <= 6 or r['date'].hour >= 22]
|
||||
|
||||
if rush_hour_records and night_records:
|
||||
avg_rush_traffic = sum(r['traffic_volume'] for r in rush_hour_records) / len(rush_hour_records)
|
||||
avg_night_traffic = sum(r['traffic_volume'] for r in night_records) / len(night_records)
|
||||
|
||||
print(f"📈 Rush hour avg traffic: {avg_rush_traffic:.1f}")
|
||||
print(f"🌙 Night avg traffic: {avg_night_traffic:.1f}")
|
||||
|
||||
# Rush hour should typically have more traffic than night
|
||||
if avg_rush_traffic > avg_night_traffic:
|
||||
print(f"✅ Traffic patterns look realistic")
|
||||
else:
|
||||
print(f"⚠️ Traffic patterns unusual (not necessarily wrong)")
|
||||
|
||||
async def test_three_days_june2025(self, client, madrid_coords, june_2025_dates):
|
||||
"""Test three days of June 2025 historical traffic data"""
|
||||
lat, lon = madrid_coords
|
||||
date_range = june_2025_dates["three_days"]
|
||||
start_time = date_range["start"]
|
||||
end_time = date_range["end"]
|
||||
|
||||
print(f"\n=== Three Days Test (June 10-13, 2025) ===")
|
||||
print(f"Date range: {start_time.strftime('%Y-%m-%d')} to {end_time.strftime('%Y-%m-%d')}")
|
||||
|
||||
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
|
||||
|
||||
print(f"📊 Records returned: {len(result)}")
|
||||
|
||||
# Should have roughly 72 records (24 hours * 3 days)
|
||||
assert len(result) >= 60, "Should have at least 60 records for 3 days"
|
||||
assert len(result) <= 5000, "Should not have more than 90 records for 3 days"
|
||||
|
||||
# Check data sources
|
||||
sources = set(r['source'] for r in result)
|
||||
print(f"📡 Data sources: {', '.join(sources)}")
|
||||
|
||||
# Calculate statistics
|
||||
traffic_volumes = [r['traffic_volume'] for r in result]
|
||||
speeds = [r['average_speed'] for r in result]
|
||||
|
||||
avg_traffic = sum(traffic_volumes) / len(traffic_volumes)
|
||||
max_traffic = max(traffic_volumes)
|
||||
min_traffic = min(traffic_volumes)
|
||||
avg_speed = sum(speeds) / len(speeds)
|
||||
|
||||
print(f"📈 Statistics:")
|
||||
print(f" Average traffic: {avg_traffic:.1f}")
|
||||
print(f" Max traffic: {max_traffic}")
|
||||
print(f" Min traffic: {min_traffic}")
|
||||
print(f" Average speed: {avg_speed:.1f} km/h")
|
||||
|
||||
# Analyze by data source
|
||||
real_data_records = [r for r in result if r['source'] == 'madrid_opendata_zip']
|
||||
synthetic_records = [r for r in result if r['source'] != 'madrid_opendata_zip']
|
||||
|
||||
print(f"🔍 Data breakdown:")
|
||||
print(f" Real Madrid data: {len(real_data_records)} records")
|
||||
print(f" Synthetic data: {len(synthetic_records)} records")
|
||||
|
||||
if real_data_records:
|
||||
# Show measurement points from real data
|
||||
real_points = set(r['measurement_point_id'] for r in real_data_records)
|
||||
print(f" Real measurement points: {len(real_points)}")
|
||||
|
||||
# Sanity checks
|
||||
assert 10 <= avg_traffic <= 500, "Average traffic should be reasonable"
|
||||
assert 10 <= avg_speed <= 60, "Average speed should be reasonable"
|
||||
assert max_traffic >= avg_traffic, "Max should be >= average"
|
||||
assert min_traffic <= avg_traffic, "Min should be <= average"
|
||||
|
||||
async def test_recent_vs_historical_data(self, client, madrid_coords, june_2025_dates):
|
||||
"""Compare recent data (synthetic) vs June 2025 data (potentially real)"""
|
||||
lat, lon = madrid_coords
|
||||
|
||||
print(f"\n=== Recent vs Historical Data Comparison ===")
|
||||
|
||||
# Test recent data (should be synthetic)
|
||||
recent_range = june_2025_dates["recent_synthetic"]
|
||||
recent_result = await client.get_historical_traffic(
|
||||
lat, lon, recent_range["start"], recent_range["end"]
|
||||
)
|
||||
|
||||
# Test June 2025 data (potentially real)
|
||||
june_range = june_2025_dates["quick"]
|
||||
june_result = await client.get_historical_traffic(
|
||||
lat, lon, june_range["start"], june_range["end"]
|
||||
)
|
||||
|
||||
print(f"📊 Recent data: {len(recent_result)} records")
|
||||
print(f"📊 June 2025 data: {len(june_result)} records")
|
||||
|
||||
if recent_result:
|
||||
recent_sources = set(r['source'] for r in recent_result)
|
||||
print(f"📡 Recent sources: {', '.join(recent_sources)}")
|
||||
|
||||
if june_result:
|
||||
june_sources = set(r['source'] for r in june_result)
|
||||
print(f"📡 June sources: {', '.join(june_sources)}")
|
||||
|
||||
# Check if we successfully got real data from June
|
||||
if 'madrid_opendata_zip' in june_sources:
|
||||
print(f"🎉 SUCCESS: Real Madrid data successfully fetched from June 2025!")
|
||||
|
||||
# Show details of real data
|
||||
real_records = [r for r in june_result if r['source'] == 'madrid_opendata_zip']
|
||||
if real_records:
|
||||
sample = real_records[0]
|
||||
print(f"📋 Real data sample:")
|
||||
print(f" Date: {sample['date']}")
|
||||
print(f" Traffic volume: {sample['traffic_volume']}")
|
||||
print(f" Measurement point: {sample['measurement_point_id']}")
|
||||
print(f" Point name: {sample.get('measurement_point_name', 'N/A')}")
|
||||
else:
|
||||
print(f"ℹ️ June data is synthetic (real ZIP may not be accessible)")
|
||||
|
||||
async def test_madrid_zip_month_code(self, client):
|
||||
"""Test the month code calculation for Madrid ZIP files"""
|
||||
print(f"\n=== Madrid ZIP Month Code Test ===")
|
||||
|
||||
# Test the month code calculation function
|
||||
test_cases = [
|
||||
(2025, 6, 145), # Known: June 2025 = 145
|
||||
(2025, 5, 144), # Known: May 2025 = 144
|
||||
(2025, 4, 143), # Known: April 2025 = 143
|
||||
(2025, 7, 146), # Predicted: July 2025 = 146
|
||||
]
|
||||
|
||||
for year, month, expected_code in test_cases:
|
||||
if hasattr(client, '_calculate_madrid_month_code'):
|
||||
calculated_code = client._calculate_madrid_month_code(year, month)
|
||||
status = "✅" if calculated_code == expected_code else "⚠️"
|
||||
print(f"{status} {year}-{month:02d}: Expected {expected_code}, Got {calculated_code}")
|
||||
|
||||
# Generate ZIP URL
|
||||
if calculated_code:
|
||||
zip_url = f"https://datos.madrid.es/egob/catalogo/208627-{calculated_code}-transporte-ptomedida-historico.zip"
|
||||
print(f" ZIP URL: {zip_url}")
|
||||
else:
|
||||
print(f"⚠️ Month code calculation function not available")
|
||||
|
||||
async def test_edge_case_large_date_range(self, client, madrid_coords):
|
||||
"""Test edge case: date range too large"""
|
||||
lat, lon = madrid_coords
|
||||
start_time = datetime(2025, 1, 1) # 6+ months range
|
||||
end_time = datetime(2025, 7, 1)
|
||||
|
||||
print(f"\n=== Edge Case: Large Date Range ===")
|
||||
print(f"Testing 6-month range: {start_time.date()} to {end_time.date()}")
|
||||
|
||||
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
|
||||
|
||||
print(f"📊 Records for 6-month range: {len(result)}")
|
||||
|
||||
# Should return empty list for ranges > 90 days
|
||||
assert len(result) == 0, "Should return empty list for date ranges > 90 days"
|
||||
print(f"✅ Correctly handled large date range")
|
||||
|
||||
async def test_edge_case_invalid_coordinates(self, client):
|
||||
"""Test edge case: invalid coordinates"""
|
||||
print(f"\n=== Edge Case: Invalid Coordinates ===")
|
||||
|
||||
start_time = datetime(2025, 6, 1)
|
||||
end_time = datetime(2025, 6, 1, 6, 0)
|
||||
|
||||
# Test with invalid coordinates
|
||||
result = await client.get_historical_traffic(999.0, 999.0, start_time, end_time)
|
||||
|
||||
print(f"📊 Records for invalid coords: {len(result)}")
|
||||
|
||||
# Should either return empty list or synthetic data
|
||||
# The function should not crash
|
||||
assert isinstance(result, list), "Should return list even with invalid coords"
|
||||
print(f"✅ Handled invalid coordinates gracefully")
|
||||
|
||||
async def test_real_madrid_zip_access(self, client):
|
||||
"""Test if we can access the actual Madrid ZIP files"""
|
||||
print(f"\n=== Real Madrid ZIP Access Test ===")
|
||||
|
||||
# Test the known ZIP URLs you provided
|
||||
test_urls = [
|
||||
"https://datos.madrid.es/egob/catalogo/208627-145-transporte-ptomedida-historico.zip", # June 2025
|
||||
"https://datos.madrid.es/egob/catalogo/208627-144-transporte-ptomedida-historico.zip", # May 2025
|
||||
"https://datos.madrid.es/egob/catalogo/208627-143-transporte-ptomedida-historico.zip", # April 2025
|
||||
]
|
||||
|
||||
for i, url in enumerate(test_urls):
|
||||
month_name = ["June 2025", "May 2025", "April 2025"][i]
|
||||
print(f"\nTesting {month_name}: {url}")
|
||||
|
||||
try:
|
||||
if hasattr(client, '_fetch_historical_zip'):
|
||||
zip_data = await client._fetch_historical_zip(url)
|
||||
if zip_data:
|
||||
print(f"✅ Successfully fetched ZIP: {len(zip_data)} bytes")
|
||||
|
||||
# Try to inspect ZIP contents
|
||||
try:
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
with zipfile.ZipFile(BytesIO(zip_data), 'r') as zip_file:
|
||||
files = zip_file.namelist()
|
||||
csv_files = [f for f in files if f.endswith('.csv')]
|
||||
print(f"📁 ZIP contains {len(files)} files, {len(csv_files)} CSV files")
|
||||
|
||||
if csv_files:
|
||||
print(f" CSV files: {csv_files[:2]}{'...' if len(csv_files) > 2 else ''}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not inspect ZIP contents: {e}")
|
||||
else:
|
||||
print(f"❌ Failed to fetch ZIP")
|
||||
else:
|
||||
print(f"⚠️ ZIP fetch function not available")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error testing ZIP access: {e}")
|
||||
|
||||
|
||||
# Additional standalone test functions for manual running
|
||||
async def run_manual_test():
|
||||
"""Manual test function that can be run directly"""
|
||||
print("="*60)
|
||||
print("MADRID TRAFFIC TEST - JUNE 2025 DATA")
|
||||
print("="*60)
|
||||
|
||||
client = MadridOpenDataClient()
|
||||
madrid_lat, madrid_lon = 40.4168, -3.7038
|
||||
|
||||
# Test with June 2025 data (last available)
|
||||
start_time = datetime(2025, 6, 15, 14, 0) # June 15, 2025 at 2 PM
|
||||
end_time = datetime(2025, 6, 15, 18, 0) # Until 6 PM (4 hours)
|
||||
|
||||
print(f"\nTesting June 15, 2025 data (2 PM - 6 PM)...")
|
||||
print(f"This should include afternoon traffic patterns")
|
||||
|
||||
result = await client.get_historical_traffic(madrid_lat, madrid_lon, start_time, end_time)
|
||||
|
||||
print(f"Result: {len(result)} records")
|
||||
|
||||
if result:
|
||||
sources = set(r['source'] for r in result)
|
||||
print(f"Data sources: {', '.join(sources)}")
|
||||
|
||||
if 'madrid_opendata_zip' in sources:
|
||||
print(f"🎉 Successfully got real Madrid data!")
|
||||
|
||||
sample = result[0]
|
||||
print(f"\nSample record:")
|
||||
for key, value in sample.items():
|
||||
if key == "date":
|
||||
print(f" {key}: {value.strftime('%Y-%m-%d %H:%M:%S')}")
|
||||
else:
|
||||
print(f" {key}: {value}")
|
||||
|
||||
print(f"\n✅ Manual test completed!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# If run directly, execute manual test
|
||||
asyncio.run(run_manual_test())
|
||||
@@ -8,7 +8,7 @@ from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any, Tuple
|
||||
from dataclasses import dataclass
|
||||
import asyncio
|
||||
import logging
|
||||
import structlog
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import timezone
|
||||
import pandas as pd
|
||||
@@ -24,7 +24,7 @@ from app.services.messaging import (
|
||||
publish_job_failed
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = structlog.get_logger()
|
||||
|
||||
@dataclass
|
||||
class TrainingDataSet:
|
||||
@@ -39,15 +39,14 @@ class TrainingDataOrchestrator:
|
||||
"""
|
||||
Enhanced orchestrator for data collection from multiple sources.
|
||||
Ensures date alignment, handles data source constraints, and prepares data for ML training.
|
||||
Uses the new abstracted traffic service layer for multi-city support.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
madrid_client=None,
|
||||
weather_client=None,
|
||||
date_alignment_service: DateAlignmentService = None):
|
||||
self.data_client = DataClient()
|
||||
self.date_alignment_service = date_alignment_service or DateAlignmentService()
|
||||
self.max_concurrent_requests = 3
|
||||
self.max_concurrent_requests = 5 # Increased for better performance
|
||||
|
||||
async def prepare_training_data(
|
||||
self,
|
||||
@@ -281,11 +280,11 @@ class TrainingDataOrchestrator:
|
||||
)
|
||||
tasks.append(("weather", weather_task))
|
||||
|
||||
# Traffic data collection
|
||||
# Enhanced Traffic data collection (supports multiple cities)
|
||||
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
|
||||
logger.info(f"🚛 Traffic data source available, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
|
||||
logger.info(f"🚛 Traffic data source available for multiple cities, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
|
||||
traffic_task = asyncio.create_task(
|
||||
self._collect_traffic_data_with_timeout(lat, lon, aligned_range, tenant_id)
|
||||
self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
|
||||
)
|
||||
tasks.append(("traffic", traffic_task))
|
||||
else:
|
||||
@@ -353,28 +352,31 @@ class TrainingDataOrchestrator:
|
||||
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
|
||||
return self._generate_synthetic_weather_data(aligned_range)
|
||||
|
||||
async def _collect_traffic_data_with_timeout(
|
||||
async def _collect_traffic_data_with_timeout_enhanced(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect traffic data with enhanced storage and retrieval for re-training"""
|
||||
"""
|
||||
Enhanced traffic data collection with multi-city support and improved storage
|
||||
Uses the new abstracted traffic service layer
|
||||
"""
|
||||
try:
|
||||
|
||||
# Double-check Madrid constraint before making request
|
||||
# Double-check constraints before making request
|
||||
constraint_violated = self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end)
|
||||
if constraint_violated:
|
||||
logger.warning(f"🚫 Madrid current month constraint violation: end_date={aligned_range.end}, no traffic data available")
|
||||
logger.warning(f"🚫 Current month constraint violation: end_date={aligned_range.end}, no traffic data available")
|
||||
return []
|
||||
else:
|
||||
logger.info(f"✅ Madrid constraint passed: end_date={aligned_range.end}, proceeding with traffic data request")
|
||||
logger.info(f"✅ Date constraints passed: end_date={aligned_range.end}, proceeding with traffic data request")
|
||||
|
||||
start_date_str = aligned_range.start.isoformat()
|
||||
end_date_str = aligned_range.end.isoformat()
|
||||
|
||||
# Fetch traffic data - this will automatically store it for future re-training
|
||||
# Enhanced: Fetch traffic data using new abstracted service
|
||||
# This automatically detects the appropriate city and uses the right client
|
||||
traffic_data = await self.data_client.fetch_traffic_data(
|
||||
tenant_id=tenant_id,
|
||||
start_date=start_date_str,
|
||||
@@ -382,39 +384,82 @@ class TrainingDataOrchestrator:
|
||||
latitude=lat,
|
||||
longitude=lon)
|
||||
|
||||
# Validate traffic data
|
||||
if self._validate_traffic_data(traffic_data):
|
||||
logger.info(f"Collected and stored {len(traffic_data)} valid traffic records for re-training")
|
||||
# Enhanced validation including pedestrian inference data
|
||||
if self._validate_traffic_data_enhanced(traffic_data):
|
||||
logger.info(f"Collected and stored {len(traffic_data)} valid enhanced traffic records for re-training")
|
||||
|
||||
# Log storage success for audit purposes
|
||||
self._log_traffic_data_storage(lat, lon, aligned_range, len(traffic_data))
|
||||
# Log storage success with enhanced metadata
|
||||
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, len(traffic_data), traffic_data)
|
||||
|
||||
return traffic_data
|
||||
else:
|
||||
logger.warning("Invalid traffic data received")
|
||||
logger.warning("Invalid enhanced traffic data received")
|
||||
return []
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(f"Traffic data collection timed out")
|
||||
logger.warning(f"Enhanced traffic data collection timed out")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Traffic data collection failed: {e}")
|
||||
logger.warning(f"Enhanced traffic data collection failed: {e}")
|
||||
return []
|
||||
|
||||
# Keep original method for backwards compatibility
|
||||
async def _collect_traffic_data_with_timeout(
|
||||
self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
tenant_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Legacy traffic data collection method - redirects to enhanced version"""
|
||||
return await self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
|
||||
|
||||
def _log_enhanced_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
record_count: int,
|
||||
traffic_data: List[Dict[str, Any]]):
|
||||
"""Enhanced logging for traffic data storage with detailed metadata"""
|
||||
# Analyze the stored data for additional insights
|
||||
cities_detected = set()
|
||||
has_pedestrian_data = 0
|
||||
data_sources = set()
|
||||
districts_covered = set()
|
||||
|
||||
for record in traffic_data:
|
||||
if 'city' in record and record['city']:
|
||||
cities_detected.add(record['city'])
|
||||
if 'pedestrian_count' in record and record['pedestrian_count'] is not None:
|
||||
has_pedestrian_data += 1
|
||||
if 'source' in record and record['source']:
|
||||
data_sources.add(record['source'])
|
||||
if 'district' in record and record['district']:
|
||||
districts_covered.add(record['district'])
|
||||
|
||||
logger.info(
|
||||
"Enhanced traffic data stored for re-training",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
|
||||
records_stored=record_count,
|
||||
cities_detected=list(cities_detected),
|
||||
pedestrian_inference_coverage=f"{has_pedestrian_data}/{record_count}",
|
||||
data_sources=list(data_sources),
|
||||
districts_covered=list(districts_covered),
|
||||
storage_timestamp=datetime.now().isoformat(),
|
||||
purpose="enhanced_model_training_and_retraining",
|
||||
architecture_version="2.0_abstracted"
|
||||
)
|
||||
|
||||
def _log_traffic_data_storage(self,
|
||||
lat: float,
|
||||
lon: float,
|
||||
aligned_range: AlignedDateRange,
|
||||
record_count: int):
|
||||
"""Log traffic data storage for audit and re-training tracking"""
|
||||
logger.info(
|
||||
"Traffic data stored for re-training",
|
||||
location=f"{lat:.4f},{lon:.4f}",
|
||||
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
|
||||
records_stored=record_count,
|
||||
storage_timestamp=datetime.now().isoformat(),
|
||||
purpose="model_training_and_retraining"
|
||||
)
|
||||
"""Legacy logging method - redirects to enhanced version"""
|
||||
# Create minimal traffic data structure for enhanced logging
|
||||
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
|
||||
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
|
||||
|
||||
async def retrieve_stored_traffic_for_retraining(
|
||||
self,
|
||||
@@ -491,32 +536,73 @@ class TrainingDataOrchestrator:
|
||||
|
||||
return is_valid
|
||||
|
||||
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Validate traffic data quality"""
|
||||
def _validate_traffic_data_enhanced(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Enhanced validation for traffic data including pedestrian inference and city-specific fields"""
|
||||
if not traffic_data:
|
||||
return False
|
||||
|
||||
required_fields = ['date']
|
||||
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
|
||||
enhanced_fields = ['pedestrian_count', 'congestion_level', 'source']
|
||||
city_specific_fields = ['city', 'measurement_point_id', 'district']
|
||||
|
||||
valid_records = 0
|
||||
enhanced_records = 0
|
||||
city_aware_records = 0
|
||||
|
||||
for record in traffic_data:
|
||||
# Check required fields
|
||||
if not all(field in record for field in required_fields):
|
||||
continue
|
||||
record_score = 0
|
||||
|
||||
# Check at least one traffic field exists
|
||||
# Check required fields
|
||||
if all(field in record and record[field] is not None for field in required_fields):
|
||||
record_score += 1
|
||||
|
||||
# Check traffic data fields
|
||||
if any(field in record and record[field] is not None for field in traffic_fields):
|
||||
record_score += 1
|
||||
|
||||
# Check enhanced fields (pedestrian inference, etc.)
|
||||
enhanced_count = sum(1 for field in enhanced_fields
|
||||
if field in record and record[field] is not None)
|
||||
if enhanced_count >= 2: # At least 2 enhanced fields
|
||||
enhanced_records += 1
|
||||
record_score += 1
|
||||
|
||||
# Check city-specific awareness
|
||||
city_count = sum(1 for field in city_specific_fields
|
||||
if field in record and record[field] is not None)
|
||||
if city_count >= 1: # At least some city awareness
|
||||
city_aware_records += 1
|
||||
|
||||
# Record is valid if it has basic requirements
|
||||
if record_score >= 2:
|
||||
valid_records += 1
|
||||
|
||||
# Consider valid if at least 30% of records are valid (traffic data is often sparse)
|
||||
total_records = len(traffic_data)
|
||||
validity_threshold = 0.3
|
||||
is_valid = (valid_records / len(traffic_data)) >= validity_threshold
|
||||
enhancement_threshold = 0.2 # Lower threshold for enhanced features
|
||||
|
||||
if not is_valid:
|
||||
logger.warning(f"Traffic data validation failed: {valid_records}/{len(traffic_data)} valid records")
|
||||
basic_validity = (valid_records / total_records) >= validity_threshold
|
||||
has_enhancements = (enhanced_records / total_records) >= enhancement_threshold
|
||||
has_city_awareness = (city_aware_records / total_records) >= enhancement_threshold
|
||||
|
||||
return is_valid
|
||||
logger.info("Enhanced traffic data validation results",
|
||||
total_records=total_records,
|
||||
valid_records=valid_records,
|
||||
enhanced_records=enhanced_records,
|
||||
city_aware_records=city_aware_records,
|
||||
basic_validity=basic_validity,
|
||||
has_enhancements=has_enhancements,
|
||||
has_city_awareness=has_city_awareness)
|
||||
|
||||
if not basic_validity:
|
||||
logger.warning(f"Traffic data basic validation failed: {valid_records}/{total_records} valid records")
|
||||
|
||||
return basic_validity
|
||||
|
||||
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
|
||||
"""Legacy validation method - redirects to enhanced version"""
|
||||
return self._validate_traffic_data_enhanced(traffic_data)
|
||||
|
||||
def _validate_data_sources(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user