195 lines
8.1 KiB
Python
195 lines
8.1 KiB
Python
# ================================================================
|
|
# services/data/app/repositories/traffic_repository.py
|
|
# ================================================================
|
|
"""
|
|
Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns
|
|
Follows existing repository architecture while adding city-specific functionality
|
|
"""
|
|
|
|
from typing import Optional, List, Dict, Any, Type, Tuple
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete
|
|
from sqlalchemy.orm import selectinload
|
|
from datetime import datetime, timezone, timedelta
|
|
import structlog
|
|
|
|
from app.models.traffic import TrafficData
|
|
from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse
|
|
from shared.database.exceptions import DatabaseError, ValidationError
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class TrafficRepository:
|
|
"""
|
|
Enhanced repository for traffic data operations across multiple cities
|
|
Provides city-aware queries and advanced traffic analytics
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
self.model = TrafficData
|
|
|
|
# ================================================================
|
|
# CORE TRAFFIC DATA OPERATIONS
|
|
# ================================================================
|
|
|
|
async def get_by_location_and_date_range(
|
|
self,
|
|
latitude: float,
|
|
longitude: float,
|
|
start_date: datetime,
|
|
end_date: datetime,
|
|
tenant_id: Optional[str] = None
|
|
) -> List[TrafficData]:
|
|
"""Get traffic data by location and date range"""
|
|
try:
|
|
location_id = f"{latitude:.4f},{longitude:.4f}"
|
|
|
|
# Build base query
|
|
query = select(self.model).where(self.model.location_id == location_id)
|
|
|
|
# Add tenant filter if specified
|
|
if tenant_id:
|
|
query = query.where(self.model.tenant_id == tenant_id)
|
|
|
|
# Add date range filters
|
|
if start_date:
|
|
query = query.where(self.model.date >= start_date)
|
|
|
|
if end_date:
|
|
query = query.where(self.model.date <= end_date)
|
|
|
|
# Order by date
|
|
query = query.order_by(self.model.date)
|
|
|
|
result = await self.session.execute(query)
|
|
return result.scalars().all()
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to get traffic data by location and date range",
|
|
latitude=latitude, longitude=longitude,
|
|
error=str(e))
|
|
raise DatabaseError(f"Failed to get traffic data: {str(e)}")
|
|
|
|
async def store_traffic_data_batch(
|
|
self,
|
|
traffic_data_list: List[Dict[str, Any]],
|
|
location_id: str,
|
|
tenant_id: Optional[str] = None
|
|
) -> int:
|
|
"""Store a batch of traffic data records with enhanced validation and duplicate handling."""
|
|
stored_count = 0
|
|
try:
|
|
if not traffic_data_list:
|
|
return 0
|
|
|
|
# Check for existing records to avoid duplicates - batch the queries to avoid parameter limit
|
|
dates = [data.get('date') for data in traffic_data_list if data.get('date')]
|
|
existing_dates = set()
|
|
if dates:
|
|
# PostgreSQL has a limit of 32767 parameters, so batch the queries
|
|
batch_size = 30000 # Safe batch size under the limit
|
|
for i in range(0, len(dates), batch_size):
|
|
date_batch = dates[i:i + batch_size]
|
|
existing_stmt = select(TrafficData.date).where(
|
|
and_(
|
|
TrafficData.location_id == location_id,
|
|
TrafficData.date.in_(date_batch)
|
|
)
|
|
)
|
|
result = await self.session.execute(existing_stmt)
|
|
existing_dates.update({row[0] for row in result.fetchall()})
|
|
logger.debug(f"Found {len(existing_dates)} existing records for location {location_id}")
|
|
|
|
batch_records = []
|
|
for data in traffic_data_list:
|
|
record_date = data.get('date')
|
|
if not record_date or record_date in existing_dates:
|
|
continue # Skip duplicates
|
|
|
|
# Validate data before preparing for insertion
|
|
if self._validate_traffic_data(data):
|
|
batch_records.append({
|
|
'location_id': location_id,
|
|
'city': data.get('city', 'madrid'), # Default to madrid for historical data
|
|
'tenant_id': tenant_id, # Include tenant_id in batch insert
|
|
'date': record_date,
|
|
'traffic_volume': data.get('traffic_volume'),
|
|
'pedestrian_count': data.get('pedestrian_count'),
|
|
'congestion_level': data.get('congestion_level'),
|
|
'average_speed': data.get('average_speed'),
|
|
'source': data.get('source', 'unknown'),
|
|
'raw_data': str(data)
|
|
})
|
|
|
|
if batch_records:
|
|
# Use bulk insert for performance
|
|
await self.session.execute(
|
|
TrafficData.__table__.insert(),
|
|
batch_records
|
|
)
|
|
await self.session.commit()
|
|
stored_count = len(batch_records)
|
|
logger.info(f"Successfully stored {stored_count} traffic records for location {location_id}")
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to store traffic data batch",
|
|
error=str(e), location_id=location_id)
|
|
await self.session.rollback()
|
|
raise DatabaseError(f"Batch store failed: {str(e)}")
|
|
|
|
return stored_count
|
|
|
|
def _validate_traffic_data(self, data: Dict[str, Any]) -> bool:
|
|
"""Validate traffic data before storage"""
|
|
required_fields = ['date']
|
|
|
|
# Check required fields
|
|
for field in required_fields:
|
|
if not data.get(field):
|
|
return False
|
|
|
|
# Validate data types and ranges
|
|
traffic_volume = data.get('traffic_volume')
|
|
if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 10000):
|
|
return False
|
|
|
|
pedestrian_count = data.get('pedestrian_count')
|
|
if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000):
|
|
return False
|
|
|
|
average_speed = data.get('average_speed')
|
|
if average_speed is not None and (average_speed < 0 or average_speed > 200):
|
|
return False
|
|
|
|
congestion_level = data.get('congestion_level')
|
|
if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']:
|
|
return False
|
|
|
|
return True
|
|
|
|
async def get_historical_traffic_for_training(self,
|
|
latitude: float,
|
|
longitude: float,
|
|
start_date: datetime,
|
|
end_date: datetime) -> List[TrafficData]:
|
|
"""Retrieve stored traffic data for training ML models."""
|
|
try:
|
|
location_id = f"{latitude:.4f},{longitude:.4f}"
|
|
|
|
stmt = select(TrafficData).where(
|
|
and_(
|
|
TrafficData.location_id == location_id,
|
|
TrafficData.date >= start_date,
|
|
TrafficData.date <= end_date
|
|
)
|
|
).order_by(TrafficData.date)
|
|
|
|
result = await self.session.execute(stmt)
|
|
return result.scalars().all()
|
|
|
|
except Exception as e:
|
|
logger.error("Failed to retrieve traffic data for training",
|
|
error=str(e), location_id=location_id)
|
|
raise DatabaseError(f"Training data retrieval failed: {str(e)}") |