Improve the traffic fetching system

This commit is contained in:
Urtzi Alfaro
2025-08-10 17:31:38 +02:00
parent 312fdc8ef3
commit 3c2acc934a
16 changed files with 3866 additions and 1981 deletions

View File

@@ -0,0 +1,312 @@
# ================================================================
# services/data/app/core/performance.py
# ================================================================
"""
Performance optimization utilities for async operations
"""
import asyncio
import functools
from typing import Any, Callable, Dict, Optional, TypeVar
from datetime import datetime, timedelta, timezone
import hashlib
import json
import structlog
logger = structlog.get_logger()
T = TypeVar('T')
class AsyncCache:
"""Simple in-memory async cache with TTL"""
def __init__(self, default_ttl: int = 300):
self.cache: Dict[str, Dict[str, Any]] = {}
self.default_ttl = default_ttl
def _generate_key(self, *args, **kwargs) -> str:
"""Generate cache key from arguments"""
key_data = {
'args': args,
'kwargs': sorted(kwargs.items())
}
key_string = json.dumps(key_data, sort_keys=True, default=str)
return hashlib.md5(key_string.encode()).hexdigest()
def _is_expired(self, entry: Dict[str, Any]) -> bool:
"""Check if cache entry is expired"""
expires_at = entry.get('expires_at')
if not expires_at:
return True
return datetime.now(timezone.utc) > expires_at
async def get(self, key: str) -> Optional[Any]:
"""Get value from cache"""
if key in self.cache:
entry = self.cache[key]
if not self._is_expired(entry):
logger.debug("Cache hit", cache_key=key)
return entry['value']
else:
# Clean up expired entry
del self.cache[key]
logger.debug("Cache expired", cache_key=key)
logger.debug("Cache miss", cache_key=key)
return None
async def set(self, key: str, value: Any, ttl: Optional[int] = None) -> None:
"""Set value in cache"""
ttl = ttl or self.default_ttl
expires_at = datetime.now(timezone.utc) + timedelta(seconds=ttl)
self.cache[key] = {
'value': value,
'expires_at': expires_at,
'created_at': datetime.now(timezone.utc)
}
logger.debug("Cache set", cache_key=key, ttl=ttl)
async def clear(self) -> None:
"""Clear all cache entries"""
self.cache.clear()
logger.info("Cache cleared")
async def cleanup_expired(self) -> int:
"""Clean up expired entries"""
expired_keys = [
key for key, entry in self.cache.items()
if self._is_expired(entry)
]
for key in expired_keys:
del self.cache[key]
if expired_keys:
logger.info("Cleaned up expired cache entries", count=len(expired_keys))
return len(expired_keys)
def async_cache(ttl: int = 300, cache_instance: Optional[AsyncCache] = None):
"""Decorator for caching async function results"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
_cache = cache_instance or AsyncCache(ttl)
@functools.wraps(func)
async def wrapper(*args, **kwargs):
# Generate cache key
cache_key = _cache._generate_key(func.__name__, *args, **kwargs)
# Try to get from cache
cached_result = await _cache.get(cache_key)
if cached_result is not None:
return cached_result
# Execute function and cache result
result = await func(*args, **kwargs)
await _cache.set(cache_key, result, ttl)
return result
# Add cache management methods
wrapper.cache_clear = _cache.clear
wrapper.cache_cleanup = _cache.cleanup_expired
return wrapper
return decorator
class ConnectionPool:
"""Simple connection pool for HTTP clients"""
def __init__(self, max_connections: int = 10):
self.max_connections = max_connections
self.semaphore = asyncio.Semaphore(max_connections)
self._active_connections = 0
async def acquire(self):
"""Acquire a connection slot"""
await self.semaphore.acquire()
self._active_connections += 1
logger.debug("Connection acquired", active=self._active_connections, max=self.max_connections)
async def release(self):
"""Release a connection slot"""
self.semaphore.release()
self._active_connections = max(0, self._active_connections - 1)
logger.debug("Connection released", active=self._active_connections, max=self.max_connections)
async def __aenter__(self):
await self.acquire()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.release()
def rate_limit(calls: int, period: int):
"""Rate limiting decorator"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
call_times = []
lock = asyncio.Lock()
@functools.wraps(func)
async def wrapper(*args, **kwargs):
async with lock:
now = datetime.now(timezone.utc)
# Remove old call times
cutoff = now - timedelta(seconds=period)
call_times[:] = [t for t in call_times if t > cutoff]
# Check rate limit
if len(call_times) >= calls:
sleep_time = (call_times[0] + timedelta(seconds=period) - now).total_seconds()
if sleep_time > 0:
logger.warning("Rate limit reached, sleeping", sleep_time=sleep_time)
await asyncio.sleep(sleep_time)
# Record this call
call_times.append(now)
return await func(*args, **kwargs)
return wrapper
return decorator
async def batch_process(
items: list,
process_func: Callable,
batch_size: int = 10,
max_concurrency: int = 5
) -> list:
"""Process items in batches with controlled concurrency"""
results = []
semaphore = asyncio.Semaphore(max_concurrency)
async def process_batch(batch):
async with semaphore:
return await process_func(batch)
# Create batches
batches = [items[i:i + batch_size] for i in range(0, len(items), batch_size)]
logger.info("Processing items in batches",
total_items=len(items),
batches=len(batches),
batch_size=batch_size,
max_concurrency=max_concurrency)
# Process batches concurrently
batch_results = await asyncio.gather(
*[process_batch(batch) for batch in batches],
return_exceptions=True
)
# Flatten results
for batch_result in batch_results:
if isinstance(batch_result, Exception):
logger.error("Batch processing error", error=str(batch_result))
continue
if isinstance(batch_result, list):
results.extend(batch_result)
else:
results.append(batch_result)
logger.info("Batch processing completed",
processed_items=len(results),
total_batches=len(batches))
return results
class PerformanceMonitor:
"""Simple performance monitoring for async functions"""
def __init__(self):
self.metrics = {}
def record_execution(self, func_name: str, duration: float, success: bool = True):
"""Record function execution metrics"""
if func_name not in self.metrics:
self.metrics[func_name] = {
'call_count': 0,
'success_count': 0,
'error_count': 0,
'total_duration': 0.0,
'min_duration': float('inf'),
'max_duration': 0.0
}
metric = self.metrics[func_name]
metric['call_count'] += 1
metric['total_duration'] += duration
metric['min_duration'] = min(metric['min_duration'], duration)
metric['max_duration'] = max(metric['max_duration'], duration)
if success:
metric['success_count'] += 1
else:
metric['error_count'] += 1
def get_metrics(self, func_name: str = None) -> dict:
"""Get performance metrics"""
if func_name:
metric = self.metrics.get(func_name, {})
if metric and metric['call_count'] > 0:
metric['avg_duration'] = metric['total_duration'] / metric['call_count']
metric['success_rate'] = metric['success_count'] / metric['call_count']
return metric
return self.metrics
def monitor_performance(monitor: Optional[PerformanceMonitor] = None):
"""Decorator to monitor function performance"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
_monitor = monitor or PerformanceMonitor()
@functools.wraps(func)
async def wrapper(*args, **kwargs):
start_time = datetime.now(timezone.utc)
success = True
try:
result = await func(*args, **kwargs)
return result
except Exception as e:
success = False
raise
finally:
end_time = datetime.now(timezone.utc)
duration = (end_time - start_time).total_seconds()
_monitor.record_execution(func.__name__, duration, success)
logger.debug("Function performance",
function=func.__name__,
duration=duration,
success=success)
# Add metrics access
wrapper.get_metrics = lambda: _monitor.get_metrics(func.__name__)
return wrapper
return decorator
# Global instances
global_cache = AsyncCache(default_ttl=300)
global_connection_pool = ConnectionPool(max_connections=20)
global_performance_monitor = PerformanceMonitor()

View File

@@ -0,0 +1,10 @@
# ================================================================
# services/data/app/external/apis/__init__.py
# ================================================================
"""
External API clients module - Scalable architecture for multiple cities
"""
from .traffic import TrafficAPIClientFactory
__all__ = ["TrafficAPIClientFactory"]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,257 @@
# ================================================================
# services/data/app/external/apis/traffic.py
# ================================================================
"""
Traffic API abstraction layer for multiple cities
"""
import asyncio
from abc import ABC, abstractmethod
from datetime import datetime
from enum import Enum
from typing import Dict, List, Any, Optional, Tuple
import structlog
logger = structlog.get_logger()
class SupportedCity(Enum):
"""Supported cities for traffic data collection"""
MADRID = "madrid"
BARCELONA = "barcelona"
VALENCIA = "valencia"
class BaseTrafficClient(ABC):
"""
Abstract base class for city-specific traffic clients
Defines the contract that all traffic clients must implement
"""
def __init__(self, city: SupportedCity):
self.city = city
self.logger = structlog.get_logger().bind(city=city.value)
@abstractmethod
async def get_current_traffic(self, latitude: float, longitude: float) -> Optional[Dict[str, Any]]:
"""Get current traffic data for location"""
pass
@abstractmethod
async def get_historical_traffic(self, latitude: float, longitude: float,
start_date: datetime, end_date: datetime) -> List[Dict[str, Any]]:
"""Get historical traffic data"""
pass
@abstractmethod
async def get_events(self, latitude: float, longitude: float, radius_km: float = 5.0) -> List[Dict[str, Any]]:
"""Get traffic incidents and events"""
pass
@abstractmethod
def supports_location(self, latitude: float, longitude: float) -> bool:
"""Check if this client supports the given location"""
pass
class TrafficAPIClientFactory:
"""
Factory class to create appropriate traffic clients based on location
"""
# City geographical bounds
CITY_BOUNDS = {
SupportedCity.MADRID: {
'lat_min': 40.31, 'lat_max': 40.56,
'lon_min': -3.89, 'lon_max': -3.51
},
SupportedCity.BARCELONA: {
'lat_min': 41.32, 'lat_max': 41.47,
'lon_min': 2.05, 'lon_max': 2.25
},
SupportedCity.VALENCIA: {
'lat_min': 39.42, 'lat_max': 39.52,
'lon_min': -0.42, 'lon_max': -0.32
}
}
@classmethod
def get_client_for_location(cls, latitude: float, longitude: float) -> Optional[BaseTrafficClient]:
"""
Get appropriate traffic client for given location
Args:
latitude: Query location latitude
longitude: Query location longitude
Returns:
BaseTrafficClient instance or None if location not supported
"""
try:
# Check each city's bounds
for city, bounds in cls.CITY_BOUNDS.items():
if (bounds['lat_min'] <= latitude <= bounds['lat_max'] and
bounds['lon_min'] <= longitude <= bounds['lon_max']):
logger.info("Location matched to city",
city=city.value, lat=latitude, lon=longitude)
return cls._create_client(city)
# If no specific city matches, try to find closest supported city
closest_city = cls._find_closest_city(latitude, longitude)
if closest_city:
logger.info("Using closest city for location",
closest_city=closest_city.value, lat=latitude, lon=longitude)
return cls._create_client(closest_city)
logger.warning("No traffic client available for location",
lat=latitude, lon=longitude)
return None
except Exception as e:
logger.error("Error getting traffic client for location",
lat=latitude, lon=longitude, error=str(e))
return None
@classmethod
def _create_client(cls, city: SupportedCity) -> BaseTrafficClient:
"""Create traffic client for specific city"""
if city == SupportedCity.MADRID:
from .madrid_traffic_client import MadridTrafficClient
return MadridTrafficClient()
elif city == SupportedCity.BARCELONA:
# Future implementation
raise NotImplementedError(f"Traffic client for {city.value} not yet implemented")
elif city == SupportedCity.VALENCIA:
# Future implementation
raise NotImplementedError(f"Traffic client for {city.value} not yet implemented")
else:
raise ValueError(f"Unsupported city: {city}")
@classmethod
def _find_closest_city(cls, latitude: float, longitude: float) -> Optional[SupportedCity]:
"""Find closest supported city to given coordinates"""
import math
def distance(lat1, lon1, lat2, lon2):
"""Calculate distance between two coordinates"""
R = 6371 # Earth's radius in km
dlat = math.radians(lat2 - lat1)
dlon = math.radians(lon2 - lon1)
a = (math.sin(dlat/2) * math.sin(dlat/2) +
math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) *
math.sin(dlon/2) * math.sin(dlon/2))
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1-a))
return R * c
min_distance = float('inf')
closest_city = None
# City centers for distance calculation
city_centers = {
SupportedCity.MADRID: (40.4168, -3.7038),
SupportedCity.BARCELONA: (41.3851, 2.1734),
SupportedCity.VALENCIA: (39.4699, -0.3763)
}
for city, (city_lat, city_lon) in city_centers.items():
dist = distance(latitude, longitude, city_lat, city_lon)
if dist < min_distance and dist < 100: # Within 100km
min_distance = dist
closest_city = city
return closest_city
@classmethod
def get_supported_cities(cls) -> List[Dict[str, Any]]:
"""Get list of supported cities with their bounds"""
cities = []
for city, bounds in cls.CITY_BOUNDS.items():
cities.append({
"city": city.value,
"bounds": bounds,
"status": "active" if city == SupportedCity.MADRID else "planned"
})
return cities
class UniversalTrafficClient:
"""
Universal traffic client that delegates to appropriate city-specific clients
This is the main interface that external services should use
"""
def __init__(self):
self.factory = TrafficAPIClientFactory()
self.client_cache = {} # Cache clients for performance
async def get_current_traffic(self, latitude: float, longitude: float) -> Optional[Dict[str, Any]]:
"""Get current traffic data for any supported location"""
try:
client = self._get_client_for_location(latitude, longitude)
if client:
return await client.get_current_traffic(latitude, longitude)
else:
logger.warning("No traffic data available for location",
lat=latitude, lon=longitude)
return None
except Exception as e:
logger.error("Error getting current traffic",
lat=latitude, lon=longitude, error=str(e))
return None
async def get_historical_traffic(self, latitude: float, longitude: float,
start_date: datetime, end_date: datetime) -> List[Dict[str, Any]]:
"""Get historical traffic data for any supported location"""
try:
client = self._get_client_for_location(latitude, longitude)
if client:
return await client.get_historical_traffic(latitude, longitude, start_date, end_date)
else:
logger.warning("No historical traffic data available for location",
lat=latitude, lon=longitude)
return []
except Exception as e:
logger.error("Error getting historical traffic",
lat=latitude, lon=longitude, error=str(e))
return []
async def get_events(self, latitude: float, longitude: float, radius_km: float = 5.0) -> List[Dict[str, Any]]:
"""Get traffic events for any supported location"""
try:
client = self._get_client_for_location(latitude, longitude)
if client:
return await client.get_events(latitude, longitude, radius_km)
else:
return []
except Exception as e:
logger.error("Error getting traffic events",
lat=latitude, lon=longitude, error=str(e))
return []
def _get_client_for_location(self, latitude: float, longitude: float) -> Optional[BaseTrafficClient]:
"""Get cached or create new client for location"""
cache_key = f"{latitude:.4f},{longitude:.4f}"
if cache_key not in self.client_cache:
client = self.factory.get_client_for_location(latitude, longitude)
self.client_cache[cache_key] = client
return self.client_cache[cache_key]
def get_location_info(self, latitude: float, longitude: float) -> Dict[str, Any]:
"""Get information about traffic data availability for location"""
client = self._get_client_for_location(latitude, longitude)
if client:
return {
"supported": True,
"city": client.city.value,
"features": ["current_traffic", "historical_traffic", "events"]
}
else:
return {
"supported": False,
"city": None,
"features": [],
"message": "No traffic data available for this location"
}

View File

@@ -54,6 +54,19 @@ class BaseAPIClient:
logger.error("Unexpected error", error=str(e), url=url)
return None
async def get(self, url: str, headers: Optional[Dict] = None, timeout: Optional[int] = None) -> httpx.Response:
"""
Public GET method for direct HTTP requests
Returns the raw httpx Response object for maximum flexibility
"""
request_headers = headers or {}
request_timeout = httpx.Timeout(timeout if timeout else 30.0)
async with httpx.AsyncClient(timeout=request_timeout, follow_redirects=True) as client:
response = await client.get(url, headers=request_headers)
response.raise_for_status()
return response
async def _fetch_url_directly(self, url: str, headers: Optional[Dict] = None) -> Optional[Dict[str, Any]]:
"""Fetch data directly from a full URL (for AEMET datos URLs)"""
try:
@@ -123,4 +136,17 @@ class BaseAPIClient:
return None
except Exception as e:
logger.error("Unexpected error", error=str(e), url=url)
return None
return None
async def get(self, url: str, headers: Optional[Dict] = None, timeout: Optional[int] = None) -> httpx.Response:
"""
Public GET method for direct HTTP requests
Returns the raw httpx Response object for maximum flexibility
"""
request_headers = headers or {}
request_timeout = httpx.Timeout(timeout if timeout else 30.0)
async with httpx.AsyncClient(timeout=request_timeout, follow_redirects=True) as client:
response = await client.get(url, headers=request_headers)
response.raise_for_status()
return response

File diff suppressed because it is too large Load Diff

View File

@@ -1,30 +1,294 @@
# ================================================================
# services/data/app/models/traffic.py
# services/data/app/models/traffic.py - Enhanced for Multiple Cities
# ================================================================
"""Traffic data models"""
"""
Flexible traffic data models supporting multiple cities and extensible schemas
"""
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index
from sqlalchemy import Column, String, DateTime, Float, Integer, Text, Index, Boolean, JSON
from sqlalchemy.dialects.postgresql import UUID
import uuid
from datetime import datetime, timezone
from typing import Dict, Any, Optional
from shared.database.base import Base
class TrafficData(Base):
"""
Flexible traffic data model supporting multiple cities
Designed to accommodate varying data structures across different cities
"""
__tablename__ = "traffic_data"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
location_id = Column(String(100), nullable=False, index=True)
# Location and temporal data
location_id = Column(String(100), nullable=False, index=True) # "lat,lon" or city-specific ID
city = Column(String(50), nullable=False, index=True) # madrid, barcelona, valencia, etc.
date = Column(DateTime(timezone=True), nullable=False, index=True)
traffic_volume = Column(Integer, nullable=True) # vehicles per hour
pedestrian_count = Column(Integer, nullable=True) # pedestrians per hour
congestion_level = Column(String(20), nullable=True) # low/medium/high
average_speed = Column(Float, nullable=True) # km/h
source = Column(String(50), nullable=False, default="madrid_opendata")
raw_data = Column(Text, nullable=True)
# Core standardized traffic metrics (common across all cities)
traffic_volume = Column(Integer, nullable=True) # Vehicle count or intensity
congestion_level = Column(String(20), nullable=True) # low, medium, high, blocked
average_speed = Column(Float, nullable=True) # Average speed in km/h
# Enhanced metrics (may not be available for all cities)
occupation_percentage = Column(Float, nullable=True) # Road occupation %
load_percentage = Column(Float, nullable=True) # Traffic load %
pedestrian_count = Column(Integer, nullable=True) # Estimated pedestrian count
# Measurement point information
measurement_point_id = Column(String(100), nullable=True, index=True)
measurement_point_name = Column(String(500), nullable=True)
measurement_point_type = Column(String(50), nullable=True) # URB, M30, A, etc.
# Geographic data
latitude = Column(Float, nullable=True)
longitude = Column(Float, nullable=True)
district = Column(String(100), nullable=True) # City district/area
zone = Column(String(100), nullable=True) # Traffic zone or sector
# Data source and quality
source = Column(String(50), nullable=False, default="unknown") # madrid_opendata, synthetic, etc.
data_quality_score = Column(Float, nullable=True) # Quality score 0-100
is_synthetic = Column(Boolean, default=False)
has_pedestrian_inference = Column(Boolean, default=False)
# City-specific data (flexible JSON storage)
city_specific_data = Column(JSON, nullable=True) # Store city-specific fields
# Raw data backup
raw_data = Column(Text, nullable=True) # Original data for debugging
# Audit fields
tenant_id = Column(UUID(as_uuid=True), nullable=True, index=True) # For multi-tenancy
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc), onupdate=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc))
# Performance-optimized indexes
__table_args__ = (
# Core query patterns
Index('idx_traffic_location_date', 'location_id', 'date'),
Index('idx_traffic_city_date', 'city', 'date'),
Index('idx_traffic_tenant_date', 'tenant_id', 'date'),
# Advanced query patterns
Index('idx_traffic_city_location', 'city', 'location_id'),
Index('idx_traffic_measurement_point', 'city', 'measurement_point_id'),
Index('idx_traffic_district_date', 'city', 'district', 'date'),
# Training data queries
Index('idx_traffic_training', 'tenant_id', 'city', 'date', 'is_synthetic'),
Index('idx_traffic_quality', 'city', 'data_quality_score', 'date'),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert model to dictionary for API responses"""
result = {
'id': str(self.id),
'location_id': self.location_id,
'city': self.city,
'date': self.date.isoformat() if self.date else None,
'traffic_volume': self.traffic_volume,
'congestion_level': self.congestion_level,
'average_speed': self.average_speed,
'occupation_percentage': self.occupation_percentage,
'load_percentage': self.load_percentage,
'pedestrian_count': self.pedestrian_count,
'measurement_point_id': self.measurement_point_id,
'measurement_point_name': self.measurement_point_name,
'measurement_point_type': self.measurement_point_type,
'latitude': self.latitude,
'longitude': self.longitude,
'district': self.district,
'zone': self.zone,
'source': self.source,
'data_quality_score': self.data_quality_score,
'is_synthetic': self.is_synthetic,
'has_pedestrian_inference': self.has_pedestrian_inference,
'created_at': self.created_at.isoformat() if self.created_at else None
}
# Add city-specific data if present
if self.city_specific_data:
result['city_specific_data'] = self.city_specific_data
return result
def get_city_specific_field(self, field_name: str, default: Any = None) -> Any:
"""Safely get city-specific field value"""
if self.city_specific_data and isinstance(self.city_specific_data, dict):
return self.city_specific_data.get(field_name, default)
return default
def set_city_specific_field(self, field_name: str, value: Any) -> None:
"""Set city-specific field value"""
if not self.city_specific_data:
self.city_specific_data = {}
if not isinstance(self.city_specific_data, dict):
self.city_specific_data = {}
self.city_specific_data[field_name] = value
class TrafficMeasurementPoint(Base):
"""
Registry of traffic measurement points across all cities
Supports different city-specific measurement point schemas
"""
__tablename__ = "traffic_measurement_points"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Location and identification
city = Column(String(50), nullable=False, index=True)
measurement_point_id = Column(String(100), nullable=False, index=True) # City-specific ID
name = Column(String(500), nullable=True)
description = Column(Text, nullable=True)
# Geographic information
latitude = Column(Float, nullable=False)
longitude = Column(Float, nullable=False)
district = Column(String(100), nullable=True)
zone = Column(String(100), nullable=True)
# Classification
road_type = Column(String(50), nullable=True) # URB, M30, A, etc.
measurement_type = Column(String(50), nullable=True) # intensity, speed, etc.
point_category = Column(String(50), nullable=True) # urban, highway, ring_road
# Status and metadata
is_active = Column(Boolean, default=True)
installation_date = Column(DateTime(timezone=True), nullable=True)
last_data_received = Column(DateTime(timezone=True), nullable=True)
data_quality_rating = Column(Float, nullable=True) # Average quality 0-100
# City-specific point data
city_specific_metadata = Column(JSON, nullable=True)
# Audit fields
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc))
__table_args__ = (
Index('idx_traffic_location_date', 'location_id', 'date'),
# Ensure unique measurement points per city
Index('idx_unique_city_point', 'city', 'measurement_point_id', unique=True),
# Geographic queries
Index('idx_points_city_location', 'city', 'latitude', 'longitude'),
Index('idx_points_district', 'city', 'district'),
Index('idx_points_road_type', 'city', 'road_type'),
# Status queries
Index('idx_points_active', 'city', 'is_active', 'last_data_received'),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert measurement point to dictionary"""
return {
'id': str(self.id),
'city': self.city,
'measurement_point_id': self.measurement_point_id,
'name': self.name,
'description': self.description,
'latitude': self.latitude,
'longitude': self.longitude,
'district': self.district,
'zone': self.zone,
'road_type': self.road_type,
'measurement_type': self.measurement_type,
'point_category': self.point_category,
'is_active': self.is_active,
'installation_date': self.installation_date.isoformat() if self.installation_date else None,
'last_data_received': self.last_data_received.isoformat() if self.last_data_received else None,
'data_quality_rating': self.data_quality_rating,
'city_specific_metadata': self.city_specific_metadata,
'created_at': self.created_at.isoformat() if self.created_at else None
}
class TrafficDataBackgroundJob(Base):
"""
Track background data collection jobs for multiple cities
Supports scheduling and monitoring of data fetching processes
"""
__tablename__ = "traffic_background_jobs"
# Primary identification
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
# Job configuration
job_type = Column(String(50), nullable=False) # historical_fetch, cleanup, etc.
city = Column(String(50), nullable=False, index=True)
location_pattern = Column(String(200), nullable=True) # Location pattern or specific coords
# Scheduling
scheduled_at = Column(DateTime(timezone=True), nullable=False)
started_at = Column(DateTime(timezone=True), nullable=True)
completed_at = Column(DateTime(timezone=True), nullable=True)
# Status tracking
status = Column(String(20), nullable=False, default='pending') # pending, running, completed, failed
progress_percentage = Column(Float, default=0.0)
records_processed = Column(Integer, default=0)
records_stored = Column(Integer, default=0)
# Date range for data jobs
data_start_date = Column(DateTime(timezone=True), nullable=True)
data_end_date = Column(DateTime(timezone=True), nullable=True)
# Results and error handling
success_count = Column(Integer, default=0)
error_count = Column(Integer, default=0)
error_message = Column(Text, nullable=True)
job_metadata = Column(JSON, nullable=True) # Additional job-specific data
# Tenant association
tenant_id = Column(UUID(as_uuid=True), nullable=True, index=True)
# Audit fields
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(timezone.utc))
updated_at = Column(DateTime(timezone=True),
default=lambda: datetime.now(timezone.utc),
onupdate=lambda: datetime.now(timezone.utc))
__table_args__ = (
# Job monitoring
Index('idx_jobs_city_status', 'city', 'status', 'scheduled_at'),
Index('idx_jobs_tenant_status', 'tenant_id', 'status', 'scheduled_at'),
Index('idx_jobs_type_city', 'job_type', 'city', 'scheduled_at'),
# Cleanup queries
Index('idx_jobs_completed', 'status', 'completed_at'),
)
def to_dict(self) -> Dict[str, Any]:
"""Convert job to dictionary"""
return {
'id': str(self.id),
'job_type': self.job_type,
'city': self.city,
'location_pattern': self.location_pattern,
'scheduled_at': self.scheduled_at.isoformat() if self.scheduled_at else None,
'started_at': self.started_at.isoformat() if self.started_at else None,
'completed_at': self.completed_at.isoformat() if self.completed_at else None,
'status': self.status,
'progress_percentage': self.progress_percentage,
'records_processed': self.records_processed,
'records_stored': self.records_stored,
'data_start_date': self.data_start_date.isoformat() if self.data_start_date else None,
'data_end_date': self.data_end_date.isoformat() if self.data_end_date else None,
'success_count': self.success_count,
'error_count': self.error_count,
'error_message': self.error_message,
'job_metadata': self.job_metadata,
'tenant_id': str(self.tenant_id) if self.tenant_id else None,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
}

View File

@@ -0,0 +1,874 @@
# ================================================================
# services/data/app/repositories/traffic_repository.py
# ================================================================
"""
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
Follows existing repository architecture while adding city-specific functionality
"""
from typing import Optional, List, Dict, Any, Type, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
from sqlalchemy.orm import selectinload
from datetime import datetime, timezone, timedelta
import structlog
from .base import DataBaseRepository
from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataBackgroundJob
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrafficRepository(DataBaseRepository[TrafficData, TrafficDataCreate, Dict]):
"""
Enhanced repository for traffic data operations across multiple cities
Provides city-aware queries and advanced traffic analytics
"""
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
super().__init__(model_class, session, cache_ttl)
# ================================================================
# CORE TRAFFIC DATA OPERATIONS
# ================================================================
async def get_by_location_and_date_range(
self,
latitude: float,
longitude: float,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
city: Optional[str] = None,
tenant_id: Optional[str] = None,
skip: int = 0,
limit: int = 100
) -> List[TrafficData]:
"""Get traffic data by location and date range with city filtering"""
try:
location_id = f"{latitude:.4f},{longitude:.4f}"
# Build base query
query = select(self.model).where(self.model.location_id == location_id)
# Add city filter if specified
if city:
query = query.where(self.model.city == city)
# Add tenant filter if specified
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
# Add date range filters
if start_date:
start_date = self._ensure_utc_datetime(start_date)
query = query.where(self.model.date >= start_date)
if end_date:
end_date = self._ensure_utc_datetime(end_date)
query = query.where(self.model.date <= end_date)
# Order by date descending (most recent first)
query = query.order_by(desc(self.model.date))
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get traffic data by location and date range",
latitude=latitude, longitude=longitude,
city=city, error=str(e))
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
async def get_by_city_and_date_range(
self,
city: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
district: Optional[str] = None,
measurement_point_ids: Optional[List[str]] = None,
include_synthetic: bool = True,
tenant_id: Optional[str] = None,
skip: int = 0,
limit: int = 1000
) -> List[TrafficData]:
"""Get traffic data by city with advanced filtering options"""
try:
# Build base query
query = select(self.model).where(self.model.city == city)
# Add tenant filter if specified
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
# Add date range filters
if start_date:
start_date = self._ensure_utc_datetime(start_date)
query = query.where(self.model.date >= start_date)
if end_date:
end_date = self._ensure_utc_datetime(end_date)
query = query.where(self.model.date <= end_date)
# Add district filter
if district:
query = query.where(self.model.district == district)
# Add measurement point filter
if measurement_point_ids:
query = query.where(self.model.measurement_point_id.in_(measurement_point_ids))
# Filter synthetic data if requested
if not include_synthetic:
query = query.where(self.model.is_synthetic == False)
# Order by date and measurement point
query = query.order_by(desc(self.model.date), self.model.measurement_point_id)
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get traffic data by city",
city=city, district=district, error=str(e))
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
async def get_latest_by_measurement_points(
self,
measurement_point_ids: List[str],
city: str,
hours_back: int = 24
) -> List[TrafficData]:
"""Get latest traffic data for specific measurement points"""
try:
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
query = select(self.model).where(
and_(
self.model.city == city,
self.model.measurement_point_id.in_(measurement_point_ids),
self.model.date >= cutoff_time
)
).order_by(
self.model.measurement_point_id,
desc(self.model.date)
)
result = await self.session.execute(query)
all_records = result.scalars().all()
# Get the latest record for each measurement point
latest_records = {}
for record in all_records:
point_id = record.measurement_point_id
if point_id not in latest_records:
latest_records[point_id] = record
return list(latest_records.values())
except Exception as e:
logger.error("Failed to get latest traffic data by measurement points",
city=city, points=len(measurement_point_ids), error=str(e))
raise DatabaseError(f"Failed to get latest traffic data: {str(e)}")
# ================================================================
# ANALYTICS AND AGGREGATIONS
# ================================================================
async def get_traffic_statistics_by_city(
self,
city: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
group_by: str = "daily"
) -> List[Dict[str, Any]]:
"""Get aggregated traffic statistics by city"""
try:
# Determine date truncation based on group_by
if group_by == "hourly":
date_trunc = "hour"
elif group_by == "daily":
date_trunc = "day"
elif group_by == "weekly":
date_trunc = "week"
elif group_by == "monthly":
date_trunc = "month"
else:
raise ValidationError(f"Invalid group_by value: {group_by}")
# Build aggregation query
if self.session.bind.dialect.name == 'postgresql':
query = text("""
SELECT
DATE_TRUNC(:date_trunc, date) as period,
city,
district,
COUNT(*) as record_count,
AVG(traffic_volume) as avg_traffic_volume,
MAX(traffic_volume) as max_traffic_volume,
AVG(pedestrian_count) as avg_pedestrian_count,
AVG(average_speed) as avg_speed,
COUNT(CASE WHEN congestion_level = 'high' THEN 1 END) as high_congestion_count,
COUNT(CASE WHEN is_synthetic = false THEN 1 END) as real_data_count,
COUNT(CASE WHEN has_pedestrian_inference = true THEN 1 END) as pedestrian_inference_count
FROM traffic_data
WHERE city = :city
""")
else:
# SQLite fallback
query = text("""
SELECT
DATE(date) as period,
city,
district,
COUNT(*) as record_count,
AVG(traffic_volume) as avg_traffic_volume,
MAX(traffic_volume) as max_traffic_volume,
AVG(pedestrian_count) as avg_pedestrian_count,
AVG(average_speed) as avg_speed,
SUM(CASE WHEN congestion_level = 'high' THEN 1 ELSE 0 END) as high_congestion_count,
SUM(CASE WHEN is_synthetic = 0 THEN 1 ELSE 0 END) as real_data_count,
SUM(CASE WHEN has_pedestrian_inference = 1 THEN 1 ELSE 0 END) as pedestrian_inference_count
FROM traffic_data
WHERE city = :city
""")
params = {
"city": city,
"date_trunc": date_trunc
}
# Add date filters
if start_date:
query = text(str(query) + " AND date >= :start_date")
params["start_date"] = self._ensure_utc_datetime(start_date)
if end_date:
query = text(str(query) + " AND date <= :end_date")
params["end_date"] = self._ensure_utc_datetime(end_date)
# Add GROUP BY and ORDER BY
query = text(str(query) + " GROUP BY period, city, district ORDER BY period DESC")
result = await self.session.execute(query, params)
rows = result.fetchall()
# Convert to list of dictionaries
statistics = []
for row in rows:
statistics.append({
"period": group_by,
"date": row.period,
"city": row.city,
"district": row.district,
"record_count": row.record_count,
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
"max_traffic_volume": row.max_traffic_volume or 0,
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0),
"avg_speed": float(row.avg_speed or 0),
"high_congestion_count": row.high_congestion_count or 0,
"real_data_percentage": round((row.real_data_count or 0) / max(1, row.record_count) * 100, 2),
"pedestrian_inference_percentage": round((row.pedestrian_inference_count or 0) / max(1, row.record_count) * 100, 2)
})
return statistics
except Exception as e:
logger.error("Failed to get traffic statistics by city",
city=city, group_by=group_by, error=str(e))
raise DatabaseError(f"Traffic statistics query failed: {str(e)}")
async def get_congestion_heatmap_data(
self,
city: str,
start_date: datetime,
end_date: datetime,
time_granularity: str = "hour"
) -> List[Dict[str, Any]]:
"""Get congestion data for heatmap visualization"""
try:
if time_granularity == "hour":
time_extract = "EXTRACT(hour FROM date)"
elif time_granularity == "day_of_week":
time_extract = "EXTRACT(dow FROM date)"
else:
time_extract = "EXTRACT(hour FROM date)"
query = text(f"""
SELECT
{time_extract} as time_period,
district,
measurement_point_id,
latitude,
longitude,
AVG(CASE
WHEN congestion_level = 'low' THEN 1
WHEN congestion_level = 'medium' THEN 2
WHEN congestion_level = 'high' THEN 3
WHEN congestion_level = 'blocked' THEN 4
ELSE 1
END) as avg_congestion_score,
COUNT(*) as data_points,
AVG(traffic_volume) as avg_traffic_volume,
AVG(pedestrian_count) as avg_pedestrian_count
FROM traffic_data
WHERE city = :city
AND date >= :start_date
AND date <= :end_date
AND latitude IS NOT NULL
AND longitude IS NOT NULL
GROUP BY time_period, district, measurement_point_id, latitude, longitude
ORDER BY time_period, district, avg_congestion_score DESC
""")
params = {
"city": city,
"start_date": self._ensure_utc_datetime(start_date),
"end_date": self._ensure_utc_datetime(end_date)
}
result = await self.session.execute(query, params)
rows = result.fetchall()
heatmap_data = []
for row in rows:
heatmap_data.append({
"time_period": int(row.time_period or 0),
"district": row.district,
"measurement_point_id": row.measurement_point_id,
"latitude": float(row.latitude),
"longitude": float(row.longitude),
"avg_congestion_score": float(row.avg_congestion_score),
"data_points": row.data_points,
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0)
})
return heatmap_data
except Exception as e:
logger.error("Failed to get congestion heatmap data",
city=city, error=str(e))
raise DatabaseError(f"Congestion heatmap query failed: {str(e)}")
# ================================================================
# BULK OPERATIONS AND DATA MANAGEMENT
# ================================================================
async def create_bulk_traffic_data(
self,
traffic_records: List[Dict[str, Any]],
city: str,
tenant_id: Optional[str] = None
) -> List[TrafficData]:
"""Create multiple traffic records in bulk with enhanced validation"""
try:
# Ensure all records have city and tenant_id
for record in traffic_records:
record["city"] = city
if tenant_id:
record["tenant_id"] = tenant_id
# Ensure dates are timezone-aware
if "date" in record and record["date"]:
record["date"] = self._ensure_utc_datetime(record["date"])
# Enhanced validation
validated_records = []
for record in traffic_records:
if self._validate_traffic_record(record):
validated_records.append(record)
else:
logger.warning("Invalid traffic record skipped",
city=city, record_keys=list(record.keys()))
if not validated_records:
logger.warning("No valid traffic records to create", city=city)
return []
# Use bulk create with deduplication
created_records = await self.bulk_create_with_deduplication(validated_records)
logger.info("Bulk traffic data creation completed",
city=city, requested=len(traffic_records),
validated=len(validated_records), created=len(created_records))
return created_records
except Exception as e:
logger.error("Failed to create bulk traffic data",
city=city, record_count=len(traffic_records), error=str(e))
raise DatabaseError(f"Bulk traffic creation failed: {str(e)}")
async def bulk_create_with_deduplication(
self,
records: List[Dict[str, Any]]
) -> List[TrafficData]:
"""Bulk create with automatic deduplication based on location, city, and date"""
try:
if not records:
return []
# Extract unique keys for deduplication check
unique_keys = []
for record in records:
key = (
record.get('location_id'),
record.get('city'),
record.get('date'),
record.get('measurement_point_id')
)
unique_keys.append(key)
# Check for existing records
location_ids = [key[0] for key in unique_keys if key[0]]
cities = [key[1] for key in unique_keys if key[1]]
dates = [key[2] for key in unique_keys if key[2]]
# For large datasets, use chunked deduplication to avoid memory issues
if len(location_ids) > 1000:
logger.info(f"Large dataset detected ({len(records)} records), using chunked deduplication")
new_records = []
chunk_size = 1000
for i in range(0, len(records), chunk_size):
chunk_records = records[i:i + chunk_size]
chunk_keys = unique_keys[i:i + chunk_size]
# Get unique values for this chunk
chunk_location_ids = list(set(key[0] for key in chunk_keys if key[0]))
chunk_cities = list(set(key[1] for key in chunk_keys if key[1]))
chunk_dates = list(set(key[2] for key in chunk_keys if key[2]))
if chunk_location_ids and chunk_cities and chunk_dates:
existing_query = select(
self.model.location_id,
self.model.city,
self.model.date,
self.model.measurement_point_id
).where(
and_(
self.model.location_id.in_(chunk_location_ids),
self.model.city.in_(chunk_cities),
self.model.date.in_(chunk_dates)
)
)
result = await self.session.execute(existing_query)
chunk_existing_keys = set(result.fetchall())
# Filter chunk duplicates
for j, record in enumerate(chunk_records):
key = chunk_keys[j]
if key not in chunk_existing_keys:
new_records.append(record)
else:
new_records.extend(chunk_records)
logger.debug("Chunked deduplication completed",
total_records=len(records),
new_records=len(new_records))
records = new_records
elif location_ids and cities and dates:
existing_query = select(
self.model.location_id,
self.model.city,
self.model.date,
self.model.measurement_point_id
).where(
and_(
self.model.location_id.in_(location_ids),
self.model.city.in_(cities),
self.model.date.in_(dates)
)
)
result = await self.session.execute(existing_query)
existing_keys = set(result.fetchall())
# Filter out duplicates
new_records = []
for i, record in enumerate(records):
key = unique_keys[i]
if key not in existing_keys:
new_records.append(record)
logger.debug("Standard deduplication completed",
total_records=len(records),
existing_records=len(existing_keys),
new_records=len(new_records))
records = new_records
# Proceed with bulk creation
return await self.bulk_create(records)
except Exception as e:
logger.error("Failed bulk create with deduplication", error=str(e))
raise DatabaseError(f"Bulk create with deduplication failed: {str(e)}")
def _validate_traffic_record(self, record: Dict[str, Any]) -> bool:
"""Enhanced validation for traffic records"""
required_fields = ['date', 'city']
# Check required fields
for field in required_fields:
if not record.get(field):
return False
# Validate city
city = record.get('city', '').lower()
if city not in ['madrid', 'barcelona', 'valencia', 'test']: # Extendable list
return False
# Validate data ranges
traffic_volume = record.get('traffic_volume')
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 50000):
return False
pedestrian_count = record.get('pedestrian_count')
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
return False
average_speed = record.get('average_speed')
if average_speed is not None and (average_speed < 0 or average_speed > 200):
return False
congestion_level = record.get('congestion_level')
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
return False
return True
# ================================================================
# TRAINING DATA SPECIFIC OPERATIONS
# ================================================================
async def get_training_data_by_location(
self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
tenant_id: Optional[str] = None,
include_pedestrian_inference: bool = True
) -> List[Dict[str, Any]]:
"""Get optimized training data for ML models"""
try:
location_id = f"{latitude:.4f},{longitude:.4f}"
query = select(self.model).where(
and_(
self.model.location_id == location_id,
self.model.date >= self._ensure_utc_datetime(start_date),
self.model.date <= self._ensure_utc_datetime(end_date)
)
)
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
if include_pedestrian_inference:
# Prefer records with pedestrian inference
query = query.order_by(
desc(self.model.has_pedestrian_inference),
desc(self.model.data_quality_score),
self.model.date
)
else:
query = query.order_by(
desc(self.model.data_quality_score),
self.model.date
)
result = await self.session.execute(query)
records = result.scalars().all()
# Convert to training format with enhanced features
training_data = []
for record in records:
training_record = {
'date': record.date,
'traffic_volume': record.traffic_volume or 0,
'pedestrian_count': record.pedestrian_count or 0,
'congestion_level': record.congestion_level or 'medium',
'average_speed': record.average_speed or 25.0,
'city': record.city,
'district': record.district,
'measurement_point_id': record.measurement_point_id,
'source': record.source,
'is_synthetic': record.is_synthetic or False,
'has_pedestrian_inference': record.has_pedestrian_inference or False,
'data_quality_score': record.data_quality_score or 50.0,
# Enhanced features for training
'hour_of_day': record.date.hour if record.date else 12,
'day_of_week': record.date.weekday() if record.date else 0,
'month': record.date.month if record.date else 1,
# City-specific features
'city_specific_data': record.city_specific_data or {}
}
training_data.append(training_record)
logger.info("Retrieved training data",
location_id=location_id, records=len(training_data),
with_pedestrian_inference=sum(1 for r in training_data if r['has_pedestrian_inference']))
return training_data
except Exception as e:
logger.error("Failed to get training data",
latitude=latitude, longitude=longitude, error=str(e))
raise DatabaseError(f"Training data retrieval failed: {str(e)}")
async def get_historical_data_by_location(
self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
tenant_id: Optional[str] = None
) -> List[TrafficData]:
"""Get historical traffic data for a specific location and date range"""
return await self.get_by_location_and_date_range(
latitude=latitude,
longitude=longitude,
start_date=start_date,
end_date=end_date,
tenant_id=tenant_id,
limit=1000000 # Large limit for historical data
)
async def count_records_in_period(
self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
city: Optional[str] = None,
tenant_id: Optional[str] = None
) -> int:
"""Count traffic records for a specific location and time period"""
try:
location_id = f"{latitude:.4f},{longitude:.4f}"
query = select(func.count(self.model.id)).where(
and_(
self.model.location_id == location_id,
self.model.date >= self._ensure_utc_datetime(start_date),
self.model.date <= self._ensure_utc_datetime(end_date)
)
)
if city:
query = query.where(self.model.city == city)
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
result = await self.session.execute(query)
count = result.scalar()
return count or 0
except Exception as e:
logger.error("Failed to count records in period",
latitude=latitude, longitude=longitude, error=str(e))
raise DatabaseError(f"Record count failed: {str(e)}")
# ================================================================
# DATA QUALITY AND MAINTENANCE
# ================================================================
async def update_data_quality_scores(self, city: str) -> int:
"""Update data quality scores based on various criteria"""
try:
# Calculate quality scores based on data completeness and consistency
query = text("""
UPDATE traffic_data
SET data_quality_score = (
CASE
WHEN traffic_volume IS NOT NULL THEN 20 ELSE 0 END +
CASE
WHEN pedestrian_count IS NOT NULL THEN 20 ELSE 0 END +
CASE
WHEN average_speed IS NOT NULL AND average_speed > 0 THEN 20 ELSE 0 END +
CASE
WHEN congestion_level IS NOT NULL THEN 15 ELSE 0 END +
CASE
WHEN measurement_point_id IS NOT NULL THEN 10 ELSE 0 END +
CASE
WHEN district IS NOT NULL THEN 10 ELSE 0 END +
CASE
WHEN has_pedestrian_inference = true THEN 5 ELSE 0 END
),
updated_at = :updated_at
WHERE city = :city AND data_quality_score IS NULL
""")
params = {
"city": city,
"updated_at": datetime.now(timezone.utc)
}
result = await self.session.execute(query, params)
updated_count = result.rowcount
await self.session.commit()
logger.info("Updated data quality scores",
city=city, updated_count=updated_count)
return updated_count
except Exception as e:
logger.error("Failed to update data quality scores",
city=city, error=str(e))
await self.session.rollback()
raise DatabaseError(f"Data quality update failed: {str(e)}")
async def cleanup_old_synthetic_data(
self,
city: str,
days_to_keep: int = 90
) -> int:
"""Clean up old synthetic data while preserving real data"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
query = delete(self.model).where(
and_(
self.model.city == city,
self.model.is_synthetic == True,
self.model.date < cutoff_date
)
)
result = await self.session.execute(query)
deleted_count = result.rowcount
await self.session.commit()
logger.info("Cleaned up old synthetic data",
city=city, deleted_count=deleted_count, days_kept=days_to_keep)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old synthetic data",
city=city, error=str(e))
await self.session.rollback()
raise DatabaseError(f"Synthetic data cleanup failed: {str(e)}")
class TrafficMeasurementPointRepository(DataBaseRepository[TrafficMeasurementPoint, Dict, Dict]):
"""Repository for traffic measurement points across cities"""
async def get_points_near_location(
self,
latitude: float,
longitude: float,
city: str,
radius_km: float = 10.0,
limit: int = 20
) -> List[TrafficMeasurementPoint]:
"""Get measurement points near a location using spatial query"""
try:
# Simple distance calculation (for more precise, use PostGIS)
query = text("""
SELECT *,
(6371 * acos(
cos(radians(:lat)) * cos(radians(latitude)) *
cos(radians(longitude) - radians(:lon)) +
sin(radians(:lat)) * sin(radians(latitude))
)) as distance_km
FROM traffic_measurement_points
WHERE city = :city
AND is_active = true
HAVING distance_km <= :radius_km
ORDER BY distance_km
LIMIT :limit
""")
params = {
"lat": latitude,
"lon": longitude,
"city": city,
"radius_km": radius_km,
"limit": limit
}
result = await self.session.execute(query, params)
rows = result.fetchall()
# Convert rows to model instances
points = []
for row in rows:
point = TrafficMeasurementPoint()
for key, value in row._mapping.items():
if hasattr(point, key) and key != 'distance_km':
setattr(point, key, value)
points.append(point)
return points
except Exception as e:
logger.error("Failed to get measurement points near location",
latitude=latitude, longitude=longitude, city=city, error=str(e))
raise DatabaseError(f"Measurement points query failed: {str(e)}")
class TrafficBackgroundJobRepository(DataBaseRepository[TrafficDataBackgroundJob, Dict, Dict]):
"""Repository for managing background traffic data jobs"""
async def get_pending_jobs_by_city(self, city: str) -> List[TrafficDataBackgroundJob]:
"""Get pending background jobs for a specific city"""
try:
query = select(self.model).where(
and_(
self.model.city == city,
self.model.status == 'pending'
)
).order_by(self.model.scheduled_at)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get pending jobs by city", city=city, error=str(e))
raise DatabaseError(f"Background jobs query failed: {str(e)}")
async def update_job_progress(
self,
job_id: str,
progress_percentage: float,
records_processed: int = 0,
records_stored: int = 0
) -> bool:
"""Update job progress"""
try:
query = update(self.model).where(
self.model.id == job_id
).values(
progress_percentage=progress_percentage,
records_processed=records_processed,
records_stored=records_stored,
updated_at=datetime.now(timezone.utc)
)
result = await self.session.execute(query)
await self.session.commit()
return result.rowcount > 0
except Exception as e:
logger.error("Failed to update job progress", job_id=job_id, error=str(e))
await self.session.rollback()
raise DatabaseError(f"Job progress update failed: {str(e)}")

View File

@@ -3,7 +3,7 @@
# ================================================================
"""Sales data schemas"""
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from datetime import datetime
from typing import Optional, List, Dict, Any
from uuid import UUID
@@ -20,7 +20,8 @@ class SalesDataCreate(BaseModel):
source: str = Field(default="manual", max_length=50)
notes: Optional[str] = Field(None, max_length=500)
@validator('product_name')
@field_validator('product_name')
@classmethod
def normalize_product_name(cls, v):
return v.strip().lower()

View File

@@ -3,7 +3,7 @@
# ================================================================
"""Traffic data schemas"""
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from datetime import datetime
from typing import Optional, List
from uuid import UUID
@@ -14,7 +14,7 @@ class TrafficDataBase(BaseModel):
date: datetime = Field(..., description="Date and time of traffic measurement")
traffic_volume: Optional[int] = Field(None, ge=0, description="Vehicles per hour")
pedestrian_count: Optional[int] = Field(None, ge=0, description="Pedestrians per hour")
congestion_level: Optional[str] = Field(None, regex="^(low|medium|high)$", description="Traffic congestion level")
congestion_level: Optional[str] = Field(None, pattern="^(low|medium|high)$", description="Traffic congestion level")
average_speed: Optional[float] = Field(None, ge=0, le=200, description="Average speed in km/h")
source: str = Field("madrid_opendata", max_length=50, description="Data source")
raw_data: Optional[str] = Field(None, description="Raw data from source")
@@ -27,7 +27,7 @@ class TrafficDataUpdate(BaseModel):
"""Schema for updating traffic data"""
traffic_volume: Optional[int] = Field(None, ge=0)
pedestrian_count: Optional[int] = Field(None, ge=0)
congestion_level: Optional[str] = Field(None, regex="^(low|medium|high)$")
congestion_level: Optional[str] = Field(None, pattern="^(low|medium|high)$")
average_speed: Optional[float] = Field(None, ge=0, le=200)
raw_data: Optional[str] = None
@@ -37,7 +37,8 @@ class TrafficDataResponse(TrafficDataBase):
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@validator('id', pre=True)
@field_validator('id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
if isinstance(v, UUID):
return str(v)

View File

@@ -3,7 +3,7 @@
# ================================================================
"""Weather data schemas"""
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from datetime import datetime
from typing import Optional, List
from uuid import UUID
@@ -41,7 +41,8 @@ class WeatherDataResponse(WeatherDataBase):
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@validator('id', pre=True)
@field_validator('id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
if isinstance(v, UUID):
return str(v)
@@ -76,7 +77,8 @@ class WeatherForecastResponse(WeatherForecastBase):
created_at: datetime = Field(..., description="Creation timestamp")
updated_at: datetime = Field(..., description="Last update timestamp")
@validator('id', pre=True)
@field_validator('id', mode='before')
@classmethod
def convert_uuid_to_string(cls, v):
if isinstance(v, UUID):
return str(v)

View File

@@ -1,122 +1,283 @@
# ================================================================
# services/data/app/services/traffic_service.py - FIXED VERSION
# services/data/app/services/traffic_service.py
# ================================================================
"""Traffic data service with improved error handling"""
"""
Abstracted Traffic Service - Universal interface for traffic data across multiple cities
"""
from typing import List, Dict, Any, Optional
from datetime import datetime, timedelta
import asyncio
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_
import structlog
from app.external.apis.traffic import UniversalTrafficClient
from app.models.traffic import TrafficData
from app.external.madrid_opendata import MadridOpenDataClient
from app.schemas.external import TrafficDataResponse
import uuid
from app.core.performance import (
async_cache,
monitor_performance,
global_connection_pool,
global_performance_monitor,
batch_process
)
logger = structlog.get_logger()
class TrafficService:
"""
Abstracted traffic service providing unified interface for traffic data
Routes requests to appropriate city-specific clients automatically
"""
def __init__(self):
self.madrid_client = MadridOpenDataClient()
self.universal_client = UniversalTrafficClient()
self.logger = structlog.get_logger(__name__)
async def get_current_traffic(self, latitude: float, longitude: float) -> Optional[TrafficDataResponse]:
"""Get current traffic data for location"""
@async_cache(ttl=300) # Cache for 5 minutes
@monitor_performance(monitor=global_performance_monitor)
async def get_current_traffic(
self,
latitude: float,
longitude: float,
tenant_id: Optional[str] = None
) -> Optional[Dict[str, Any]]:
"""
Get current traffic data for any supported location
Args:
latitude: Query location latitude
longitude: Query location longitude
tenant_id: Optional tenant identifier for logging/analytics
Returns:
Dict with current traffic data or None if not available
"""
try:
logger.debug("Getting current traffic", lat=latitude, lon=longitude)
traffic_data = await self.madrid_client.get_current_traffic(latitude, longitude)
self.logger.info("Getting current traffic data",
lat=latitude, lon=longitude, tenant_id=tenant_id)
# Delegate to universal client
traffic_data = await self.universal_client.get_current_traffic(latitude, longitude)
if traffic_data:
logger.debug("Traffic data received", source=traffic_data.get('source'))
# Add service metadata
traffic_data['service_metadata'] = {
'request_timestamp': datetime.now().isoformat(),
'tenant_id': tenant_id,
'service_version': '2.0',
'query_location': {'latitude': latitude, 'longitude': longitude}
}
# Validate and clean traffic data before creating response
# Use keyword arguments instead of unpacking
response = TrafficDataResponse(
date=traffic_data.get("date", datetime.now()),
traffic_volume=int(traffic_data.get("traffic_volume", 100)),
pedestrian_count=int(traffic_data.get("pedestrian_count", 150)),
congestion_level=str(traffic_data.get("congestion_level", "medium")),
average_speed=float(traffic_data.get("average_speed", 25.0)), # Fixed: use float, not int
source=str(traffic_data.get("source", "unknown"))
)
self.logger.info("Successfully retrieved current traffic data",
lat=latitude, lon=longitude,
source=traffic_data.get('source', 'unknown'))
logger.debug("Successfully created traffic response",
traffic_volume=response.traffic_volume,
congestion_level=response.congestion_level)
return response
return traffic_data
else:
logger.warning("No traffic data received from Madrid client")
self.logger.warning("No current traffic data available",
lat=latitude, lon=longitude)
return None
except Exception as e:
logger.error("Failed to get current traffic", error=str(e), lat=latitude, lon=longitude)
# Log the full traceback for debugging
import traceback
logger.error("Traffic service traceback", traceback=traceback.format_exc())
self.logger.error("Error getting current traffic data",
lat=latitude, lon=longitude, error=str(e))
return None
async def get_historical_traffic(self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
db: AsyncSession) -> List[TrafficDataResponse]:
"""Get historical traffic data with enhanced storage for re-training"""
@async_cache(ttl=1800) # Cache for 30 minutes (historical data changes less frequently)
@monitor_performance(monitor=global_performance_monitor)
async def get_historical_traffic(
self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
tenant_id: Optional[str] = None,
db: Optional[AsyncSession] = None
) -> List[Dict[str, Any]]:
"""
Get historical traffic data for any supported location with database storage
Args:
latitude: Query location latitude
longitude: Query location longitude
start_date: Start date for historical data
end_date: End date for historical data
tenant_id: Optional tenant identifier
db: Optional database session for storage
Returns:
List of historical traffic data dictionaries
"""
try:
logger.debug("Getting historical traffic",
lat=latitude, lon=longitude,
start=start_date, end=end_date)
self.logger.info("Getting historical traffic data",
lat=latitude, lon=longitude,
start=start_date, end=end_date, tenant_id=tenant_id)
# Validate date range
if start_date >= end_date:
self.logger.warning("Invalid date range", start=start_date, end=end_date)
return []
# Check database first
location_id = f"{latitude:.4f},{longitude:.4f}"
stmt = select(TrafficData).where(
and_(
TrafficData.location_id == location_id,
TrafficData.date >= start_date,
TrafficData.date <= end_date
)
).order_by(TrafficData.date)
result = await db.execute(stmt)
db_records = result.scalars().all()
# Check database first if session provided
if db:
stmt = select(TrafficData).where(
and_(
TrafficData.location_id == location_id,
TrafficData.date >= start_date,
TrafficData.date <= end_date
)
).order_by(TrafficData.date)
result = await db.execute(stmt)
db_records = result.scalars().all()
if db_records:
self.logger.info("Historical traffic data found in database",
count=len(db_records))
return [self._convert_db_record_to_dict(record) for record in db_records]
if db_records:
logger.debug("Historical traffic data found in database", count=len(db_records))
return [TrafficDataResponse(
date=record.date,
traffic_volume=record.traffic_volume,
pedestrian_count=record.pedestrian_count,
congestion_level=record.congestion_level,
average_speed=record.average_speed,
source=record.source
) for record in db_records]
# If not in database, fetch from API and store
logger.debug("Fetching historical data from MADRID OPEN DATA")
traffic_data = await self.madrid_client.get_historical_traffic(
# Delegate to universal client
traffic_data = await self.universal_client.get_historical_traffic(
latitude, longitude, start_date, end_date
)
if traffic_data:
# Enhanced storage with better error handling and validation
stored_count = await self._store_traffic_data_batch(
traffic_data, location_id, db
)
logger.info("Traffic data stored for re-training",
fetched=len(traffic_data), stored=stored_count, location=location_id)
return [TrafficDataResponse(**item) for item in traffic_data]
# Add service metadata to each record
for record in traffic_data:
record['service_metadata'] = {
'request_timestamp': datetime.now().isoformat(),
'tenant_id': tenant_id,
'service_version': '2.0',
'query_location': {'latitude': latitude, 'longitude': longitude},
'date_range': {
'start': start_date.isoformat(),
'end': end_date.isoformat()
}
}
# Store in database if session provided
if db:
stored_count = await self._store_traffic_data_batch(
traffic_data, location_id, db
)
self.logger.info("Traffic data stored for re-training",
fetched=len(traffic_data), stored=stored_count,
location=location_id)
self.logger.info("Successfully retrieved historical traffic data",
lat=latitude, lon=longitude, records=len(traffic_data))
return traffic_data
else:
logger.warning("No historical traffic data received")
self.logger.info("No historical traffic data available",
lat=latitude, lon=longitude)
return []
except Exception as e:
logger.error("Failed to get historical traffic", error=str(e))
self.logger.error("Error getting historical traffic data",
lat=latitude, lon=longitude, error=str(e))
return []
def _convert_db_record_to_dict(self, record: TrafficData) -> Dict[str, Any]:
"""Convert database record to dictionary format"""
return {
'date': record.date,
'traffic_volume': record.traffic_volume,
'pedestrian_count': record.pedestrian_count,
'congestion_level': record.congestion_level,
'average_speed': record.average_speed,
'source': record.source,
'location_id': record.location_id,
'raw_data': record.raw_data
}
async def get_traffic_events(
self,
latitude: float,
longitude: float,
radius_km: float = 5.0,
tenant_id: Optional[str] = None
) -> List[Dict[str, Any]]:
"""
Get traffic events and incidents for any supported location
Args:
latitude: Query location latitude
longitude: Query location longitude
radius_km: Search radius in kilometers
tenant_id: Optional tenant identifier
Returns:
List of traffic events
"""
try:
self.logger.info("Getting traffic events",
lat=latitude, lon=longitude, radius=radius_km, tenant_id=tenant_id)
# Delegate to universal client
events = await self.universal_client.get_events(latitude, longitude, radius_km)
# Add metadata to events
for event in events:
event['service_metadata'] = {
'request_timestamp': datetime.now().isoformat(),
'tenant_id': tenant_id,
'service_version': '2.0',
'query_location': {'latitude': latitude, 'longitude': longitude},
'search_radius_km': radius_km
}
self.logger.info("Retrieved traffic events",
lat=latitude, lon=longitude, events=len(events))
return events
except Exception as e:
self.logger.error("Error getting traffic events",
lat=latitude, lon=longitude, error=str(e))
return []
def get_location_info(self, latitude: float, longitude: float) -> Dict[str, Any]:
"""
Get information about traffic data availability for location
Args:
latitude: Query location latitude
longitude: Query location longitude
Returns:
Dict with location support information
"""
try:
info = self.universal_client.get_location_info(latitude, longitude)
# Add service layer information
info['service_layer'] = {
'version': '2.0',
'abstraction_level': 'universal',
'supported_operations': [
'current_traffic',
'historical_traffic',
'traffic_events',
'bulk_requests'
]
}
return info
except Exception as e:
self.logger.error("Error getting location info",
lat=latitude, lon=longitude, error=str(e))
return {
'supported': False,
'error': str(e),
'service_layer': {'version': '2.0'}
}
async def store_traffic_data(self,
latitude: float,
longitude: float,
@@ -176,7 +337,8 @@ class TrafficService:
else:
existing_dates = set()
# Store only new records
# Prepare batch of new records for bulk insert
batch_records = []
for data in traffic_data:
try:
record_date = data.get('date')
@@ -188,32 +350,41 @@ class TrafficService:
logger.warning("Invalid traffic data, skipping", data=data)
continue
traffic_record = TrafficData(
location_id=location_id,
date=record_date,
traffic_volume=data.get('traffic_volume'),
pedestrian_count=data.get('pedestrian_count'),
congestion_level=data.get('congestion_level'),
average_speed=data.get('average_speed'),
source=data.get('source', 'madrid_opendata'),
raw_data=str(data)
)
db.add(traffic_record)
stored_count += 1
# Commit in batches to avoid memory issues
if stored_count % 100 == 0:
await db.commit()
logger.debug(f"Committed batch of {stored_count} records")
# Prepare record data for bulk insert
record_data = {
'location_id': location_id,
'date': record_date,
'traffic_volume': data.get('traffic_volume'),
'pedestrian_count': data.get('pedestrian_count'),
'congestion_level': data.get('congestion_level'),
'average_speed': data.get('average_speed'),
'source': data.get('source', 'madrid_opendata'),
'raw_data': str(data)
}
batch_records.append(record_data)
except Exception as record_error:
logger.warning("Failed to store individual traffic record",
logger.warning("Failed to prepare traffic record",
error=str(record_error), data=data)
continue
# Final commit
await db.commit()
# Use efficient bulk insert instead of individual records
if batch_records:
# Process in chunks to avoid memory issues
chunk_size = 5000
for i in range(0, len(batch_records), chunk_size):
chunk = batch_records[i:i + chunk_size]
# Use SQLAlchemy bulk insert for maximum performance
await db.execute(
TrafficData.__table__.insert(),
chunk
)
await db.commit()
stored_count += len(chunk)
logger.debug(f"Bulk inserted {len(chunk)} records (total: {stored_count})")
logger.info(f"Successfully stored {stored_count} traffic records for location {location_id}")
except Exception as e:

View File

@@ -1,405 +0,0 @@
#!/usr/bin/env python3
"""
Updated Madrid Historical Traffic test for pytest inside Docker
Configured for June 2025 data availability (last available historical data)
"""
import pytest
import asyncio
from datetime import datetime, timedelta
from typing import List, Dict, Any
# Import from the actual service
from app.external.madrid_opendata import MadridOpenDataClient
from app.core.config import settings
import structlog
# Configure pytest for async
pytestmark = pytest.mark.asyncio
# Use actual logger
logger = structlog.get_logger()
class TestMadridTrafficInside:
"""Test class for Madrid traffic functionality inside Docker"""
@pytest.fixture
def client(self):
"""Create Madrid client for testing"""
return MadridOpenDataClient()
@pytest.fixture
def madrid_coords(self):
"""Madrid center coordinates"""
return 40.4168, -3.7038
@pytest.fixture
def june_2025_dates(self):
"""Date ranges for June 2025 (last available historical data)"""
return {
"quick": {
"start": datetime(2025, 6, 1, 0, 0),
"end": datetime(2025, 6, 1, 6, 0) # 6 hours on June 1st
},
"one_day": {
"start": datetime(2025, 6, 15, 0, 0), # Mid-June
"end": datetime(2025, 6, 16, 0, 0) # One full day
},
"three_days": {
"start": datetime(2025, 6, 10, 0, 0),
"end": datetime(2025, 6, 13, 0, 0) # 3 days in June
},
"recent_synthetic": {
"start": datetime.now() - timedelta(hours=6),
"end": datetime.now() # Recent data (will be synthetic)
}
}
async def test_quick_historical_traffic_june2025(self, client, madrid_coords, june_2025_dates):
"""Test quick historical traffic data from June 2025"""
lat, lon = madrid_coords
date_range = june_2025_dates["quick"]
start_time = date_range["start"]
end_time = date_range["end"]
print(f"\n=== Quick Test (June 2025 - 6 hours) ===")
print(f"Location: {lat}, {lon}")
print(f"Date range: {start_time.strftime('%Y-%m-%d %H:%M')} to {end_time.strftime('%Y-%m-%d %H:%M')}")
print(f"Note: Testing with June 2025 data (last available historical month)")
# Test the function
execution_start = datetime.now()
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
execution_time = (datetime.now() - execution_start).total_seconds()
print(f"⏱️ Execution time: {execution_time:.2f} seconds")
print(f"📊 Records returned: {len(result)}")
# Assertions
assert isinstance(result, list), "Result should be a list"
assert len(result) > 0, "Should return at least some records"
assert execution_time < 5000, "Should execute in reasonable time (allowing for ZIP download)"
# Check first record structure
if result:
sample = result[0]
print(f"📋 Sample record keys: {list(sample.keys())}")
print(f"📡 Data source: {sample.get('source', 'unknown')}")
# Required fields
required_fields = ['date', 'traffic_volume', 'congestion_level', 'average_speed', 'source']
for field in required_fields:
assert field in sample, f"Missing required field: {field}"
# Data validation
assert isinstance(sample['traffic_volume'], int), "Traffic volume should be int"
assert 0 <= sample['traffic_volume'] <= 1000, "Traffic volume should be reasonable"
assert sample['congestion_level'] in ['low', 'medium', 'high', 'blocked'], "Invalid congestion level"
assert 5 <= sample['average_speed'] <= 100, "Speed should be reasonable"
assert isinstance(sample['date'], datetime), "Date should be datetime object"
# Check if we got real Madrid data or synthetic
if sample['source'] == 'madrid_opendata_zip':
print(f"🎉 SUCCESS: Got real Madrid historical data from ZIP!")
else:
print(f" Got synthetic data (real data may not be available)")
print(f"✅ All validations passed")
async def test_one_day_june2025(self, client, madrid_coords, june_2025_dates):
"""Test one day of June 2025 historical traffic data"""
lat, lon = madrid_coords
date_range = june_2025_dates["one_day"]
start_time = date_range["start"]
end_time = date_range["end"]
print(f"\n=== One Day Test (June 15, 2025) ===")
print(f"Date range: {start_time.strftime('%Y-%m-%d %H:%M')} to {end_time.strftime('%Y-%m-%d %H:%M')}")
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
print(f"📊 Records returned: {len(result)}")
# Should have roughly 24 records (one per hour)
assert len(result) >= 20, "Should have at least 20 hourly records for one day"
assert len(result) <= 5000, "Should not have more than 30 records for one day"
# Check data source
if result:
sources = set(r['source'] for r in result)
print(f"📡 Data sources: {', '.join(sources)}")
# If we got real data, check for realistic measurement point IDs
real_data_records = [r for r in result if r['source'] == 'madrid_opendata_zip']
if real_data_records:
point_ids = set(r['measurement_point_id'] for r in real_data_records)
print(f"🏷️ Real measurement points found: {len(point_ids)}")
print(f" Sample IDs: {list(point_ids)[:3]}")
# Check traffic patterns
if len(result) >= 24:
# Find rush hour records (7-9 AM, 6-8 PM)
rush_hour_records = [r for r in result if 7 <= r['date'].hour <= 9 or 18 <= r['date'].hour <= 20]
night_records = [r for r in result if r['date'].hour <= 6 or r['date'].hour >= 22]
if rush_hour_records and night_records:
avg_rush_traffic = sum(r['traffic_volume'] for r in rush_hour_records) / len(rush_hour_records)
avg_night_traffic = sum(r['traffic_volume'] for r in night_records) / len(night_records)
print(f"📈 Rush hour avg traffic: {avg_rush_traffic:.1f}")
print(f"🌙 Night avg traffic: {avg_night_traffic:.1f}")
# Rush hour should typically have more traffic than night
if avg_rush_traffic > avg_night_traffic:
print(f"✅ Traffic patterns look realistic")
else:
print(f"⚠️ Traffic patterns unusual (not necessarily wrong)")
async def test_three_days_june2025(self, client, madrid_coords, june_2025_dates):
"""Test three days of June 2025 historical traffic data"""
lat, lon = madrid_coords
date_range = june_2025_dates["three_days"]
start_time = date_range["start"]
end_time = date_range["end"]
print(f"\n=== Three Days Test (June 10-13, 2025) ===")
print(f"Date range: {start_time.strftime('%Y-%m-%d')} to {end_time.strftime('%Y-%m-%d')}")
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
print(f"📊 Records returned: {len(result)}")
# Should have roughly 72 records (24 hours * 3 days)
assert len(result) >= 60, "Should have at least 60 records for 3 days"
assert len(result) <= 5000, "Should not have more than 90 records for 3 days"
# Check data sources
sources = set(r['source'] for r in result)
print(f"📡 Data sources: {', '.join(sources)}")
# Calculate statistics
traffic_volumes = [r['traffic_volume'] for r in result]
speeds = [r['average_speed'] for r in result]
avg_traffic = sum(traffic_volumes) / len(traffic_volumes)
max_traffic = max(traffic_volumes)
min_traffic = min(traffic_volumes)
avg_speed = sum(speeds) / len(speeds)
print(f"📈 Statistics:")
print(f" Average traffic: {avg_traffic:.1f}")
print(f" Max traffic: {max_traffic}")
print(f" Min traffic: {min_traffic}")
print(f" Average speed: {avg_speed:.1f} km/h")
# Analyze by data source
real_data_records = [r for r in result if r['source'] == 'madrid_opendata_zip']
synthetic_records = [r for r in result if r['source'] != 'madrid_opendata_zip']
print(f"🔍 Data breakdown:")
print(f" Real Madrid data: {len(real_data_records)} records")
print(f" Synthetic data: {len(synthetic_records)} records")
if real_data_records:
# Show measurement points from real data
real_points = set(r['measurement_point_id'] for r in real_data_records)
print(f" Real measurement points: {len(real_points)}")
# Sanity checks
assert 10 <= avg_traffic <= 500, "Average traffic should be reasonable"
assert 10 <= avg_speed <= 60, "Average speed should be reasonable"
assert max_traffic >= avg_traffic, "Max should be >= average"
assert min_traffic <= avg_traffic, "Min should be <= average"
async def test_recent_vs_historical_data(self, client, madrid_coords, june_2025_dates):
"""Compare recent data (synthetic) vs June 2025 data (potentially real)"""
lat, lon = madrid_coords
print(f"\n=== Recent vs Historical Data Comparison ===")
# Test recent data (should be synthetic)
recent_range = june_2025_dates["recent_synthetic"]
recent_result = await client.get_historical_traffic(
lat, lon, recent_range["start"], recent_range["end"]
)
# Test June 2025 data (potentially real)
june_range = june_2025_dates["quick"]
june_result = await client.get_historical_traffic(
lat, lon, june_range["start"], june_range["end"]
)
print(f"📊 Recent data: {len(recent_result)} records")
print(f"📊 June 2025 data: {len(june_result)} records")
if recent_result:
recent_sources = set(r['source'] for r in recent_result)
print(f"📡 Recent sources: {', '.join(recent_sources)}")
if june_result:
june_sources = set(r['source'] for r in june_result)
print(f"📡 June sources: {', '.join(june_sources)}")
# Check if we successfully got real data from June
if 'madrid_opendata_zip' in june_sources:
print(f"🎉 SUCCESS: Real Madrid data successfully fetched from June 2025!")
# Show details of real data
real_records = [r for r in june_result if r['source'] == 'madrid_opendata_zip']
if real_records:
sample = real_records[0]
print(f"📋 Real data sample:")
print(f" Date: {sample['date']}")
print(f" Traffic volume: {sample['traffic_volume']}")
print(f" Measurement point: {sample['measurement_point_id']}")
print(f" Point name: {sample.get('measurement_point_name', 'N/A')}")
else:
print(f" June data is synthetic (real ZIP may not be accessible)")
async def test_madrid_zip_month_code(self, client):
"""Test the month code calculation for Madrid ZIP files"""
print(f"\n=== Madrid ZIP Month Code Test ===")
# Test the month code calculation function
test_cases = [
(2025, 6, 145), # Known: June 2025 = 145
(2025, 5, 144), # Known: May 2025 = 144
(2025, 4, 143), # Known: April 2025 = 143
(2025, 7, 146), # Predicted: July 2025 = 146
]
for year, month, expected_code in test_cases:
if hasattr(client, '_calculate_madrid_month_code'):
calculated_code = client._calculate_madrid_month_code(year, month)
status = "" if calculated_code == expected_code else "⚠️"
print(f"{status} {year}-{month:02d}: Expected {expected_code}, Got {calculated_code}")
# Generate ZIP URL
if calculated_code:
zip_url = f"https://datos.madrid.es/egob/catalogo/208627-{calculated_code}-transporte-ptomedida-historico.zip"
print(f" ZIP URL: {zip_url}")
else:
print(f"⚠️ Month code calculation function not available")
async def test_edge_case_large_date_range(self, client, madrid_coords):
"""Test edge case: date range too large"""
lat, lon = madrid_coords
start_time = datetime(2025, 1, 1) # 6+ months range
end_time = datetime(2025, 7, 1)
print(f"\n=== Edge Case: Large Date Range ===")
print(f"Testing 6-month range: {start_time.date()} to {end_time.date()}")
result = await client.get_historical_traffic(lat, lon, start_time, end_time)
print(f"📊 Records for 6-month range: {len(result)}")
# Should return empty list for ranges > 90 days
assert len(result) == 0, "Should return empty list for date ranges > 90 days"
print(f"✅ Correctly handled large date range")
async def test_edge_case_invalid_coordinates(self, client):
"""Test edge case: invalid coordinates"""
print(f"\n=== Edge Case: Invalid Coordinates ===")
start_time = datetime(2025, 6, 1)
end_time = datetime(2025, 6, 1, 6, 0)
# Test with invalid coordinates
result = await client.get_historical_traffic(999.0, 999.0, start_time, end_time)
print(f"📊 Records for invalid coords: {len(result)}")
# Should either return empty list or synthetic data
# The function should not crash
assert isinstance(result, list), "Should return list even with invalid coords"
print(f"✅ Handled invalid coordinates gracefully")
async def test_real_madrid_zip_access(self, client):
"""Test if we can access the actual Madrid ZIP files"""
print(f"\n=== Real Madrid ZIP Access Test ===")
# Test the known ZIP URLs you provided
test_urls = [
"https://datos.madrid.es/egob/catalogo/208627-145-transporte-ptomedida-historico.zip", # June 2025
"https://datos.madrid.es/egob/catalogo/208627-144-transporte-ptomedida-historico.zip", # May 2025
"https://datos.madrid.es/egob/catalogo/208627-143-transporte-ptomedida-historico.zip", # April 2025
]
for i, url in enumerate(test_urls):
month_name = ["June 2025", "May 2025", "April 2025"][i]
print(f"\nTesting {month_name}: {url}")
try:
if hasattr(client, '_fetch_historical_zip'):
zip_data = await client._fetch_historical_zip(url)
if zip_data:
print(f"✅ Successfully fetched ZIP: {len(zip_data)} bytes")
# Try to inspect ZIP contents
try:
import zipfile
from io import BytesIO
with zipfile.ZipFile(BytesIO(zip_data), 'r') as zip_file:
files = zip_file.namelist()
csv_files = [f for f in files if f.endswith('.csv')]
print(f"📁 ZIP contains {len(files)} files, {len(csv_files)} CSV files")
if csv_files:
print(f" CSV files: {csv_files[:2]}{'...' if len(csv_files) > 2 else ''}")
except Exception as e:
print(f"⚠️ Could not inspect ZIP contents: {e}")
else:
print(f"❌ Failed to fetch ZIP")
else:
print(f"⚠️ ZIP fetch function not available")
except Exception as e:
print(f"❌ Error testing ZIP access: {e}")
# Additional standalone test functions for manual running
async def run_manual_test():
"""Manual test function that can be run directly"""
print("="*60)
print("MADRID TRAFFIC TEST - JUNE 2025 DATA")
print("="*60)
client = MadridOpenDataClient()
madrid_lat, madrid_lon = 40.4168, -3.7038
# Test with June 2025 data (last available)
start_time = datetime(2025, 6, 15, 14, 0) # June 15, 2025 at 2 PM
end_time = datetime(2025, 6, 15, 18, 0) # Until 6 PM (4 hours)
print(f"\nTesting June 15, 2025 data (2 PM - 6 PM)...")
print(f"This should include afternoon traffic patterns")
result = await client.get_historical_traffic(madrid_lat, madrid_lon, start_time, end_time)
print(f"Result: {len(result)} records")
if result:
sources = set(r['source'] for r in result)
print(f"Data sources: {', '.join(sources)}")
if 'madrid_opendata_zip' in sources:
print(f"🎉 Successfully got real Madrid data!")
sample = result[0]
print(f"\nSample record:")
for key, value in sample.items():
if key == "date":
print(f" {key}: {value.strftime('%Y-%m-%d %H:%M:%S')}")
else:
print(f" {key}: {value}")
print(f"\n✅ Manual test completed!")
if __name__ == "__main__":
# If run directly, execute manual test
asyncio.run(run_manual_test())

View File

@@ -8,7 +8,7 @@ from datetime import datetime, timedelta
from typing import Dict, List, Optional, Any, Tuple
from dataclasses import dataclass
import asyncio
import logging
import structlog
from concurrent.futures import ThreadPoolExecutor
from datetime import timezone
import pandas as pd
@@ -24,7 +24,7 @@ from app.services.messaging import (
publish_job_failed
)
logger = logging.getLogger(__name__)
logger = structlog.get_logger()
@dataclass
class TrainingDataSet:
@@ -39,15 +39,14 @@ class TrainingDataOrchestrator:
"""
Enhanced orchestrator for data collection from multiple sources.
Ensures date alignment, handles data source constraints, and prepares data for ML training.
Uses the new abstracted traffic service layer for multi-city support.
"""
def __init__(self,
madrid_client=None,
weather_client=None,
date_alignment_service: DateAlignmentService = None):
self.data_client = DataClient()
self.date_alignment_service = date_alignment_service or DateAlignmentService()
self.max_concurrent_requests = 3
self.max_concurrent_requests = 5 # Increased for better performance
async def prepare_training_data(
self,
@@ -281,11 +280,11 @@ class TrainingDataOrchestrator:
)
tasks.append(("weather", weather_task))
# Traffic data collection
# Enhanced Traffic data collection (supports multiple cities)
if DataSourceType.MADRID_TRAFFIC in aligned_range.available_sources:
logger.info(f"🚛 Traffic data source available, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
logger.info(f"🚛 Traffic data source available for multiple cities, creating collection task for date range: {aligned_range.start} to {aligned_range.end}")
traffic_task = asyncio.create_task(
self._collect_traffic_data_with_timeout(lat, lon, aligned_range, tenant_id)
self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
)
tasks.append(("traffic", traffic_task))
else:
@@ -353,28 +352,31 @@ class TrainingDataOrchestrator:
logger.warning(f"Weather data collection failed: {e}, using synthetic data")
return self._generate_synthetic_weather_data(aligned_range)
async def _collect_traffic_data_with_timeout(
async def _collect_traffic_data_with_timeout_enhanced(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Collect traffic data with enhanced storage and retrieval for re-training"""
"""
Enhanced traffic data collection with multi-city support and improved storage
Uses the new abstracted traffic service layer
"""
try:
# Double-check Madrid constraint before making request
# Double-check constraints before making request
constraint_violated = self.date_alignment_service.check_madrid_current_month_constraint(aligned_range.end)
if constraint_violated:
logger.warning(f"🚫 Madrid current month constraint violation: end_date={aligned_range.end}, no traffic data available")
logger.warning(f"🚫 Current month constraint violation: end_date={aligned_range.end}, no traffic data available")
return []
else:
logger.info(f"Madrid constraint passed: end_date={aligned_range.end}, proceeding with traffic data request")
logger.info(f"Date constraints passed: end_date={aligned_range.end}, proceeding with traffic data request")
start_date_str = aligned_range.start.isoformat()
end_date_str = aligned_range.end.isoformat()
# Fetch traffic data - this will automatically store it for future re-training
# Enhanced: Fetch traffic data using new abstracted service
# This automatically detects the appropriate city and uses the right client
traffic_data = await self.data_client.fetch_traffic_data(
tenant_id=tenant_id,
start_date=start_date_str,
@@ -382,39 +384,82 @@ class TrainingDataOrchestrator:
latitude=lat,
longitude=lon)
# Validate traffic data
if self._validate_traffic_data(traffic_data):
logger.info(f"Collected and stored {len(traffic_data)} valid traffic records for re-training")
# Enhanced validation including pedestrian inference data
if self._validate_traffic_data_enhanced(traffic_data):
logger.info(f"Collected and stored {len(traffic_data)} valid enhanced traffic records for re-training")
# Log storage success for audit purposes
self._log_traffic_data_storage(lat, lon, aligned_range, len(traffic_data))
# Log storage success with enhanced metadata
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, len(traffic_data), traffic_data)
return traffic_data
else:
logger.warning("Invalid traffic data received")
logger.warning("Invalid enhanced traffic data received")
return []
except asyncio.TimeoutError:
logger.warning(f"Traffic data collection timed out")
logger.warning(f"Enhanced traffic data collection timed out")
return []
except Exception as e:
logger.warning(f"Traffic data collection failed: {e}")
logger.warning(f"Enhanced traffic data collection failed: {e}")
return []
# Keep original method for backwards compatibility
async def _collect_traffic_data_with_timeout(
self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
tenant_id: str
) -> List[Dict[str, Any]]:
"""Legacy traffic data collection method - redirects to enhanced version"""
return await self._collect_traffic_data_with_timeout_enhanced(lat, lon, aligned_range, tenant_id)
def _log_enhanced_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
record_count: int,
traffic_data: List[Dict[str, Any]]):
"""Enhanced logging for traffic data storage with detailed metadata"""
# Analyze the stored data for additional insights
cities_detected = set()
has_pedestrian_data = 0
data_sources = set()
districts_covered = set()
for record in traffic_data:
if 'city' in record and record['city']:
cities_detected.add(record['city'])
if 'pedestrian_count' in record and record['pedestrian_count'] is not None:
has_pedestrian_data += 1
if 'source' in record and record['source']:
data_sources.add(record['source'])
if 'district' in record and record['district']:
districts_covered.add(record['district'])
logger.info(
"Enhanced traffic data stored for re-training",
location=f"{lat:.4f},{lon:.4f}",
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
records_stored=record_count,
cities_detected=list(cities_detected),
pedestrian_inference_coverage=f"{has_pedestrian_data}/{record_count}",
data_sources=list(data_sources),
districts_covered=list(districts_covered),
storage_timestamp=datetime.now().isoformat(),
purpose="enhanced_model_training_and_retraining",
architecture_version="2.0_abstracted"
)
def _log_traffic_data_storage(self,
lat: float,
lon: float,
aligned_range: AlignedDateRange,
record_count: int):
"""Log traffic data storage for audit and re-training tracking"""
logger.info(
"Traffic data stored for re-training",
location=f"{lat:.4f},{lon:.4f}",
date_range=f"{aligned_range.start.isoformat()} to {aligned_range.end.isoformat()}",
records_stored=record_count,
storage_timestamp=datetime.now().isoformat(),
purpose="model_training_and_retraining"
)
"""Legacy logging method - redirects to enhanced version"""
# Create minimal traffic data structure for enhanced logging
minimal_traffic_data = [{"city": "madrid", "source": "legacy"}] * min(record_count, 1)
self._log_enhanced_traffic_data_storage(lat, lon, aligned_range, record_count, minimal_traffic_data)
async def retrieve_stored_traffic_for_retraining(
self,
@@ -491,32 +536,73 @@ class TrainingDataOrchestrator:
return is_valid
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Validate traffic data quality"""
def _validate_traffic_data_enhanced(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Enhanced validation for traffic data including pedestrian inference and city-specific fields"""
if not traffic_data:
return False
required_fields = ['date']
traffic_fields = ['traffic_volume', 'traffic_intensity', 'intensidad', 'trafico']
enhanced_fields = ['pedestrian_count', 'congestion_level', 'source']
city_specific_fields = ['city', 'measurement_point_id', 'district']
valid_records = 0
enhanced_records = 0
city_aware_records = 0
for record in traffic_data:
# Check required fields
if not all(field in record for field in required_fields):
continue
record_score = 0
# Check at least one traffic field exists
# Check required fields
if all(field in record and record[field] is not None for field in required_fields):
record_score += 1
# Check traffic data fields
if any(field in record and record[field] is not None for field in traffic_fields):
record_score += 1
# Check enhanced fields (pedestrian inference, etc.)
enhanced_count = sum(1 for field in enhanced_fields
if field in record and record[field] is not None)
if enhanced_count >= 2: # At least 2 enhanced fields
enhanced_records += 1
record_score += 1
# Check city-specific awareness
city_count = sum(1 for field in city_specific_fields
if field in record and record[field] is not None)
if city_count >= 1: # At least some city awareness
city_aware_records += 1
# Record is valid if it has basic requirements
if record_score >= 2:
valid_records += 1
# Consider valid if at least 30% of records are valid (traffic data is often sparse)
total_records = len(traffic_data)
validity_threshold = 0.3
is_valid = (valid_records / len(traffic_data)) >= validity_threshold
enhancement_threshold = 0.2 # Lower threshold for enhanced features
if not is_valid:
logger.warning(f"Traffic data validation failed: {valid_records}/{len(traffic_data)} valid records")
basic_validity = (valid_records / total_records) >= validity_threshold
has_enhancements = (enhanced_records / total_records) >= enhancement_threshold
has_city_awareness = (city_aware_records / total_records) >= enhancement_threshold
return is_valid
logger.info("Enhanced traffic data validation results",
total_records=total_records,
valid_records=valid_records,
enhanced_records=enhanced_records,
city_aware_records=city_aware_records,
basic_validity=basic_validity,
has_enhancements=has_enhancements,
has_city_awareness=has_city_awareness)
if not basic_validity:
logger.warning(f"Traffic data basic validation failed: {valid_records}/{total_records} valid records")
return basic_validity
def _validate_traffic_data(self, traffic_data: List[Dict[str, Any]]) -> bool:
"""Legacy validation method - redirects to enhanced version"""
return self._validate_traffic_data_enhanced(traffic_data)
def _validate_data_sources(
self,