874 lines
36 KiB
Python
874 lines
36 KiB
Python
# ================================================================
|
|
# services/data/app/repositories/traffic_repository.py
|
|
# ================================================================
|
|
"""
|
|
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
|
|
Follows existing repository architecture while adding city-specific functionality
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any, Type, Tuple
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
|
|
from sqlalchemy.orm import selectinload
|
|
from datetime import datetime, timezone, timedelta
|
|
import structlog
|
|
|
|
from .base import DataBaseRepository
|
|
from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataBackgroundJob
|
|
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class TrafficRepository(DataBaseRepository[TrafficData, TrafficDataCreate, Dict]):
|
|
"""
|
|
Enhanced repository for traffic data operations across multiple cities
|
|
Provides city-aware queries and advanced traffic analytics
|
|
"""
|
|
|
|
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
|
|
super().__init__(model_class, session, cache_ttl)
|
|
|
|
# ================================================================
|
|
# CORE TRAFFIC DATA OPERATIONS
|
|
# ================================================================
|
|
|
|
async def get_by_location_and_date_range(
|
|
self,
|
|
latitude: float,
|
|
longitude: float,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
city: Optional[str] = None,
|
|
tenant_id: Optional[str] = None,
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> List[TrafficData]:
|
|
"""Get traffic data by location and date range with city filtering"""
|
|
try:
|
|
location_id = f"{latitude:.4f},{longitude:.4f}"
|
|
|
|
# Build base query
|
|
query = select(self.model).where(self.model.location_id == location_id)
|
|
|
|
# Add city filter if specified
|
|
if city:
|
|
query = query.where(self.model.city == city)
|
|
|
|
# Add tenant filter if specified
|
|
if tenant_id:
|
|
query = query.where(self.model.tenant_id == tenant_id)
|
|
|
|
# Add date range filters
|
|
if start_date:
|
|
start_date = self._ensure_utc_datetime(start_date)
|
|
query = query.where(self.model.date >= start_date)
|
|
|
|
if end_date:
|
|
end_date = self._ensure_utc_datetime(end_date)
|
|
query = query.where(self.model.date <= end_date)
|
|
|
|
# Order by date descending (most recent first)
|
|
query = query.order_by(desc(self.model.date))
|
|
|
|
# Apply pagination
|
|
query = query.offset(skip).limit(limit)
|
|
|
|
result = await self.session.execute(query)
|
|
return result.scalars().all()
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get traffic data by location and date range",
|
|
latitude=latitude, longitude=longitude,
|
|
city=city, error=str(e))
|
|
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
|
|
|
async def get_by_city_and_date_range(
|
|
self,
|
|
city: str,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
district: Optional[str] = None,
|
|
measurement_point_ids: Optional[List[str]] = None,
|
|
include_synthetic: bool = True,
|
|
tenant_id: Optional[str] = None,
|
|
skip: int = 0,
|
|
limit: int = 1000
|
|
) -> List[TrafficData]:
|
|
"""Get traffic data by city with advanced filtering options"""
|
|
try:
|
|
# Build base query
|
|
query = select(self.model).where(self.model.city == city)
|
|
|
|
# Add tenant filter if specified
|
|
if tenant_id:
|
|
query = query.where(self.model.tenant_id == tenant_id)
|
|
|
|
# Add date range filters
|
|
if start_date:
|
|
start_date = self._ensure_utc_datetime(start_date)
|
|
query = query.where(self.model.date >= start_date)
|
|
|
|
if end_date:
|
|
end_date = self._ensure_utc_datetime(end_date)
|
|
query = query.where(self.model.date <= end_date)
|
|
|
|
# Add district filter
|
|
if district:
|
|
query = query.where(self.model.district == district)
|
|
|
|
# Add measurement point filter
|
|
if measurement_point_ids:
|
|
query = query.where(self.model.measurement_point_id.in_(measurement_point_ids))
|
|
|
|
# Filter synthetic data if requested
|
|
if not include_synthetic:
|
|
query = query.where(self.model.is_synthetic == False)
|
|
|
|
# Order by date and measurement point
|
|
query = query.order_by(desc(self.model.date), self.model.measurement_point_id)
|
|
|
|
# Apply pagination
|
|
query = query.offset(skip).limit(limit)
|
|
|
|
result = await self.session.execute(query)
|
|
return result.scalars().all()
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get traffic data by city",
|
|
city=city, district=district, error=str(e))
|
|
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
|
|
|
async def get_latest_by_measurement_points(
|
|
self,
|
|
measurement_point_ids: List[str],
|
|
city: str,
|
|
hours_back: int = 24
|
|
) -> List[TrafficData]:
|
|
"""Get latest traffic data for specific measurement points"""
|
|
try:
|
|
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
|
|
|
|
query = select(self.model).where(
|
|
and_(
|
|
self.model.city == city,
|
|
self.model.measurement_point_id.in_(measurement_point_ids),
|
|
self.model.date >= cutoff_time
|
|
)
|
|
).order_by(
|
|
self.model.measurement_point_id,
|
|
desc(self.model.date)
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
all_records = result.scalars().all()
|
|
|
|
# Get the latest record for each measurement point
|
|
latest_records = {}
|
|
for record in all_records:
|
|
point_id = record.measurement_point_id
|
|
if point_id not in latest_records:
|
|
latest_records[point_id] = record
|
|
|
|
return list(latest_records.values())
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get latest traffic data by measurement points",
|
|
city=city, points=len(measurement_point_ids), error=str(e))
|
|
raise DatabaseError(f"Failed to get latest traffic data: {str(e)}")
|
|
|
|
# ================================================================
|
|
# ANALYTICS AND AGGREGATIONS
|
|
# ================================================================
|
|
|
|
async def get_traffic_statistics_by_city(
|
|
self,
|
|
city: str,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
group_by: str = "daily"
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get aggregated traffic statistics by city"""
|
|
try:
|
|
# Determine date truncation based on group_by
|
|
if group_by == "hourly":
|
|
date_trunc = "hour"
|
|
elif group_by == "daily":
|
|
date_trunc = "day"
|
|
elif group_by == "weekly":
|
|
date_trunc = "week"
|
|
elif group_by == "monthly":
|
|
date_trunc = "month"
|
|
else:
|
|
raise ValidationError(f"Invalid group_by value: {group_by}")
|
|
|
|
# Build aggregation query
|
|
if self.session.bind.dialect.name == 'postgresql':
|
|
query = text("""
|
|
SELECT
|
|
DATE_TRUNC(:date_trunc, date) as period,
|
|
city,
|
|
district,
|
|
COUNT(*) as record_count,
|
|
AVG(traffic_volume) as avg_traffic_volume,
|
|
MAX(traffic_volume) as max_traffic_volume,
|
|
AVG(pedestrian_count) as avg_pedestrian_count,
|
|
AVG(average_speed) as avg_speed,
|
|
COUNT(CASE WHEN congestion_level = 'high' THEN 1 END) as high_congestion_count,
|
|
COUNT(CASE WHEN is_synthetic = false THEN 1 END) as real_data_count,
|
|
COUNT(CASE WHEN has_pedestrian_inference = true THEN 1 END) as pedestrian_inference_count
|
|
FROM traffic_data
|
|
WHERE city = :city
|
|
""")
|
|
else:
|
|
# SQLite fallback
|
|
query = text("""
|
|
SELECT
|
|
DATE(date) as period,
|
|
city,
|
|
district,
|
|
COUNT(*) as record_count,
|
|
AVG(traffic_volume) as avg_traffic_volume,
|
|
MAX(traffic_volume) as max_traffic_volume,
|
|
AVG(pedestrian_count) as avg_pedestrian_count,
|
|
AVG(average_speed) as avg_speed,
|
|
SUM(CASE WHEN congestion_level = 'high' THEN 1 ELSE 0 END) as high_congestion_count,
|
|
SUM(CASE WHEN is_synthetic = 0 THEN 1 ELSE 0 END) as real_data_count,
|
|
SUM(CASE WHEN has_pedestrian_inference = 1 THEN 1 ELSE 0 END) as pedestrian_inference_count
|
|
FROM traffic_data
|
|
WHERE city = :city
|
|
""")
|
|
|
|
params = {
|
|
"city": city,
|
|
"date_trunc": date_trunc
|
|
}
|
|
|
|
# Add date filters
|
|
if start_date:
|
|
query = text(str(query) + " AND date >= :start_date")
|
|
params["start_date"] = self._ensure_utc_datetime(start_date)
|
|
|
|
if end_date:
|
|
query = text(str(query) + " AND date <= :end_date")
|
|
params["end_date"] = self._ensure_utc_datetime(end_date)
|
|
|
|
# Add GROUP BY and ORDER BY
|
|
query = text(str(query) + " GROUP BY period, city, district ORDER BY period DESC")
|
|
|
|
result = await self.session.execute(query, params)
|
|
rows = result.fetchall()
|
|
|
|
# Convert to list of dictionaries
|
|
statistics = []
|
|
for row in rows:
|
|
statistics.append({
|
|
"period": group_by,
|
|
"date": row.period,
|
|
"city": row.city,
|
|
"district": row.district,
|
|
"record_count": row.record_count,
|
|
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
|
"max_traffic_volume": row.max_traffic_volume or 0,
|
|
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0),
|
|
"avg_speed": float(row.avg_speed or 0),
|
|
"high_congestion_count": row.high_congestion_count or 0,
|
|
"real_data_percentage": round((row.real_data_count or 0) / max(1, row.record_count) * 100, 2),
|
|
"pedestrian_inference_percentage": round((row.pedestrian_inference_count or 0) / max(1, row.record_count) * 100, 2)
|
|
})
|
|
|
|
return statistics
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get traffic statistics by city",
|
|
city=city, group_by=group_by, error=str(e))
|
|
raise DatabaseError(f"Traffic statistics query failed: {str(e)}")
|
|
|
|
async def get_congestion_heatmap_data(
|
|
self,
|
|
city: str,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
time_granularity: str = "hour"
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get congestion data for heatmap visualization"""
|
|
try:
|
|
if time_granularity == "hour":
|
|
time_extract = "EXTRACT(hour FROM date)"
|
|
elif time_granularity == "day_of_week":
|
|
time_extract = "EXTRACT(dow FROM date)"
|
|
else:
|
|
time_extract = "EXTRACT(hour FROM date)"
|
|
|
|
query = text(f"""
|
|
SELECT
|
|
{time_extract} as time_period,
|
|
district,
|
|
measurement_point_id,
|
|
latitude,
|
|
longitude,
|
|
AVG(CASE
|
|
WHEN congestion_level = 'low' THEN 1
|
|
WHEN congestion_level = 'medium' THEN 2
|
|
WHEN congestion_level = 'high' THEN 3
|
|
WHEN congestion_level = 'blocked' THEN 4
|
|
ELSE 1
|
|
END) as avg_congestion_score,
|
|
COUNT(*) as data_points,
|
|
AVG(traffic_volume) as avg_traffic_volume,
|
|
AVG(pedestrian_count) as avg_pedestrian_count
|
|
FROM traffic_data
|
|
WHERE city = :city
|
|
AND date >= :start_date
|
|
AND date <= :end_date
|
|
AND latitude IS NOT NULL
|
|
AND longitude IS NOT NULL
|
|
GROUP BY time_period, district, measurement_point_id, latitude, longitude
|
|
ORDER BY time_period, district, avg_congestion_score DESC
|
|
""")
|
|
|
|
params = {
|
|
"city": city,
|
|
"start_date": self._ensure_utc_datetime(start_date),
|
|
"end_date": self._ensure_utc_datetime(end_date)
|
|
}
|
|
|
|
result = await self.session.execute(query, params)
|
|
rows = result.fetchall()
|
|
|
|
heatmap_data = []
|
|
for row in rows:
|
|
heatmap_data.append({
|
|
"time_period": int(row.time_period or 0),
|
|
"district": row.district,
|
|
"measurement_point_id": row.measurement_point_id,
|
|
"latitude": float(row.latitude),
|
|
"longitude": float(row.longitude),
|
|
"avg_congestion_score": float(row.avg_congestion_score),
|
|
"data_points": row.data_points,
|
|
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
|
|
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0)
|
|
})
|
|
|
|
return heatmap_data
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get congestion heatmap data",
|
|
city=city, error=str(e))
|
|
raise DatabaseError(f"Congestion heatmap query failed: {str(e)}")
|
|
|
|
# ================================================================
|
|
# BULK OPERATIONS AND DATA MANAGEMENT
|
|
# ================================================================
|
|
|
|
async def create_bulk_traffic_data(
|
|
self,
|
|
traffic_records: List[Dict[str, Any]],
|
|
city: str,
|
|
tenant_id: Optional[str] = None
|
|
) -> List[TrafficData]:
|
|
"""Create multiple traffic records in bulk with enhanced validation"""
|
|
try:
|
|
# Ensure all records have city and tenant_id
|
|
for record in traffic_records:
|
|
record["city"] = city
|
|
if tenant_id:
|
|
record["tenant_id"] = tenant_id
|
|
# Ensure dates are timezone-aware
|
|
if "date" in record and record["date"]:
|
|
record["date"] = self._ensure_utc_datetime(record["date"])
|
|
|
|
# Enhanced validation
|
|
validated_records = []
|
|
for record in traffic_records:
|
|
if self._validate_traffic_record(record):
|
|
validated_records.append(record)
|
|
else:
|
|
logger.warning("Invalid traffic record skipped",
|
|
city=city, record_keys=list(record.keys()))
|
|
|
|
if not validated_records:
|
|
logger.warning("No valid traffic records to create", city=city)
|
|
return []
|
|
|
|
# Use bulk create with deduplication
|
|
created_records = await self.bulk_create_with_deduplication(validated_records)
|
|
|
|
logger.info("Bulk traffic data creation completed",
|
|
city=city, requested=len(traffic_records),
|
|
validated=len(validated_records), created=len(created_records))
|
|
|
|
return created_records
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to create bulk traffic data",
|
|
city=city, record_count=len(traffic_records), error=str(e))
|
|
raise DatabaseError(f"Bulk traffic creation failed: {str(e)}")
|
|
|
|
async def bulk_create_with_deduplication(
|
|
self,
|
|
records: List[Dict[str, Any]]
|
|
) -> List[TrafficData]:
|
|
"""Bulk create with automatic deduplication based on location, city, and date"""
|
|
try:
|
|
if not records:
|
|
return []
|
|
|
|
# Extract unique keys for deduplication check
|
|
unique_keys = []
|
|
for record in records:
|
|
key = (
|
|
record.get('location_id'),
|
|
record.get('city'),
|
|
record.get('date'),
|
|
record.get('measurement_point_id')
|
|
)
|
|
unique_keys.append(key)
|
|
|
|
# Check for existing records
|
|
location_ids = [key[0] for key in unique_keys if key[0]]
|
|
cities = [key[1] for key in unique_keys if key[1]]
|
|
dates = [key[2] for key in unique_keys if key[2]]
|
|
|
|
# For large datasets, use chunked deduplication to avoid memory issues
|
|
if len(location_ids) > 1000:
|
|
logger.info(f"Large dataset detected ({len(records)} records), using chunked deduplication")
|
|
new_records = []
|
|
chunk_size = 1000
|
|
|
|
for i in range(0, len(records), chunk_size):
|
|
chunk_records = records[i:i + chunk_size]
|
|
chunk_keys = unique_keys[i:i + chunk_size]
|
|
|
|
# Get unique values for this chunk
|
|
chunk_location_ids = list(set(key[0] for key in chunk_keys if key[0]))
|
|
chunk_cities = list(set(key[1] for key in chunk_keys if key[1]))
|
|
chunk_dates = list(set(key[2] for key in chunk_keys if key[2]))
|
|
|
|
if chunk_location_ids and chunk_cities and chunk_dates:
|
|
existing_query = select(
|
|
self.model.location_id,
|
|
self.model.city,
|
|
self.model.date,
|
|
self.model.measurement_point_id
|
|
).where(
|
|
and_(
|
|
self.model.location_id.in_(chunk_location_ids),
|
|
self.model.city.in_(chunk_cities),
|
|
self.model.date.in_(chunk_dates)
|
|
)
|
|
)
|
|
|
|
result = await self.session.execute(existing_query)
|
|
chunk_existing_keys = set(result.fetchall())
|
|
|
|
# Filter chunk duplicates
|
|
for j, record in enumerate(chunk_records):
|
|
key = chunk_keys[j]
|
|
if key not in chunk_existing_keys:
|
|
new_records.append(record)
|
|
else:
|
|
new_records.extend(chunk_records)
|
|
|
|
logger.debug("Chunked deduplication completed",
|
|
total_records=len(records),
|
|
new_records=len(new_records))
|
|
records = new_records
|
|
|
|
elif location_ids and cities and dates:
|
|
existing_query = select(
|
|
self.model.location_id,
|
|
self.model.city,
|
|
self.model.date,
|
|
self.model.measurement_point_id
|
|
).where(
|
|
and_(
|
|
self.model.location_id.in_(location_ids),
|
|
self.model.city.in_(cities),
|
|
self.model.date.in_(dates)
|
|
)
|
|
)
|
|
|
|
result = await self.session.execute(existing_query)
|
|
existing_keys = set(result.fetchall())
|
|
|
|
# Filter out duplicates
|
|
new_records = []
|
|
for i, record in enumerate(records):
|
|
key = unique_keys[i]
|
|
if key not in existing_keys:
|
|
new_records.append(record)
|
|
|
|
logger.debug("Standard deduplication completed",
|
|
total_records=len(records),
|
|
existing_records=len(existing_keys),
|
|
new_records=len(new_records))
|
|
|
|
records = new_records
|
|
|
|
# Proceed with bulk creation
|
|
return await self.bulk_create(records)
|
|
|
|
except Exception as e:
|
|
logger.error("Failed bulk create with deduplication", error=str(e))
|
|
raise DatabaseError(f"Bulk create with deduplication failed: {str(e)}")
|
|
|
|
def _validate_traffic_record(self, record: Dict[str, Any]) -> bool:
|
|
"""Enhanced validation for traffic records"""
|
|
required_fields = ['date', 'city']
|
|
|
|
# Check required fields
|
|
for field in required_fields:
|
|
if not record.get(field):
|
|
return False
|
|
|
|
# Validate city
|
|
city = record.get('city', '').lower()
|
|
if city not in ['madrid', 'barcelona', 'valencia', 'test']: # Extendable list
|
|
return False
|
|
|
|
# Validate data ranges
|
|
traffic_volume = record.get('traffic_volume')
|
|
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 50000):
|
|
return False
|
|
|
|
pedestrian_count = record.get('pedestrian_count')
|
|
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
|
|
return False
|
|
|
|
average_speed = record.get('average_speed')
|
|
if average_speed is not None and (average_speed < 0 or average_speed > 200):
|
|
return False
|
|
|
|
congestion_level = record.get('congestion_level')
|
|
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
|
|
return False
|
|
|
|
return True
|
|
|
|
# ================================================================
|
|
# TRAINING DATA SPECIFIC OPERATIONS
|
|
# ================================================================
|
|
|
|
async def get_training_data_by_location(
|
|
self,
|
|
latitude: float,
|
|
longitude: float,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
tenant_id: Optional[str] = None,
|
|
include_pedestrian_inference: bool = True
|
|
) -> List[Dict[str, Any]]:
|
|
"""Get optimized training data for ML models"""
|
|
try:
|
|
location_id = f"{latitude:.4f},{longitude:.4f}"
|
|
|
|
query = select(self.model).where(
|
|
and_(
|
|
self.model.location_id == location_id,
|
|
self.model.date >= self._ensure_utc_datetime(start_date),
|
|
self.model.date <= self._ensure_utc_datetime(end_date)
|
|
)
|
|
)
|
|
|
|
if tenant_id:
|
|
query = query.where(self.model.tenant_id == tenant_id)
|
|
|
|
if include_pedestrian_inference:
|
|
# Prefer records with pedestrian inference
|
|
query = query.order_by(
|
|
desc(self.model.has_pedestrian_inference),
|
|
desc(self.model.data_quality_score),
|
|
self.model.date
|
|
)
|
|
else:
|
|
query = query.order_by(
|
|
desc(self.model.data_quality_score),
|
|
self.model.date
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
records = result.scalars().all()
|
|
|
|
# Convert to training format with enhanced features
|
|
training_data = []
|
|
for record in records:
|
|
training_record = {
|
|
'date': record.date,
|
|
'traffic_volume': record.traffic_volume or 0,
|
|
'pedestrian_count': record.pedestrian_count or 0,
|
|
'congestion_level': record.congestion_level or 'medium',
|
|
'average_speed': record.average_speed or 25.0,
|
|
'city': record.city,
|
|
'district': record.district,
|
|
'measurement_point_id': record.measurement_point_id,
|
|
'source': record.source,
|
|
'is_synthetic': record.is_synthetic or False,
|
|
'has_pedestrian_inference': record.has_pedestrian_inference or False,
|
|
'data_quality_score': record.data_quality_score or 50.0,
|
|
|
|
# Enhanced features for training
|
|
'hour_of_day': record.date.hour if record.date else 12,
|
|
'day_of_week': record.date.weekday() if record.date else 0,
|
|
'month': record.date.month if record.date else 1,
|
|
|
|
# City-specific features
|
|
'city_specific_data': record.city_specific_data or {}
|
|
}
|
|
|
|
training_data.append(training_record)
|
|
|
|
logger.info("Retrieved training data",
|
|
location_id=location_id, records=len(training_data),
|
|
with_pedestrian_inference=sum(1 for r in training_data if r['has_pedestrian_inference']))
|
|
|
|
return training_data
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get training data",
|
|
latitude=latitude, longitude=longitude, error=str(e))
|
|
raise DatabaseError(f"Training data retrieval failed: {str(e)}")
|
|
|
|
async def get_historical_data_by_location(
|
|
self,
|
|
latitude: float,
|
|
longitude: float,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
tenant_id: Optional[str] = None
|
|
) -> List[TrafficData]:
|
|
"""Get historical traffic data for a specific location and date range"""
|
|
return await self.get_by_location_and_date_range(
|
|
latitude=latitude,
|
|
longitude=longitude,
|
|
start_date=start_date,
|
|
end_date=end_date,
|
|
tenant_id=tenant_id,
|
|
limit=1000000 # Large limit for historical data
|
|
)
|
|
|
|
async def count_records_in_period(
|
|
self,
|
|
latitude: float,
|
|
longitude: float,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
city: Optional[str] = None,
|
|
tenant_id: Optional[str] = None
|
|
) -> int:
|
|
"""Count traffic records for a specific location and time period"""
|
|
try:
|
|
location_id = f"{latitude:.4f},{longitude:.4f}"
|
|
|
|
query = select(func.count(self.model.id)).where(
|
|
and_(
|
|
self.model.location_id == location_id,
|
|
self.model.date >= self._ensure_utc_datetime(start_date),
|
|
self.model.date <= self._ensure_utc_datetime(end_date)
|
|
)
|
|
)
|
|
|
|
if city:
|
|
query = query.where(self.model.city == city)
|
|
|
|
if tenant_id:
|
|
query = query.where(self.model.tenant_id == tenant_id)
|
|
|
|
result = await self.session.execute(query)
|
|
count = result.scalar()
|
|
|
|
return count or 0
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to count records in period",
|
|
latitude=latitude, longitude=longitude, error=str(e))
|
|
raise DatabaseError(f"Record count failed: {str(e)}")
|
|
|
|
# ================================================================
|
|
# DATA QUALITY AND MAINTENANCE
|
|
# ================================================================
|
|
|
|
async def update_data_quality_scores(self, city: str) -> int:
|
|
"""Update data quality scores based on various criteria"""
|
|
try:
|
|
# Calculate quality scores based on data completeness and consistency
|
|
query = text("""
|
|
UPDATE traffic_data
|
|
SET data_quality_score = (
|
|
CASE
|
|
WHEN traffic_volume IS NOT NULL THEN 20 ELSE 0 END +
|
|
CASE
|
|
WHEN pedestrian_count IS NOT NULL THEN 20 ELSE 0 END +
|
|
CASE
|
|
WHEN average_speed IS NOT NULL AND average_speed > 0 THEN 20 ELSE 0 END +
|
|
CASE
|
|
WHEN congestion_level IS NOT NULL THEN 15 ELSE 0 END +
|
|
CASE
|
|
WHEN measurement_point_id IS NOT NULL THEN 10 ELSE 0 END +
|
|
CASE
|
|
WHEN district IS NOT NULL THEN 10 ELSE 0 END +
|
|
CASE
|
|
WHEN has_pedestrian_inference = true THEN 5 ELSE 0 END
|
|
),
|
|
updated_at = :updated_at
|
|
WHERE city = :city AND data_quality_score IS NULL
|
|
""")
|
|
|
|
params = {
|
|
"city": city,
|
|
"updated_at": datetime.now(timezone.utc)
|
|
}
|
|
|
|
result = await self.session.execute(query, params)
|
|
updated_count = result.rowcount
|
|
await self.session.commit()
|
|
|
|
logger.info("Updated data quality scores",
|
|
city=city, updated_count=updated_count)
|
|
|
|
return updated_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to update data quality scores",
|
|
city=city, error=str(e))
|
|
await self.session.rollback()
|
|
raise DatabaseError(f"Data quality update failed: {str(e)}")
|
|
|
|
async def cleanup_old_synthetic_data(
|
|
self,
|
|
city: str,
|
|
days_to_keep: int = 90
|
|
) -> int:
|
|
"""Clean up old synthetic data while preserving real data"""
|
|
try:
|
|
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
|
|
|
|
query = delete(self.model).where(
|
|
and_(
|
|
self.model.city == city,
|
|
self.model.is_synthetic == True,
|
|
self.model.date < cutoff_date
|
|
)
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
deleted_count = result.rowcount
|
|
await self.session.commit()
|
|
|
|
logger.info("Cleaned up old synthetic data",
|
|
city=city, deleted_count=deleted_count, days_kept=days_to_keep)
|
|
|
|
return deleted_count
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to cleanup old synthetic data",
|
|
city=city, error=str(e))
|
|
await self.session.rollback()
|
|
raise DatabaseError(f"Synthetic data cleanup failed: {str(e)}")
|
|
|
|
|
|
class TrafficMeasurementPointRepository(DataBaseRepository[TrafficMeasurementPoint, Dict, Dict]):
|
|
"""Repository for traffic measurement points across cities"""
|
|
|
|
async def get_points_near_location(
|
|
self,
|
|
latitude: float,
|
|
longitude: float,
|
|
city: str,
|
|
radius_km: float = 10.0,
|
|
limit: int = 20
|
|
) -> List[TrafficMeasurementPoint]:
|
|
"""Get measurement points near a location using spatial query"""
|
|
try:
|
|
# Simple distance calculation (for more precise, use PostGIS)
|
|
query = text("""
|
|
SELECT *,
|
|
(6371 * acos(
|
|
cos(radians(:lat)) * cos(radians(latitude)) *
|
|
cos(radians(longitude) - radians(:lon)) +
|
|
sin(radians(:lat)) * sin(radians(latitude))
|
|
)) as distance_km
|
|
FROM traffic_measurement_points
|
|
WHERE city = :city
|
|
AND is_active = true
|
|
HAVING distance_km <= :radius_km
|
|
ORDER BY distance_km
|
|
LIMIT :limit
|
|
""")
|
|
|
|
params = {
|
|
"lat": latitude,
|
|
"lon": longitude,
|
|
"city": city,
|
|
"radius_km": radius_km,
|
|
"limit": limit
|
|
}
|
|
|
|
result = await self.session.execute(query, params)
|
|
rows = result.fetchall()
|
|
|
|
# Convert rows to model instances
|
|
points = []
|
|
for row in rows:
|
|
point = TrafficMeasurementPoint()
|
|
for key, value in row._mapping.items():
|
|
if hasattr(point, key) and key != 'distance_km':
|
|
setattr(point, key, value)
|
|
points.append(point)
|
|
|
|
return points
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get measurement points near location",
|
|
latitude=latitude, longitude=longitude, city=city, error=str(e))
|
|
raise DatabaseError(f"Measurement points query failed: {str(e)}")
|
|
|
|
|
|
class TrafficBackgroundJobRepository(DataBaseRepository[TrafficDataBackgroundJob, Dict, Dict]):
|
|
"""Repository for managing background traffic data jobs"""
|
|
|
|
async def get_pending_jobs_by_city(self, city: str) -> List[TrafficDataBackgroundJob]:
|
|
"""Get pending background jobs for a specific city"""
|
|
try:
|
|
query = select(self.model).where(
|
|
and_(
|
|
self.model.city == city,
|
|
self.model.status == 'pending'
|
|
)
|
|
).order_by(self.model.scheduled_at)
|
|
|
|
result = await self.session.execute(query)
|
|
return result.scalars().all()
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get pending jobs by city", city=city, error=str(e))
|
|
raise DatabaseError(f"Background jobs query failed: {str(e)}")
|
|
|
|
async def update_job_progress(
|
|
self,
|
|
job_id: str,
|
|
progress_percentage: float,
|
|
records_processed: int = 0,
|
|
records_stored: int = 0
|
|
) -> bool:
|
|
"""Update job progress"""
|
|
try:
|
|
query = update(self.model).where(
|
|
self.model.id == job_id
|
|
).values(
|
|
progress_percentage=progress_percentage,
|
|
records_processed=records_processed,
|
|
records_stored=records_stored,
|
|
updated_at=datetime.now(timezone.utc)
|
|
)
|
|
|
|
result = await self.session.execute(query)
|
|
await self.session.commit()
|
|
|
|
return result.rowcount > 0
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to update job progress", job_id=job_id, error=str(e))
|
|
await self.session.rollback()
|
|
raise DatabaseError(f"Job progress update failed: {str(e)}") |