250 lines
8.0 KiB
Python
250 lines
8.0 KiB
Python
# services/external/app/repositories/city_data_repository.py
|
|
"""
|
|
City Data Repository - Manages shared city-based data storage
|
|
"""
|
|
|
|
from typing import List, Dict, Any, Optional
|
|
from datetime import datetime
|
|
from sqlalchemy import select, delete, and_
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
import structlog
|
|
|
|
from app.models.city_weather import CityWeatherData
|
|
from app.models.city_traffic import CityTrafficData
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class CityDataRepository:
|
|
"""Repository for city-based historical data"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
async def bulk_store_weather(
|
|
self,
|
|
city_id: str,
|
|
weather_records: List[Dict[str, Any]]
|
|
) -> int:
|
|
"""Bulk insert weather records for a city"""
|
|
if not weather_records:
|
|
return 0
|
|
|
|
try:
|
|
objects = []
|
|
for record in weather_records:
|
|
obj = CityWeatherData(
|
|
city_id=city_id,
|
|
date=record.get('date'),
|
|
temperature=record.get('temperature'),
|
|
precipitation=record.get('precipitation'),
|
|
humidity=record.get('humidity'),
|
|
wind_speed=record.get('wind_speed'),
|
|
pressure=record.get('pressure'),
|
|
description=record.get('description'),
|
|
source=record.get('source', 'ingestion'),
|
|
raw_data=record.get('raw_data')
|
|
)
|
|
objects.append(obj)
|
|
|
|
self.session.add_all(objects)
|
|
await self.session.commit()
|
|
|
|
logger.info(
|
|
"Weather data stored",
|
|
city_id=city_id,
|
|
records=len(objects)
|
|
)
|
|
|
|
return len(objects)
|
|
|
|
except Exception as e:
|
|
await self.session.rollback()
|
|
logger.error(
|
|
"Error storing weather data",
|
|
city_id=city_id,
|
|
error=str(e)
|
|
)
|
|
raise
|
|
|
|
async def get_weather_by_city_and_range(
|
|
self,
|
|
city_id: str,
|
|
start_date: datetime,
|
|
end_date: datetime
|
|
) -> List[CityWeatherData]:
|
|
"""Get weather data for city within date range"""
|
|
stmt = select(CityWeatherData).where(
|
|
and_(
|
|
CityWeatherData.city_id == city_id,
|
|
CityWeatherData.date >= start_date,
|
|
CityWeatherData.date <= end_date
|
|
)
|
|
).order_by(CityWeatherData.date)
|
|
|
|
result = await self.session.execute(stmt)
|
|
return result.scalars().all()
|
|
|
|
async def delete_weather_before(
|
|
self,
|
|
city_id: str,
|
|
cutoff_date: datetime
|
|
) -> int:
|
|
"""Delete weather records older than cutoff date"""
|
|
stmt = delete(CityWeatherData).where(
|
|
and_(
|
|
CityWeatherData.city_id == city_id,
|
|
CityWeatherData.date < cutoff_date
|
|
)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
return result.rowcount
|
|
|
|
async def bulk_store_traffic(
|
|
self,
|
|
city_id: str,
|
|
traffic_records: List[Dict[str, Any]]
|
|
) -> int:
|
|
"""Bulk insert traffic records for a city"""
|
|
if not traffic_records:
|
|
return 0
|
|
|
|
try:
|
|
objects = []
|
|
for record in traffic_records:
|
|
obj = CityTrafficData(
|
|
city_id=city_id,
|
|
date=record.get('date'),
|
|
traffic_volume=record.get('traffic_volume'),
|
|
pedestrian_count=record.get('pedestrian_count'),
|
|
congestion_level=record.get('congestion_level'),
|
|
average_speed=record.get('average_speed'),
|
|
source=record.get('source', 'ingestion'),
|
|
raw_data=record.get('raw_data')
|
|
)
|
|
objects.append(obj)
|
|
|
|
self.session.add_all(objects)
|
|
await self.session.commit()
|
|
|
|
logger.info(
|
|
"Traffic data stored",
|
|
city_id=city_id,
|
|
records=len(objects)
|
|
)
|
|
|
|
return len(objects)
|
|
|
|
except Exception as e:
|
|
await self.session.rollback()
|
|
logger.error(
|
|
"Error storing traffic data",
|
|
city_id=city_id,
|
|
error=str(e)
|
|
)
|
|
raise
|
|
|
|
async def get_traffic_by_city_and_range(
|
|
self,
|
|
city_id: str,
|
|
start_date: datetime,
|
|
end_date: datetime
|
|
) -> List[CityTrafficData]:
|
|
"""Get traffic data for city within date range - aggregated daily"""
|
|
from sqlalchemy import func, cast, Date
|
|
|
|
# Aggregate hourly data to daily averages to avoid loading hundreds of thousands of records
|
|
stmt = select(
|
|
cast(CityTrafficData.date, Date).label('date'),
|
|
func.avg(CityTrafficData.traffic_volume).label('traffic_volume'),
|
|
func.avg(CityTrafficData.pedestrian_count).label('pedestrian_count'),
|
|
func.avg(CityTrafficData.average_speed).label('average_speed'),
|
|
func.max(CityTrafficData.source).label('source')
|
|
).where(
|
|
and_(
|
|
CityTrafficData.city_id == city_id,
|
|
CityTrafficData.date >= start_date,
|
|
CityTrafficData.date <= end_date
|
|
)
|
|
).group_by(
|
|
cast(CityTrafficData.date, Date)
|
|
).order_by(
|
|
cast(CityTrafficData.date, Date)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
|
|
# Convert aggregated rows to CityTrafficData objects
|
|
traffic_records = []
|
|
for row in result:
|
|
record = CityTrafficData(
|
|
city_id=city_id,
|
|
date=datetime.combine(row.date, datetime.min.time()),
|
|
traffic_volume=int(row.traffic_volume) if row.traffic_volume else None,
|
|
pedestrian_count=int(row.pedestrian_count) if row.pedestrian_count else None,
|
|
congestion_level='medium', # Default since we're averaging
|
|
average_speed=float(row.average_speed) if row.average_speed else None,
|
|
source=row.source or 'aggregated'
|
|
)
|
|
traffic_records.append(record)
|
|
|
|
return traffic_records
|
|
|
|
async def delete_traffic_before(
|
|
self,
|
|
city_id: str,
|
|
cutoff_date: datetime
|
|
) -> int:
|
|
"""Delete traffic records older than cutoff date"""
|
|
stmt = delete(CityTrafficData).where(
|
|
and_(
|
|
CityTrafficData.city_id == city_id,
|
|
CityTrafficData.date < cutoff_date
|
|
)
|
|
)
|
|
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
return result.rowcount
|
|
|
|
async def get_data_coverage(
|
|
self,
|
|
city_id: str,
|
|
start_date: datetime,
|
|
end_date: datetime
|
|
) -> Dict[str, int]:
|
|
"""
|
|
Check how much data exists for a city in a date range
|
|
Returns dict with counts: {'weather': X, 'traffic': Y}
|
|
"""
|
|
# Count weather records
|
|
weather_stmt = select(CityWeatherData).where(
|
|
and_(
|
|
CityWeatherData.city_id == city_id,
|
|
CityWeatherData.date >= start_date,
|
|
CityWeatherData.date <= end_date
|
|
)
|
|
)
|
|
weather_result = await self.session.execute(weather_stmt)
|
|
weather_count = len(weather_result.scalars().all())
|
|
|
|
# Count traffic records
|
|
traffic_stmt = select(CityTrafficData).where(
|
|
and_(
|
|
CityTrafficData.city_id == city_id,
|
|
CityTrafficData.date >= start_date,
|
|
CityTrafficData.date <= end_date
|
|
)
|
|
)
|
|
traffic_result = await self.session.execute(traffic_stmt)
|
|
traffic_count = len(traffic_result.scalars().all())
|
|
|
|
return {
|
|
'weather': weather_count,
|
|
'traffic': traffic_count
|
|
}
|