Improve the traffic fetching system
This commit is contained in:
874
services/data/app/repositories/traffic_repository.py
Normal file
874
services/data/app/repositories/traffic_repository.py
Normal file
@@ -0,0 +1,874 @@
|
||||
# ================================================================
|
||||
# services/data/app/repositories/traffic_repository.py
|
||||
# ================================================================
|
||||
"""
|
||||
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
|
||||
Follows existing repository architecture while adding city-specific functionality
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Type, Tuple
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
|
||||
from sqlalchemy.orm import selectinload
|
||||
from datetime import datetime, timezone, timedelta
|
||||
import structlog
|
||||
|
||||
from .base import DataBaseRepository
|
||||
from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataBackgroundJob
|
||||
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
|
||||
from shared.database.exceptions import DatabaseError, ValidationError
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
|
||||
class TrafficRepository(DataBaseRepository[TrafficData, TrafficDataCreate, Dict]):
|
||||
"""
|
||||
Enhanced repository for traffic data operations across multiple cities
|
||||
Provides city-aware queries and advanced traffic analytics
|
||||
"""
|
||||
|
||||
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
||||
super().__init__(model_class, session, cache_ttl)
|
||||
|
||||
# ================================================================
|
||||
# CORE TRAFFIC DATA OPERATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_by_location_and_date_range(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
city: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 100
|
||||
) -> List[TrafficData]:
|
||||
"""Get traffic data by location and date range with city filtering"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
# Build base query
|
||||
query = select(self.model).where(self.model.location_id == location_id)
|
||||
|
||||
# Add city filter if specified
|
||||
if city:
|
||||
query = query.where(self.model.city == city)
|
||||
|
||||
# Add tenant filter if specified
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filters
|
||||
if start_date:
|
||||
start_date = self._ensure_utc_datetime(start_date)
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_date = self._ensure_utc_datetime(end_date)
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Order by date descending (most recent first)
|
||||
query = query.order_by(desc(self.model.date))
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic data by location and date range",
|
||||
latitude=latitude, longitude=longitude,
|
||||
city=city, error=str(e))
|
||||
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
||||
|
||||
async def get_by_city_and_date_range(
|
||||
self,
|
||||
city: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
district: Optional[str] = None,
|
||||
measurement_point_ids: Optional[List[str]] = None,
|
||||
include_synthetic: bool = True,
|
||||
tenant_id: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = 1000
|
||||
) -> List[TrafficData]:
|
||||
"""Get traffic data by city with advanced filtering options"""
|
||||
try:
|
||||
# Build base query
|
||||
query = select(self.model).where(self.model.city == city)
|
||||
|
||||
# Add tenant filter if specified
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
# Add date range filters
|
||||
if start_date:
|
||||
start_date = self._ensure_utc_datetime(start_date)
|
||||
query = query.where(self.model.date >= start_date)
|
||||
|
||||
if end_date:
|
||||
end_date = self._ensure_utc_datetime(end_date)
|
||||
query = query.where(self.model.date <= end_date)
|
||||
|
||||
# Add district filter
|
||||
if district:
|
||||
query = query.where(self.model.district == district)
|
||||
|
||||
# Add measurement point filter
|
||||
if measurement_point_ids:
|
||||
query = query.where(self.model.measurement_point_id.in_(measurement_point_ids))
|
||||
|
||||
# Filter synthetic data if requested
|
||||
if not include_synthetic:
|
||||
query = query.where(self.model.is_synthetic == False)
|
||||
|
||||
# Order by date and measurement point
|
||||
query = query.order_by(desc(self.model.date), self.model.measurement_point_id)
|
||||
|
||||
# Apply pagination
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic data by city",
|
||||
city=city, district=district, error=str(e))
|
||||
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
||||
|
||||
async def get_latest_by_measurement_points(
|
||||
self,
|
||||
measurement_point_ids: List[str],
|
||||
city: str,
|
||||
hours_back: int = 24
|
||||
) -> List[TrafficData]:
|
||||
"""Get latest traffic data for specific measurement points"""
|
||||
try:
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
|
||||
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.measurement_point_id.in_(measurement_point_ids),
|
||||
self.model.date >= cutoff_time
|
||||
)
|
||||
).order_by(
|
||||
self.model.measurement_point_id,
|
||||
desc(self.model.date)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
all_records = result.scalars().all()
|
||||
|
||||
# Get the latest record for each measurement point
|
||||
latest_records = {}
|
||||
for record in all_records:
|
||||
point_id = record.measurement_point_id
|
||||
if point_id not in latest_records:
|
||||
latest_records[point_id] = record
|
||||
|
||||
return list(latest_records.values())
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get latest traffic data by measurement points",
|
||||
city=city, points=len(measurement_point_ids), error=str(e))
|
||||
raise DatabaseError(f"Failed to get latest traffic data: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# ANALYTICS AND AGGREGATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_traffic_statistics_by_city(
|
||||
self,
|
||||
city: str,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
group_by: str = "daily"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get aggregated traffic statistics by city"""
|
||||
try:
|
||||
# Determine date truncation based on group_by
|
||||
if group_by == "hourly":
|
||||
date_trunc = "hour"
|
||||
elif group_by == "daily":
|
||||
date_trunc = "day"
|
||||
elif group_by == "weekly":
|
||||
date_trunc = "week"
|
||||
elif group_by == "monthly":
|
||||
date_trunc = "month"
|
||||
else:
|
||||
raise ValidationError(f"Invalid group_by value: {group_by}")
|
||||
|
||||
# Build aggregation query
|
||||
if self.session.bind.dialect.name == 'postgresql':
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE_TRUNC(:date_trunc, date) as period,
|
||||
city,
|
||||
district,
|
||||
COUNT(*) as record_count,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
MAX(traffic_volume) as max_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count,
|
||||
AVG(average_speed) as avg_speed,
|
||||
COUNT(CASE WHEN congestion_level = 'high' THEN 1 END) as high_congestion_count,
|
||||
COUNT(CASE WHEN is_synthetic = false THEN 1 END) as real_data_count,
|
||||
COUNT(CASE WHEN has_pedestrian_inference = true THEN 1 END) as pedestrian_inference_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
""")
|
||||
else:
|
||||
# SQLite fallback
|
||||
query = text("""
|
||||
SELECT
|
||||
DATE(date) as period,
|
||||
city,
|
||||
district,
|
||||
COUNT(*) as record_count,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
MAX(traffic_volume) as max_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count,
|
||||
AVG(average_speed) as avg_speed,
|
||||
SUM(CASE WHEN congestion_level = 'high' THEN 1 ELSE 0 END) as high_congestion_count,
|
||||
SUM(CASE WHEN is_synthetic = 0 THEN 1 ELSE 0 END) as real_data_count,
|
||||
SUM(CASE WHEN has_pedestrian_inference = 1 THEN 1 ELSE 0 END) as pedestrian_inference_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"date_trunc": date_trunc
|
||||
}
|
||||
|
||||
# Add date filters
|
||||
if start_date:
|
||||
query = text(str(query) + " AND date >= :start_date")
|
||||
params["start_date"] = self._ensure_utc_datetime(start_date)
|
||||
|
||||
if end_date:
|
||||
query = text(str(query) + " AND date <= :end_date")
|
||||
params["end_date"] = self._ensure_utc_datetime(end_date)
|
||||
|
||||
# Add GROUP BY and ORDER BY
|
||||
query = text(str(query) + " GROUP BY period, city, district ORDER BY period DESC")
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert to list of dictionaries
|
||||
statistics = []
|
||||
for row in rows:
|
||||
statistics.append({
|
||||
"period": group_by,
|
||||
"date": row.period,
|
||||
"city": row.city,
|
||||
"district": row.district,
|
||||
"record_count": row.record_count,
|
||||
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
||||
"max_traffic_volume": row.max_traffic_volume or 0,
|
||||
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0),
|
||||
"avg_speed": float(row.avg_speed or 0),
|
||||
"high_congestion_count": row.high_congestion_count or 0,
|
||||
"real_data_percentage": round((row.real_data_count or 0) / max(1, row.record_count) * 100, 2),
|
||||
"pedestrian_inference_percentage": round((row.pedestrian_inference_count or 0) / max(1, row.record_count) * 100, 2)
|
||||
})
|
||||
|
||||
return statistics
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get traffic statistics by city",
|
||||
city=city, group_by=group_by, error=str(e))
|
||||
raise DatabaseError(f"Traffic statistics query failed: {str(e)}")
|
||||
|
||||
async def get_congestion_heatmap_data(
|
||||
self,
|
||||
city: str,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
time_granularity: str = "hour"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get congestion data for heatmap visualization"""
|
||||
try:
|
||||
if time_granularity == "hour":
|
||||
time_extract = "EXTRACT(hour FROM date)"
|
||||
elif time_granularity == "day_of_week":
|
||||
time_extract = "EXTRACT(dow FROM date)"
|
||||
else:
|
||||
time_extract = "EXTRACT(hour FROM date)"
|
||||
|
||||
query = text(f"""
|
||||
SELECT
|
||||
{time_extract} as time_period,
|
||||
district,
|
||||
measurement_point_id,
|
||||
latitude,
|
||||
longitude,
|
||||
AVG(CASE
|
||||
WHEN congestion_level = 'low' THEN 1
|
||||
WHEN congestion_level = 'medium' THEN 2
|
||||
WHEN congestion_level = 'high' THEN 3
|
||||
WHEN congestion_level = 'blocked' THEN 4
|
||||
ELSE 1
|
||||
END) as avg_congestion_score,
|
||||
COUNT(*) as data_points,
|
||||
AVG(traffic_volume) as avg_traffic_volume,
|
||||
AVG(pedestrian_count) as avg_pedestrian_count
|
||||
FROM traffic_data
|
||||
WHERE city = :city
|
||||
AND date >= :start_date
|
||||
AND date <= :end_date
|
||||
AND latitude IS NOT NULL
|
||||
AND longitude IS NOT NULL
|
||||
GROUP BY time_period, district, measurement_point_id, latitude, longitude
|
||||
ORDER BY time_period, district, avg_congestion_score DESC
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"start_date": self._ensure_utc_datetime(start_date),
|
||||
"end_date": self._ensure_utc_datetime(end_date)
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
heatmap_data = []
|
||||
for row in rows:
|
||||
heatmap_data.append({
|
||||
"time_period": int(row.time_period or 0),
|
||||
"district": row.district,
|
||||
"measurement_point_id": row.measurement_point_id,
|
||||
"latitude": float(row.latitude),
|
||||
"longitude": float(row.longitude),
|
||||
"avg_congestion_score": float(row.avg_congestion_score),
|
||||
"data_points": row.data_points,
|
||||
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
||||
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0)
|
||||
})
|
||||
|
||||
return heatmap_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get congestion heatmap data",
|
||||
city=city, error=str(e))
|
||||
raise DatabaseError(f"Congestion heatmap query failed: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# BULK OPERATIONS AND DATA MANAGEMENT
|
||||
# ================================================================
|
||||
|
||||
async def create_bulk_traffic_data(
|
||||
self,
|
||||
traffic_records: List[Dict[str, Any]],
|
||||
city: str,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[TrafficData]:
|
||||
"""Create multiple traffic records in bulk with enhanced validation"""
|
||||
try:
|
||||
# Ensure all records have city and tenant_id
|
||||
for record in traffic_records:
|
||||
record["city"] = city
|
||||
if tenant_id:
|
||||
record["tenant_id"] = tenant_id
|
||||
# Ensure dates are timezone-aware
|
||||
if "date" in record and record["date"]:
|
||||
record["date"] = self._ensure_utc_datetime(record["date"])
|
||||
|
||||
# Enhanced validation
|
||||
validated_records = []
|
||||
for record in traffic_records:
|
||||
if self._validate_traffic_record(record):
|
||||
validated_records.append(record)
|
||||
else:
|
||||
logger.warning("Invalid traffic record skipped",
|
||||
city=city, record_keys=list(record.keys()))
|
||||
|
||||
if not validated_records:
|
||||
logger.warning("No valid traffic records to create", city=city)
|
||||
return []
|
||||
|
||||
# Use bulk create with deduplication
|
||||
created_records = await self.bulk_create_with_deduplication(validated_records)
|
||||
|
||||
logger.info("Bulk traffic data creation completed",
|
||||
city=city, requested=len(traffic_records),
|
||||
validated=len(validated_records), created=len(created_records))
|
||||
|
||||
return created_records
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create bulk traffic data",
|
||||
city=city, record_count=len(traffic_records), error=str(e))
|
||||
raise DatabaseError(f"Bulk traffic creation failed: {str(e)}")
|
||||
|
||||
async def bulk_create_with_deduplication(
|
||||
self,
|
||||
records: List[Dict[str, Any]]
|
||||
) -> List[TrafficData]:
|
||||
"""Bulk create with automatic deduplication based on location, city, and date"""
|
||||
try:
|
||||
if not records:
|
||||
return []
|
||||
|
||||
# Extract unique keys for deduplication check
|
||||
unique_keys = []
|
||||
for record in records:
|
||||
key = (
|
||||
record.get('location_id'),
|
||||
record.get('city'),
|
||||
record.get('date'),
|
||||
record.get('measurement_point_id')
|
||||
)
|
||||
unique_keys.append(key)
|
||||
|
||||
# Check for existing records
|
||||
location_ids = [key[0] for key in unique_keys if key[0]]
|
||||
cities = [key[1] for key in unique_keys if key[1]]
|
||||
dates = [key[2] for key in unique_keys if key[2]]
|
||||
|
||||
# For large datasets, use chunked deduplication to avoid memory issues
|
||||
if len(location_ids) > 1000:
|
||||
logger.info(f"Large dataset detected ({len(records)} records), using chunked deduplication")
|
||||
new_records = []
|
||||
chunk_size = 1000
|
||||
|
||||
for i in range(0, len(records), chunk_size):
|
||||
chunk_records = records[i:i + chunk_size]
|
||||
chunk_keys = unique_keys[i:i + chunk_size]
|
||||
|
||||
# Get unique values for this chunk
|
||||
chunk_location_ids = list(set(key[0] for key in chunk_keys if key[0]))
|
||||
chunk_cities = list(set(key[1] for key in chunk_keys if key[1]))
|
||||
chunk_dates = list(set(key[2] for key in chunk_keys if key[2]))
|
||||
|
||||
if chunk_location_ids and chunk_cities and chunk_dates:
|
||||
existing_query = select(
|
||||
self.model.location_id,
|
||||
self.model.city,
|
||||
self.model.date,
|
||||
self.model.measurement_point_id
|
||||
).where(
|
||||
and_(
|
||||
self.model.location_id.in_(chunk_location_ids),
|
||||
self.model.city.in_(chunk_cities),
|
||||
self.model.date.in_(chunk_dates)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(existing_query)
|
||||
chunk_existing_keys = set(result.fetchall())
|
||||
|
||||
# Filter chunk duplicates
|
||||
for j, record in enumerate(chunk_records):
|
||||
key = chunk_keys[j]
|
||||
if key not in chunk_existing_keys:
|
||||
new_records.append(record)
|
||||
else:
|
||||
new_records.extend(chunk_records)
|
||||
|
||||
logger.debug("Chunked deduplication completed",
|
||||
total_records=len(records),
|
||||
new_records=len(new_records))
|
||||
records = new_records
|
||||
|
||||
elif location_ids and cities and dates:
|
||||
existing_query = select(
|
||||
self.model.location_id,
|
||||
self.model.city,
|
||||
self.model.date,
|
||||
self.model.measurement_point_id
|
||||
).where(
|
||||
and_(
|
||||
self.model.location_id.in_(location_ids),
|
||||
self.model.city.in_(cities),
|
||||
self.model.date.in_(dates)
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(existing_query)
|
||||
existing_keys = set(result.fetchall())
|
||||
|
||||
# Filter out duplicates
|
||||
new_records = []
|
||||
for i, record in enumerate(records):
|
||||
key = unique_keys[i]
|
||||
if key not in existing_keys:
|
||||
new_records.append(record)
|
||||
|
||||
logger.debug("Standard deduplication completed",
|
||||
total_records=len(records),
|
||||
existing_records=len(existing_keys),
|
||||
new_records=len(new_records))
|
||||
|
||||
records = new_records
|
||||
|
||||
# Proceed with bulk creation
|
||||
return await self.bulk_create(records)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed bulk create with deduplication", error=str(e))
|
||||
raise DatabaseError(f"Bulk create with deduplication failed: {str(e)}")
|
||||
|
||||
def _validate_traffic_record(self, record: Dict[str, Any]) -> bool:
|
||||
"""Enhanced validation for traffic records"""
|
||||
required_fields = ['date', 'city']
|
||||
|
||||
# Check required fields
|
||||
for field in required_fields:
|
||||
if not record.get(field):
|
||||
return False
|
||||
|
||||
# Validate city
|
||||
city = record.get('city', '').lower()
|
||||
if city not in ['madrid', 'barcelona', 'valencia', 'test']: # Extendable list
|
||||
return False
|
||||
|
||||
# Validate data ranges
|
||||
traffic_volume = record.get('traffic_volume')
|
||||
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 50000):
|
||||
return False
|
||||
|
||||
pedestrian_count = record.get('pedestrian_count')
|
||||
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
|
||||
return False
|
||||
|
||||
average_speed = record.get('average_speed')
|
||||
if average_speed is not None and (average_speed < 0 or average_speed > 200):
|
||||
return False
|
||||
|
||||
congestion_level = record.get('congestion_level')
|
||||
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
# ================================================================
|
||||
# TRAINING DATA SPECIFIC OPERATIONS
|
||||
# ================================================================
|
||||
|
||||
async def get_training_data_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None,
|
||||
include_pedestrian_inference: bool = True
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Get optimized training data for ML models"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.location_id == location_id,
|
||||
self.model.date >= self._ensure_utc_datetime(start_date),
|
||||
self.model.date <= self._ensure_utc_datetime(end_date)
|
||||
)
|
||||
)
|
||||
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
if include_pedestrian_inference:
|
||||
# Prefer records with pedestrian inference
|
||||
query = query.order_by(
|
||||
desc(self.model.has_pedestrian_inference),
|
||||
desc(self.model.data_quality_score),
|
||||
self.model.date
|
||||
)
|
||||
else:
|
||||
query = query.order_by(
|
||||
desc(self.model.data_quality_score),
|
||||
self.model.date
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
records = result.scalars().all()
|
||||
|
||||
# Convert to training format with enhanced features
|
||||
training_data = []
|
||||
for record in records:
|
||||
training_record = {
|
||||
'date': record.date,
|
||||
'traffic_volume': record.traffic_volume or 0,
|
||||
'pedestrian_count': record.pedestrian_count or 0,
|
||||
'congestion_level': record.congestion_level or 'medium',
|
||||
'average_speed': record.average_speed or 25.0,
|
||||
'city': record.city,
|
||||
'district': record.district,
|
||||
'measurement_point_id': record.measurement_point_id,
|
||||
'source': record.source,
|
||||
'is_synthetic': record.is_synthetic or False,
|
||||
'has_pedestrian_inference': record.has_pedestrian_inference or False,
|
||||
'data_quality_score': record.data_quality_score or 50.0,
|
||||
|
||||
# Enhanced features for training
|
||||
'hour_of_day': record.date.hour if record.date else 12,
|
||||
'day_of_week': record.date.weekday() if record.date else 0,
|
||||
'month': record.date.month if record.date else 1,
|
||||
|
||||
# City-specific features
|
||||
'city_specific_data': record.city_specific_data or {}
|
||||
}
|
||||
|
||||
training_data.append(training_record)
|
||||
|
||||
logger.info("Retrieved training data",
|
||||
location_id=location_id, records=len(training_data),
|
||||
with_pedestrian_inference=sum(1 for r in training_data if r['has_pedestrian_inference']))
|
||||
|
||||
return training_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get training data",
|
||||
latitude=latitude, longitude=longitude, error=str(e))
|
||||
raise DatabaseError(f"Training data retrieval failed: {str(e)}")
|
||||
|
||||
async def get_historical_data_by_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> List[TrafficData]:
|
||||
"""Get historical traffic data for a specific location and date range"""
|
||||
return await self.get_by_location_and_date_range(
|
||||
latitude=latitude,
|
||||
longitude=longitude,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
tenant_id=tenant_id,
|
||||
limit=1000000 # Large limit for historical data
|
||||
)
|
||||
|
||||
async def count_records_in_period(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
start_date: datetime,
|
||||
end_date: datetime,
|
||||
city: Optional[str] = None,
|
||||
tenant_id: Optional[str] = None
|
||||
) -> int:
|
||||
"""Count traffic records for a specific location and time period"""
|
||||
try:
|
||||
location_id = f"{latitude:.4f},{longitude:.4f}"
|
||||
|
||||
query = select(func.count(self.model.id)).where(
|
||||
and_(
|
||||
self.model.location_id == location_id,
|
||||
self.model.date >= self._ensure_utc_datetime(start_date),
|
||||
self.model.date <= self._ensure_utc_datetime(end_date)
|
||||
)
|
||||
)
|
||||
|
||||
if city:
|
||||
query = query.where(self.model.city == city)
|
||||
|
||||
if tenant_id:
|
||||
query = query.where(self.model.tenant_id == tenant_id)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
count = result.scalar()
|
||||
|
||||
return count or 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to count records in period",
|
||||
latitude=latitude, longitude=longitude, error=str(e))
|
||||
raise DatabaseError(f"Record count failed: {str(e)}")
|
||||
|
||||
# ================================================================
|
||||
# DATA QUALITY AND MAINTENANCE
|
||||
# ================================================================
|
||||
|
||||
async def update_data_quality_scores(self, city: str) -> int:
|
||||
"""Update data quality scores based on various criteria"""
|
||||
try:
|
||||
# Calculate quality scores based on data completeness and consistency
|
||||
query = text("""
|
||||
UPDATE traffic_data
|
||||
SET data_quality_score = (
|
||||
CASE
|
||||
WHEN traffic_volume IS NOT NULL THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN pedestrian_count IS NOT NULL THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN average_speed IS NOT NULL AND average_speed > 0 THEN 20 ELSE 0 END +
|
||||
CASE
|
||||
WHEN congestion_level IS NOT NULL THEN 15 ELSE 0 END +
|
||||
CASE
|
||||
WHEN measurement_point_id IS NOT NULL THEN 10 ELSE 0 END +
|
||||
CASE
|
||||
WHEN district IS NOT NULL THEN 10 ELSE 0 END +
|
||||
CASE
|
||||
WHEN has_pedestrian_inference = true THEN 5 ELSE 0 END
|
||||
),
|
||||
updated_at = :updated_at
|
||||
WHERE city = :city AND data_quality_score IS NULL
|
||||
""")
|
||||
|
||||
params = {
|
||||
"city": city,
|
||||
"updated_at": datetime.now(timezone.utc)
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
updated_count = result.rowcount
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Updated data quality scores",
|
||||
city=city, updated_count=updated_count)
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update data quality scores",
|
||||
city=city, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Data quality update failed: {str(e)}")
|
||||
|
||||
async def cleanup_old_synthetic_data(
|
||||
self,
|
||||
city: str,
|
||||
days_to_keep: int = 90
|
||||
) -> int:
|
||||
"""Clean up old synthetic data while preserving real data"""
|
||||
try:
|
||||
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
||||
|
||||
query = delete(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.is_synthetic == True,
|
||||
self.model.date < cutoff_date
|
||||
)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
deleted_count = result.rowcount
|
||||
await self.session.commit()
|
||||
|
||||
logger.info("Cleaned up old synthetic data",
|
||||
city=city, deleted_count=deleted_count, days_kept=days_to_keep)
|
||||
|
||||
return deleted_count
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to cleanup old synthetic data",
|
||||
city=city, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Synthetic data cleanup failed: {str(e)}")
|
||||
|
||||
|
||||
class TrafficMeasurementPointRepository(DataBaseRepository[TrafficMeasurementPoint, Dict, Dict]):
|
||||
"""Repository for traffic measurement points across cities"""
|
||||
|
||||
async def get_points_near_location(
|
||||
self,
|
||||
latitude: float,
|
||||
longitude: float,
|
||||
city: str,
|
||||
radius_km: float = 10.0,
|
||||
limit: int = 20
|
||||
) -> List[TrafficMeasurementPoint]:
|
||||
"""Get measurement points near a location using spatial query"""
|
||||
try:
|
||||
# Simple distance calculation (for more precise, use PostGIS)
|
||||
query = text("""
|
||||
SELECT *,
|
||||
(6371 * acos(
|
||||
cos(radians(:lat)) * cos(radians(latitude)) *
|
||||
cos(radians(longitude) - radians(:lon)) +
|
||||
sin(radians(:lat)) * sin(radians(latitude))
|
||||
)) as distance_km
|
||||
FROM traffic_measurement_points
|
||||
WHERE city = :city
|
||||
AND is_active = true
|
||||
HAVING distance_km <= :radius_km
|
||||
ORDER BY distance_km
|
||||
LIMIT :limit
|
||||
""")
|
||||
|
||||
params = {
|
||||
"lat": latitude,
|
||||
"lon": longitude,
|
||||
"city": city,
|
||||
"radius_km": radius_km,
|
||||
"limit": limit
|
||||
}
|
||||
|
||||
result = await self.session.execute(query, params)
|
||||
rows = result.fetchall()
|
||||
|
||||
# Convert rows to model instances
|
||||
points = []
|
||||
for row in rows:
|
||||
point = TrafficMeasurementPoint()
|
||||
for key, value in row._mapping.items():
|
||||
if hasattr(point, key) and key != 'distance_km':
|
||||
setattr(point, key, value)
|
||||
points.append(point)
|
||||
|
||||
return points
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get measurement points near location",
|
||||
latitude=latitude, longitude=longitude, city=city, error=str(e))
|
||||
raise DatabaseError(f"Measurement points query failed: {str(e)}")
|
||||
|
||||
|
||||
class TrafficBackgroundJobRepository(DataBaseRepository[TrafficDataBackgroundJob, Dict, Dict]):
|
||||
"""Repository for managing background traffic data jobs"""
|
||||
|
||||
async def get_pending_jobs_by_city(self, city: str) -> List[TrafficDataBackgroundJob]:
|
||||
"""Get pending background jobs for a specific city"""
|
||||
try:
|
||||
query = select(self.model).where(
|
||||
and_(
|
||||
self.model.city == city,
|
||||
self.model.status == 'pending'
|
||||
)
|
||||
).order_by(self.model.scheduled_at)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to get pending jobs by city", city=city, error=str(e))
|
||||
raise DatabaseError(f"Background jobs query failed: {str(e)}")
|
||||
|
||||
async def update_job_progress(
|
||||
self,
|
||||
job_id: str,
|
||||
progress_percentage: float,
|
||||
records_processed: int = 0,
|
||||
records_stored: int = 0
|
||||
) -> bool:
|
||||
"""Update job progress"""
|
||||
try:
|
||||
query = update(self.model).where(
|
||||
self.model.id == job_id
|
||||
).values(
|
||||
progress_percentage=progress_percentage,
|
||||
records_processed=records_processed,
|
||||
records_stored=records_stored,
|
||||
updated_at=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
result = await self.session.execute(query)
|
||||
await self.session.commit()
|
||||
|
||||
return result.rowcount > 0
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to update job progress", job_id=job_id, error=str(e))
|
||||
await self.session.rollback()
|
||||
raise DatabaseError(f"Job progress update failed: {str(e)}")
|
||||
Reference in New Issue
Block a user