Files
bakery-ia/services/data/app/repositories/traffic_repository.py
2025-08-10 17:31:38 +02:00

874 lines
36 KiB
Python

# ================================================================
# services/data/app/repositories/traffic_repository.py
# ================================================================
"""
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
Follows existing repository architecture while adding city-specific functionality
"""
from typing import Optional, List, Dict, Any, Type, Tuple
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
from sqlalchemy.orm import selectinload
from datetime import datetime, timezone, timedelta
import structlog
from .base import DataBaseRepository
from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataBackgroundJob
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
from shared.database.exceptions import DatabaseError, ValidationError
logger = structlog.get_logger()
class TrafficRepository(DataBaseRepository[TrafficData, TrafficDataCreate, Dict]):
"""
Enhanced repository for traffic data operations across multiple cities
Provides city-aware queries and advanced traffic analytics
"""
def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300):
super().__init__(model_class, session, cache_ttl)
# ================================================================
# CORE TRAFFIC DATA OPERATIONS
# ================================================================
async def get_by_location_and_date_range(
self,
latitude: float,
longitude: float,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
city: Optional[str] = None,
tenant_id: Optional[str] = None,
skip: int = 0,
limit: int = 100
) -> List[TrafficData]:
"""Get traffic data by location and date range with city filtering"""
try:
location_id = f"{latitude:.4f},{longitude:.4f}"
# Build base query
query = select(self.model).where(self.model.location_id == location_id)
# Add city filter if specified
if city:
query = query.where(self.model.city == city)
# Add tenant filter if specified
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
# Add date range filters
if start_date:
start_date = self._ensure_utc_datetime(start_date)
query = query.where(self.model.date >= start_date)
if end_date:
end_date = self._ensure_utc_datetime(end_date)
query = query.where(self.model.date <= end_date)
# Order by date descending (most recent first)
query = query.order_by(desc(self.model.date))
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get traffic data by location and date range",
latitude=latitude, longitude=longitude,
city=city, error=str(e))
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
async def get_by_city_and_date_range(
self,
city: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
district: Optional[str] = None,
measurement_point_ids: Optional[List[str]] = None,
include_synthetic: bool = True,
tenant_id: Optional[str] = None,
skip: int = 0,
limit: int = 1000
) -> List[TrafficData]:
"""Get traffic data by city with advanced filtering options"""
try:
# Build base query
query = select(self.model).where(self.model.city == city)
# Add tenant filter if specified
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
# Add date range filters
if start_date:
start_date = self._ensure_utc_datetime(start_date)
query = query.where(self.model.date >= start_date)
if end_date:
end_date = self._ensure_utc_datetime(end_date)
query = query.where(self.model.date <= end_date)
# Add district filter
if district:
query = query.where(self.model.district == district)
# Add measurement point filter
if measurement_point_ids:
query = query.where(self.model.measurement_point_id.in_(measurement_point_ids))
# Filter synthetic data if requested
if not include_synthetic:
query = query.where(self.model.is_synthetic == False)
# Order by date and measurement point
query = query.order_by(desc(self.model.date), self.model.measurement_point_id)
# Apply pagination
query = query.offset(skip).limit(limit)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get traffic data by city",
city=city, district=district, error=str(e))
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
async def get_latest_by_measurement_points(
self,
measurement_point_ids: List[str],
city: str,
hours_back: int = 24
) -> List[TrafficData]:
"""Get latest traffic data for specific measurement points"""
try:
cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back)
query = select(self.model).where(
and_(
self.model.city == city,
self.model.measurement_point_id.in_(measurement_point_ids),
self.model.date >= cutoff_time
)
).order_by(
self.model.measurement_point_id,
desc(self.model.date)
)
result = await self.session.execute(query)
all_records = result.scalars().all()
# Get the latest record for each measurement point
latest_records = {}
for record in all_records:
point_id = record.measurement_point_id
if point_id not in latest_records:
latest_records[point_id] = record
return list(latest_records.values())
except Exception as e:
logger.error("Failed to get latest traffic data by measurement points",
city=city, points=len(measurement_point_ids), error=str(e))
raise DatabaseError(f"Failed to get latest traffic data: {str(e)}")
# ================================================================
# ANALYTICS AND AGGREGATIONS
# ================================================================
async def get_traffic_statistics_by_city(
self,
city: str,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
group_by: str = "daily"
) -> List[Dict[str, Any]]:
"""Get aggregated traffic statistics by city"""
try:
# Determine date truncation based on group_by
if group_by == "hourly":
date_trunc = "hour"
elif group_by == "daily":
date_trunc = "day"
elif group_by == "weekly":
date_trunc = "week"
elif group_by == "monthly":
date_trunc = "month"
else:
raise ValidationError(f"Invalid group_by value: {group_by}")
# Build aggregation query
if self.session.bind.dialect.name == 'postgresql':
query = text("""
SELECT
DATE_TRUNC(:date_trunc, date) as period,
city,
district,
COUNT(*) as record_count,
AVG(traffic_volume) as avg_traffic_volume,
MAX(traffic_volume) as max_traffic_volume,
AVG(pedestrian_count) as avg_pedestrian_count,
AVG(average_speed) as avg_speed,
COUNT(CASE WHEN congestion_level = 'high' THEN 1 END) as high_congestion_count,
COUNT(CASE WHEN is_synthetic = false THEN 1 END) as real_data_count,
COUNT(CASE WHEN has_pedestrian_inference = true THEN 1 END) as pedestrian_inference_count
FROM traffic_data
WHERE city = :city
""")
else:
# SQLite fallback
query = text("""
SELECT
DATE(date) as period,
city,
district,
COUNT(*) as record_count,
AVG(traffic_volume) as avg_traffic_volume,
MAX(traffic_volume) as max_traffic_volume,
AVG(pedestrian_count) as avg_pedestrian_count,
AVG(average_speed) as avg_speed,
SUM(CASE WHEN congestion_level = 'high' THEN 1 ELSE 0 END) as high_congestion_count,
SUM(CASE WHEN is_synthetic = 0 THEN 1 ELSE 0 END) as real_data_count,
SUM(CASE WHEN has_pedestrian_inference = 1 THEN 1 ELSE 0 END) as pedestrian_inference_count
FROM traffic_data
WHERE city = :city
""")
params = {
"city": city,
"date_trunc": date_trunc
}
# Add date filters
if start_date:
query = text(str(query) + " AND date >= :start_date")
params["start_date"] = self._ensure_utc_datetime(start_date)
if end_date:
query = text(str(query) + " AND date <= :end_date")
params["end_date"] = self._ensure_utc_datetime(end_date)
# Add GROUP BY and ORDER BY
query = text(str(query) + " GROUP BY period, city, district ORDER BY period DESC")
result = await self.session.execute(query, params)
rows = result.fetchall()
# Convert to list of dictionaries
statistics = []
for row in rows:
statistics.append({
"period": group_by,
"date": row.period,
"city": row.city,
"district": row.district,
"record_count": row.record_count,
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
"max_traffic_volume": row.max_traffic_volume or 0,
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0),
"avg_speed": float(row.avg_speed or 0),
"high_congestion_count": row.high_congestion_count or 0,
"real_data_percentage": round((row.real_data_count or 0) / max(1, row.record_count) * 100, 2),
"pedestrian_inference_percentage": round((row.pedestrian_inference_count or 0) / max(1, row.record_count) * 100, 2)
})
return statistics
except Exception as e:
logger.error("Failed to get traffic statistics by city",
city=city, group_by=group_by, error=str(e))
raise DatabaseError(f"Traffic statistics query failed: {str(e)}")
async def get_congestion_heatmap_data(
self,
city: str,
start_date: datetime,
end_date: datetime,
time_granularity: str = "hour"
) -> List[Dict[str, Any]]:
"""Get congestion data for heatmap visualization"""
try:
if time_granularity == "hour":
time_extract = "EXTRACT(hour FROM date)"
elif time_granularity == "day_of_week":
time_extract = "EXTRACT(dow FROM date)"
else:
time_extract = "EXTRACT(hour FROM date)"
query = text(f"""
SELECT
{time_extract} as time_period,
district,
measurement_point_id,
latitude,
longitude,
AVG(CASE
WHEN congestion_level = 'low' THEN 1
WHEN congestion_level = 'medium' THEN 2
WHEN congestion_level = 'high' THEN 3
WHEN congestion_level = 'blocked' THEN 4
ELSE 1
END) as avg_congestion_score,
COUNT(*) as data_points,
AVG(traffic_volume) as avg_traffic_volume,
AVG(pedestrian_count) as avg_pedestrian_count
FROM traffic_data
WHERE city = :city
AND date >= :start_date
AND date <= :end_date
AND latitude IS NOT NULL
AND longitude IS NOT NULL
GROUP BY time_period, district, measurement_point_id, latitude, longitude
ORDER BY time_period, district, avg_congestion_score DESC
""")
params = {
"city": city,
"start_date": self._ensure_utc_datetime(start_date),
"end_date": self._ensure_utc_datetime(end_date)
}
result = await self.session.execute(query, params)
rows = result.fetchall()
heatmap_data = []
for row in rows:
heatmap_data.append({
"time_period": int(row.time_period or 0),
"district": row.district,
"measurement_point_id": row.measurement_point_id,
"latitude": float(row.latitude),
"longitude": float(row.longitude),
"avg_congestion_score": float(row.avg_congestion_score),
"data_points": row.data_points,
"avg_traffic_volume": float(row.avg_traffic_volume or 0),
"avg_pedestrian_count": float(row.avg_pedestrian_count or 0)
})
return heatmap_data
except Exception as e:
logger.error("Failed to get congestion heatmap data",
city=city, error=str(e))
raise DatabaseError(f"Congestion heatmap query failed: {str(e)}")
# ================================================================
# BULK OPERATIONS AND DATA MANAGEMENT
# ================================================================
async def create_bulk_traffic_data(
self,
traffic_records: List[Dict[str, Any]],
city: str,
tenant_id: Optional[str] = None
) -> List[TrafficData]:
"""Create multiple traffic records in bulk with enhanced validation"""
try:
# Ensure all records have city and tenant_id
for record in traffic_records:
record["city"] = city
if tenant_id:
record["tenant_id"] = tenant_id
# Ensure dates are timezone-aware
if "date" in record and record["date"]:
record["date"] = self._ensure_utc_datetime(record["date"])
# Enhanced validation
validated_records = []
for record in traffic_records:
if self._validate_traffic_record(record):
validated_records.append(record)
else:
logger.warning("Invalid traffic record skipped",
city=city, record_keys=list(record.keys()))
if not validated_records:
logger.warning("No valid traffic records to create", city=city)
return []
# Use bulk create with deduplication
created_records = await self.bulk_create_with_deduplication(validated_records)
logger.info("Bulk traffic data creation completed",
city=city, requested=len(traffic_records),
validated=len(validated_records), created=len(created_records))
return created_records
except Exception as e:
logger.error("Failed to create bulk traffic data",
city=city, record_count=len(traffic_records), error=str(e))
raise DatabaseError(f"Bulk traffic creation failed: {str(e)}")
async def bulk_create_with_deduplication(
self,
records: List[Dict[str, Any]]
) -> List[TrafficData]:
"""Bulk create with automatic deduplication based on location, city, and date"""
try:
if not records:
return []
# Extract unique keys for deduplication check
unique_keys = []
for record in records:
key = (
record.get('location_id'),
record.get('city'),
record.get('date'),
record.get('measurement_point_id')
)
unique_keys.append(key)
# Check for existing records
location_ids = [key[0] for key in unique_keys if key[0]]
cities = [key[1] for key in unique_keys if key[1]]
dates = [key[2] for key in unique_keys if key[2]]
# For large datasets, use chunked deduplication to avoid memory issues
if len(location_ids) > 1000:
logger.info(f"Large dataset detected ({len(records)} records), using chunked deduplication")
new_records = []
chunk_size = 1000
for i in range(0, len(records), chunk_size):
chunk_records = records[i:i + chunk_size]
chunk_keys = unique_keys[i:i + chunk_size]
# Get unique values for this chunk
chunk_location_ids = list(set(key[0] for key in chunk_keys if key[0]))
chunk_cities = list(set(key[1] for key in chunk_keys if key[1]))
chunk_dates = list(set(key[2] for key in chunk_keys if key[2]))
if chunk_location_ids and chunk_cities and chunk_dates:
existing_query = select(
self.model.location_id,
self.model.city,
self.model.date,
self.model.measurement_point_id
).where(
and_(
self.model.location_id.in_(chunk_location_ids),
self.model.city.in_(chunk_cities),
self.model.date.in_(chunk_dates)
)
)
result = await self.session.execute(existing_query)
chunk_existing_keys = set(result.fetchall())
# Filter chunk duplicates
for j, record in enumerate(chunk_records):
key = chunk_keys[j]
if key not in chunk_existing_keys:
new_records.append(record)
else:
new_records.extend(chunk_records)
logger.debug("Chunked deduplication completed",
total_records=len(records),
new_records=len(new_records))
records = new_records
elif location_ids and cities and dates:
existing_query = select(
self.model.location_id,
self.model.city,
self.model.date,
self.model.measurement_point_id
).where(
and_(
self.model.location_id.in_(location_ids),
self.model.city.in_(cities),
self.model.date.in_(dates)
)
)
result = await self.session.execute(existing_query)
existing_keys = set(result.fetchall())
# Filter out duplicates
new_records = []
for i, record in enumerate(records):
key = unique_keys[i]
if key not in existing_keys:
new_records.append(record)
logger.debug("Standard deduplication completed",
total_records=len(records),
existing_records=len(existing_keys),
new_records=len(new_records))
records = new_records
# Proceed with bulk creation
return await self.bulk_create(records)
except Exception as e:
logger.error("Failed bulk create with deduplication", error=str(e))
raise DatabaseError(f"Bulk create with deduplication failed: {str(e)}")
def _validate_traffic_record(self, record: Dict[str, Any]) -> bool:
"""Enhanced validation for traffic records"""
required_fields = ['date', 'city']
# Check required fields
for field in required_fields:
if not record.get(field):
return False
# Validate city
city = record.get('city', '').lower()
if city not in ['madrid', 'barcelona', 'valencia', 'test']: # Extendable list
return False
# Validate data ranges
traffic_volume = record.get('traffic_volume')
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 50000):
return False
pedestrian_count = record.get('pedestrian_count')
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
return False
average_speed = record.get('average_speed')
if average_speed is not None and (average_speed < 0 or average_speed > 200):
return False
congestion_level = record.get('congestion_level')
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
return False
return True
# ================================================================
# TRAINING DATA SPECIFIC OPERATIONS
# ================================================================
async def get_training_data_by_location(
self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
tenant_id: Optional[str] = None,
include_pedestrian_inference: bool = True
) -> List[Dict[str, Any]]:
"""Get optimized training data for ML models"""
try:
location_id = f"{latitude:.4f},{longitude:.4f}"
query = select(self.model).where(
and_(
self.model.location_id == location_id,
self.model.date >= self._ensure_utc_datetime(start_date),
self.model.date <= self._ensure_utc_datetime(end_date)
)
)
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
if include_pedestrian_inference:
# Prefer records with pedestrian inference
query = query.order_by(
desc(self.model.has_pedestrian_inference),
desc(self.model.data_quality_score),
self.model.date
)
else:
query = query.order_by(
desc(self.model.data_quality_score),
self.model.date
)
result = await self.session.execute(query)
records = result.scalars().all()
# Convert to training format with enhanced features
training_data = []
for record in records:
training_record = {
'date': record.date,
'traffic_volume': record.traffic_volume or 0,
'pedestrian_count': record.pedestrian_count or 0,
'congestion_level': record.congestion_level or 'medium',
'average_speed': record.average_speed or 25.0,
'city': record.city,
'district': record.district,
'measurement_point_id': record.measurement_point_id,
'source': record.source,
'is_synthetic': record.is_synthetic or False,
'has_pedestrian_inference': record.has_pedestrian_inference or False,
'data_quality_score': record.data_quality_score or 50.0,
# Enhanced features for training
'hour_of_day': record.date.hour if record.date else 12,
'day_of_week': record.date.weekday() if record.date else 0,
'month': record.date.month if record.date else 1,
# City-specific features
'city_specific_data': record.city_specific_data or {}
}
training_data.append(training_record)
logger.info("Retrieved training data",
location_id=location_id, records=len(training_data),
with_pedestrian_inference=sum(1 for r in training_data if r['has_pedestrian_inference']))
return training_data
except Exception as e:
logger.error("Failed to get training data",
latitude=latitude, longitude=longitude, error=str(e))
raise DatabaseError(f"Training data retrieval failed: {str(e)}")
async def get_historical_data_by_location(
self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
tenant_id: Optional[str] = None
) -> List[TrafficData]:
"""Get historical traffic data for a specific location and date range"""
return await self.get_by_location_and_date_range(
latitude=latitude,
longitude=longitude,
start_date=start_date,
end_date=end_date,
tenant_id=tenant_id,
limit=1000000 # Large limit for historical data
)
async def count_records_in_period(
self,
latitude: float,
longitude: float,
start_date: datetime,
end_date: datetime,
city: Optional[str] = None,
tenant_id: Optional[str] = None
) -> int:
"""Count traffic records for a specific location and time period"""
try:
location_id = f"{latitude:.4f},{longitude:.4f}"
query = select(func.count(self.model.id)).where(
and_(
self.model.location_id == location_id,
self.model.date >= self._ensure_utc_datetime(start_date),
self.model.date <= self._ensure_utc_datetime(end_date)
)
)
if city:
query = query.where(self.model.city == city)
if tenant_id:
query = query.where(self.model.tenant_id == tenant_id)
result = await self.session.execute(query)
count = result.scalar()
return count or 0
except Exception as e:
logger.error("Failed to count records in period",
latitude=latitude, longitude=longitude, error=str(e))
raise DatabaseError(f"Record count failed: {str(e)}")
# ================================================================
# DATA QUALITY AND MAINTENANCE
# ================================================================
async def update_data_quality_scores(self, city: str) -> int:
"""Update data quality scores based on various criteria"""
try:
# Calculate quality scores based on data completeness and consistency
query = text("""
UPDATE traffic_data
SET data_quality_score = (
CASE
WHEN traffic_volume IS NOT NULL THEN 20 ELSE 0 END +
CASE
WHEN pedestrian_count IS NOT NULL THEN 20 ELSE 0 END +
CASE
WHEN average_speed IS NOT NULL AND average_speed > 0 THEN 20 ELSE 0 END +
CASE
WHEN congestion_level IS NOT NULL THEN 15 ELSE 0 END +
CASE
WHEN measurement_point_id IS NOT NULL THEN 10 ELSE 0 END +
CASE
WHEN district IS NOT NULL THEN 10 ELSE 0 END +
CASE
WHEN has_pedestrian_inference = true THEN 5 ELSE 0 END
),
updated_at = :updated_at
WHERE city = :city AND data_quality_score IS NULL
""")
params = {
"city": city,
"updated_at": datetime.now(timezone.utc)
}
result = await self.session.execute(query, params)
updated_count = result.rowcount
await self.session.commit()
logger.info("Updated data quality scores",
city=city, updated_count=updated_count)
return updated_count
except Exception as e:
logger.error("Failed to update data quality scores",
city=city, error=str(e))
await self.session.rollback()
raise DatabaseError(f"Data quality update failed: {str(e)}")
async def cleanup_old_synthetic_data(
self,
city: str,
days_to_keep: int = 90
) -> int:
"""Clean up old synthetic data while preserving real data"""
try:
cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep)
query = delete(self.model).where(
and_(
self.model.city == city,
self.model.is_synthetic == True,
self.model.date < cutoff_date
)
)
result = await self.session.execute(query)
deleted_count = result.rowcount
await self.session.commit()
logger.info("Cleaned up old synthetic data",
city=city, deleted_count=deleted_count, days_kept=days_to_keep)
return deleted_count
except Exception as e:
logger.error("Failed to cleanup old synthetic data",
city=city, error=str(e))
await self.session.rollback()
raise DatabaseError(f"Synthetic data cleanup failed: {str(e)}")
class TrafficMeasurementPointRepository(DataBaseRepository[TrafficMeasurementPoint, Dict, Dict]):
"""Repository for traffic measurement points across cities"""
async def get_points_near_location(
self,
latitude: float,
longitude: float,
city: str,
radius_km: float = 10.0,
limit: int = 20
) -> List[TrafficMeasurementPoint]:
"""Get measurement points near a location using spatial query"""
try:
# Simple distance calculation (for more precise, use PostGIS)
query = text("""
SELECT *,
(6371 * acos(
cos(radians(:lat)) * cos(radians(latitude)) *
cos(radians(longitude) - radians(:lon)) +
sin(radians(:lat)) * sin(radians(latitude))
)) as distance_km
FROM traffic_measurement_points
WHERE city = :city
AND is_active = true
HAVING distance_km <= :radius_km
ORDER BY distance_km
LIMIT :limit
""")
params = {
"lat": latitude,
"lon": longitude,
"city": city,
"radius_km": radius_km,
"limit": limit
}
result = await self.session.execute(query, params)
rows = result.fetchall()
# Convert rows to model instances
points = []
for row in rows:
point = TrafficMeasurementPoint()
for key, value in row._mapping.items():
if hasattr(point, key) and key != 'distance_km':
setattr(point, key, value)
points.append(point)
return points
except Exception as e:
logger.error("Failed to get measurement points near location",
latitude=latitude, longitude=longitude, city=city, error=str(e))
raise DatabaseError(f"Measurement points query failed: {str(e)}")
class TrafficBackgroundJobRepository(DataBaseRepository[TrafficDataBackgroundJob, Dict, Dict]):
"""Repository for managing background traffic data jobs"""
async def get_pending_jobs_by_city(self, city: str) -> List[TrafficDataBackgroundJob]:
"""Get pending background jobs for a specific city"""
try:
query = select(self.model).where(
and_(
self.model.city == city,
self.model.status == 'pending'
)
).order_by(self.model.scheduled_at)
result = await self.session.execute(query)
return result.scalars().all()
except Exception as e:
logger.error("Failed to get pending jobs by city", city=city, error=str(e))
raise DatabaseError(f"Background jobs query failed: {str(e)}")
async def update_job_progress(
self,
job_id: str,
progress_percentage: float,
records_processed: int = 0,
records_stored: int = 0
) -> bool:
"""Update job progress"""
try:
query = update(self.model).where(
self.model.id == job_id
).values(
progress_percentage=progress_percentage,
records_processed=records_processed,
records_stored=records_stored,
updated_at=datetime.now(timezone.utc)
)
result = await self.session.execute(query)
await self.session.commit()
return result.rowcount > 0
except Exception as e:
logger.error("Failed to update job progress", job_id=job_id, error=str(e))
await self.session.rollback()
raise DatabaseError(f"Job progress update failed: {str(e)}")