# ================================================================ # services/data/app/repositories/traffic_repository.py # ================================================================ """ Traffic Repository - Enhanced for multiple cities with comprehensive data access patterns Follows existing repository architecture while adding city-specific functionality """ from typing import Optional, List, Dict, Any, Type, Tuple from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, or_, func, desc, asc, text, update, delete from sqlalchemy.orm import selectinload from datetime import datetime, timezone, timedelta import structlog from .base import DataBaseRepository from app.models.traffic import TrafficData, TrafficMeasurementPoint, TrafficDataBackgroundJob from app.schemas.traffic import TrafficDataCreate, TrafficDataResponse from shared.database.exceptions import DatabaseError, ValidationError logger = structlog.get_logger() class TrafficRepository(DataBaseRepository[TrafficData, TrafficDataCreate, Dict]): """ Enhanced repository for traffic data operations across multiple cities Provides city-aware queries and advanced traffic analytics """ def __init__(self, model_class: Type, session: AsyncSession, cache_ttl: Optional[int] = 300): super().__init__(model_class, session, cache_ttl) # ================================================================ # CORE TRAFFIC DATA OPERATIONS # ================================================================ async def get_by_location_and_date_range( self, latitude: float, longitude: float, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, city: Optional[str] = None, tenant_id: Optional[str] = None, skip: int = 0, limit: int = 100 ) -> List[TrafficData]: """Get traffic data by location and date range with city filtering""" try: location_id = f"{latitude:.4f},{longitude:.4f}" # Build base query query = select(self.model).where(self.model.location_id == location_id) # Add city filter if specified if city: query = query.where(self.model.city == city) # Add tenant filter if specified if tenant_id: query = query.where(self.model.tenant_id == tenant_id) # Add date range filters if start_date: start_date = self._ensure_utc_datetime(start_date) query = query.where(self.model.date >= start_date) if end_date: end_date = self._ensure_utc_datetime(end_date) query = query.where(self.model.date <= end_date) # Order by date descending (most recent first) query = query.order_by(desc(self.model.date)) # Apply pagination query = query.offset(skip).limit(limit) result = await self.session.execute(query) return result.scalars().all() except Exception as e: logger.error("Failed to get traffic data by location and date range", latitude=latitude, longitude=longitude, city=city, error=str(e)) raise DatabaseError(f"Failed to get traffic data: {str(e)}") async def get_by_city_and_date_range( self, city: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, district: Optional[str] = None, measurement_point_ids: Optional[List[str]] = None, include_synthetic: bool = True, tenant_id: Optional[str] = None, skip: int = 0, limit: int = 1000 ) -> List[TrafficData]: """Get traffic data by city with advanced filtering options""" try: # Build base query query = select(self.model).where(self.model.city == city) # Add tenant filter if specified if tenant_id: query = query.where(self.model.tenant_id == tenant_id) # Add date range filters if start_date: start_date = self._ensure_utc_datetime(start_date) query = query.where(self.model.date >= start_date) if end_date: end_date = self._ensure_utc_datetime(end_date) query = query.where(self.model.date <= end_date) # Add district filter if district: query = query.where(self.model.district == district) # Add measurement point filter if measurement_point_ids: query = query.where(self.model.measurement_point_id.in_(measurement_point_ids)) # Filter synthetic data if requested if not include_synthetic: query = query.where(self.model.is_synthetic == False) # Order by date and measurement point query = query.order_by(desc(self.model.date), self.model.measurement_point_id) # Apply pagination query = query.offset(skip).limit(limit) result = await self.session.execute(query) return result.scalars().all() except Exception as e: logger.error("Failed to get traffic data by city", city=city, district=district, error=str(e)) raise DatabaseError(f"Failed to get traffic data: {str(e)}") async def get_latest_by_measurement_points( self, measurement_point_ids: List[str], city: str, hours_back: int = 24 ) -> List[TrafficData]: """Get latest traffic data for specific measurement points""" try: cutoff_time = datetime.now(timezone.utc) - timedelta(hours=hours_back) query = select(self.model).where( and_( self.model.city == city, self.model.measurement_point_id.in_(measurement_point_ids), self.model.date >= cutoff_time ) ).order_by( self.model.measurement_point_id, desc(self.model.date) ) result = await self.session.execute(query) all_records = result.scalars().all() # Get the latest record for each measurement point latest_records = {} for record in all_records: point_id = record.measurement_point_id if point_id not in latest_records: latest_records[point_id] = record return list(latest_records.values()) except Exception as e: logger.error("Failed to get latest traffic data by measurement points", city=city, points=len(measurement_point_ids), error=str(e)) raise DatabaseError(f"Failed to get latest traffic data: {str(e)}") # ================================================================ # ANALYTICS AND AGGREGATIONS # ================================================================ async def get_traffic_statistics_by_city( self, city: str, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, group_by: str = "daily" ) -> List[Dict[str, Any]]: """Get aggregated traffic statistics by city""" try: # Determine date truncation based on group_by if group_by == "hourly": date_trunc = "hour" elif group_by == "daily": date_trunc = "day" elif group_by == "weekly": date_trunc = "week" elif group_by == "monthly": date_trunc = "month" else: raise ValidationError(f"Invalid group_by value: {group_by}") # Build aggregation query if self.session.bind.dialect.name == 'postgresql': query = text(""" SELECT DATE_TRUNC(:date_trunc, date) as period, city, district, COUNT(*) as record_count, AVG(traffic_volume) as avg_traffic_volume, MAX(traffic_volume) as max_traffic_volume, AVG(pedestrian_count) as avg_pedestrian_count, AVG(average_speed) as avg_speed, COUNT(CASE WHEN congestion_level = 'high' THEN 1 END) as high_congestion_count, COUNT(CASE WHEN is_synthetic = false THEN 1 END) as real_data_count, COUNT(CASE WHEN has_pedestrian_inference = true THEN 1 END) as pedestrian_inference_count FROM traffic_data WHERE city = :city """) else: # SQLite fallback query = text(""" SELECT DATE(date) as period, city, district, COUNT(*) as record_count, AVG(traffic_volume) as avg_traffic_volume, MAX(traffic_volume) as max_traffic_volume, AVG(pedestrian_count) as avg_pedestrian_count, AVG(average_speed) as avg_speed, SUM(CASE WHEN congestion_level = 'high' THEN 1 ELSE 0 END) as high_congestion_count, SUM(CASE WHEN is_synthetic = 0 THEN 1 ELSE 0 END) as real_data_count, SUM(CASE WHEN has_pedestrian_inference = 1 THEN 1 ELSE 0 END) as pedestrian_inference_count FROM traffic_data WHERE city = :city """) params = { "city": city, "date_trunc": date_trunc } # Add date filters if start_date: query = text(str(query) + " AND date >= :start_date") params["start_date"] = self._ensure_utc_datetime(start_date) if end_date: query = text(str(query) + " AND date <= :end_date") params["end_date"] = self._ensure_utc_datetime(end_date) # Add GROUP BY and ORDER BY query = text(str(query) + " GROUP BY period, city, district ORDER BY period DESC") result = await self.session.execute(query, params) rows = result.fetchall() # Convert to list of dictionaries statistics = [] for row in rows: statistics.append({ "period": group_by, "date": row.period, "city": row.city, "district": row.district, "record_count": row.record_count, "avg_traffic_volume": float(row.avg_traffic_volume or 0), "max_traffic_volume": row.max_traffic_volume or 0, "avg_pedestrian_count": float(row.avg_pedestrian_count or 0), "avg_speed": float(row.avg_speed or 0), "high_congestion_count": row.high_congestion_count or 0, "real_data_percentage": round((row.real_data_count or 0) / max(1, row.record_count) * 100, 2), "pedestrian_inference_percentage": round((row.pedestrian_inference_count or 0) / max(1, row.record_count) * 100, 2) }) return statistics except Exception as e: logger.error("Failed to get traffic statistics by city", city=city, group_by=group_by, error=str(e)) raise DatabaseError(f"Traffic statistics query failed: {str(e)}") async def get_congestion_heatmap_data( self, city: str, start_date: datetime, end_date: datetime, time_granularity: str = "hour" ) -> List[Dict[str, Any]]: """Get congestion data for heatmap visualization""" try: if time_granularity == "hour": time_extract = "EXTRACT(hour FROM date)" elif time_granularity == "day_of_week": time_extract = "EXTRACT(dow FROM date)" else: time_extract = "EXTRACT(hour FROM date)" query = text(f""" SELECT {time_extract} as time_period, district, measurement_point_id, latitude, longitude, AVG(CASE WHEN congestion_level = 'low' THEN 1 WHEN congestion_level = 'medium' THEN 2 WHEN congestion_level = 'high' THEN 3 WHEN congestion_level = 'blocked' THEN 4 ELSE 1 END) as avg_congestion_score, COUNT(*) as data_points, AVG(traffic_volume) as avg_traffic_volume, AVG(pedestrian_count) as avg_pedestrian_count FROM traffic_data WHERE city = :city AND date >= :start_date AND date <= :end_date AND latitude IS NOT NULL AND longitude IS NOT NULL GROUP BY time_period, district, measurement_point_id, latitude, longitude ORDER BY time_period, district, avg_congestion_score DESC """) params = { "city": city, "start_date": self._ensure_utc_datetime(start_date), "end_date": self._ensure_utc_datetime(end_date) } result = await self.session.execute(query, params) rows = result.fetchall() heatmap_data = [] for row in rows: heatmap_data.append({ "time_period": int(row.time_period or 0), "district": row.district, "measurement_point_id": row.measurement_point_id, "latitude": float(row.latitude), "longitude": float(row.longitude), "avg_congestion_score": float(row.avg_congestion_score), "data_points": row.data_points, "avg_traffic_volume": float(row.avg_traffic_volume or 0), "avg_pedestrian_count": float(row.avg_pedestrian_count or 0) }) return heatmap_data except Exception as e: logger.error("Failed to get congestion heatmap data", city=city, error=str(e)) raise DatabaseError(f"Congestion heatmap query failed: {str(e)}") # ================================================================ # BULK OPERATIONS AND DATA MANAGEMENT # ================================================================ async def create_bulk_traffic_data( self, traffic_records: List[Dict[str, Any]], city: str, tenant_id: Optional[str] = None ) -> List[TrafficData]: """Create multiple traffic records in bulk with enhanced validation""" try: # Ensure all records have city and tenant_id for record in traffic_records: record["city"] = city if tenant_id: record["tenant_id"] = tenant_id # Ensure dates are timezone-aware if "date" in record and record["date"]: record["date"] = self._ensure_utc_datetime(record["date"]) # Enhanced validation validated_records = [] for record in traffic_records: if self._validate_traffic_record(record): validated_records.append(record) else: logger.warning("Invalid traffic record skipped", city=city, record_keys=list(record.keys())) if not validated_records: logger.warning("No valid traffic records to create", city=city) return [] # Use bulk create with deduplication created_records = await self.bulk_create_with_deduplication(validated_records) logger.info("Bulk traffic data creation completed", city=city, requested=len(traffic_records), validated=len(validated_records), created=len(created_records)) return created_records except Exception as e: logger.error("Failed to create bulk traffic data", city=city, record_count=len(traffic_records), error=str(e)) raise DatabaseError(f"Bulk traffic creation failed: {str(e)}") async def bulk_create_with_deduplication( self, records: List[Dict[str, Any]] ) -> List[TrafficData]: """Bulk create with automatic deduplication based on location, city, and date""" try: if not records: return [] # Extract unique keys for deduplication check unique_keys = [] for record in records: key = ( record.get('location_id'), record.get('city'), record.get('date'), record.get('measurement_point_id') ) unique_keys.append(key) # Check for existing records location_ids = [key[0] for key in unique_keys if key[0]] cities = [key[1] for key in unique_keys if key[1]] dates = [key[2] for key in unique_keys if key[2]] # For large datasets, use chunked deduplication to avoid memory issues if len(location_ids) > 1000: logger.info(f"Large dataset detected ({len(records)} records), using chunked deduplication") new_records = [] chunk_size = 1000 for i in range(0, len(records), chunk_size): chunk_records = records[i:i + chunk_size] chunk_keys = unique_keys[i:i + chunk_size] # Get unique values for this chunk chunk_location_ids = list(set(key[0] for key in chunk_keys if key[0])) chunk_cities = list(set(key[1] for key in chunk_keys if key[1])) chunk_dates = list(set(key[2] for key in chunk_keys if key[2])) if chunk_location_ids and chunk_cities and chunk_dates: existing_query = select( self.model.location_id, self.model.city, self.model.date, self.model.measurement_point_id ).where( and_( self.model.location_id.in_(chunk_location_ids), self.model.city.in_(chunk_cities), self.model.date.in_(chunk_dates) ) ) result = await self.session.execute(existing_query) chunk_existing_keys = set(result.fetchall()) # Filter chunk duplicates for j, record in enumerate(chunk_records): key = chunk_keys[j] if key not in chunk_existing_keys: new_records.append(record) else: new_records.extend(chunk_records) logger.debug("Chunked deduplication completed", total_records=len(records), new_records=len(new_records)) records = new_records elif location_ids and cities and dates: existing_query = select( self.model.location_id, self.model.city, self.model.date, self.model.measurement_point_id ).where( and_( self.model.location_id.in_(location_ids), self.model.city.in_(cities), self.model.date.in_(dates) ) ) result = await self.session.execute(existing_query) existing_keys = set(result.fetchall()) # Filter out duplicates new_records = [] for i, record in enumerate(records): key = unique_keys[i] if key not in existing_keys: new_records.append(record) logger.debug("Standard deduplication completed", total_records=len(records), existing_records=len(existing_keys), new_records=len(new_records)) records = new_records # Proceed with bulk creation return await self.bulk_create(records) except Exception as e: logger.error("Failed bulk create with deduplication", error=str(e)) raise DatabaseError(f"Bulk create with deduplication failed: {str(e)}") def _validate_traffic_record(self, record: Dict[str, Any]) -> bool: """Enhanced validation for traffic records""" required_fields = ['date', 'city'] # Check required fields for field in required_fields: if not record.get(field): return False # Validate city city = record.get('city', '').lower() if city not in ['madrid', 'barcelona', 'valencia', 'test']: # Extendable list return False # Validate data ranges traffic_volume = record.get('traffic_volume') if traffic_volume is not None and (traffic_volume < 0 or traffic_volume > 50000): return False pedestrian_count = record.get('pedestrian_count') if pedestrian_count is not None and (pedestrian_count < 0 or pedestrian_count > 10000): return False average_speed = record.get('average_speed') if average_speed is not None and (average_speed < 0 or average_speed > 200): return False congestion_level = record.get('congestion_level') if congestion_level and congestion_level not in ['low', 'medium', 'high', 'blocked']: return False return True # ================================================================ # TRAINING DATA SPECIFIC OPERATIONS # ================================================================ async def get_training_data_by_location( self, latitude: float, longitude: float, start_date: datetime, end_date: datetime, tenant_id: Optional[str] = None, include_pedestrian_inference: bool = True ) -> List[Dict[str, Any]]: """Get optimized training data for ML models""" try: location_id = f"{latitude:.4f},{longitude:.4f}" query = select(self.model).where( and_( self.model.location_id == location_id, self.model.date >= self._ensure_utc_datetime(start_date), self.model.date <= self._ensure_utc_datetime(end_date) ) ) if tenant_id: query = query.where(self.model.tenant_id == tenant_id) if include_pedestrian_inference: # Prefer records with pedestrian inference query = query.order_by( desc(self.model.has_pedestrian_inference), desc(self.model.data_quality_score), self.model.date ) else: query = query.order_by( desc(self.model.data_quality_score), self.model.date ) result = await self.session.execute(query) records = result.scalars().all() # Convert to training format with enhanced features training_data = [] for record in records: training_record = { 'date': record.date, 'traffic_volume': record.traffic_volume or 0, 'pedestrian_count': record.pedestrian_count or 0, 'congestion_level': record.congestion_level or 'medium', 'average_speed': record.average_speed or 25.0, 'city': record.city, 'district': record.district, 'measurement_point_id': record.measurement_point_id, 'source': record.source, 'is_synthetic': record.is_synthetic or False, 'has_pedestrian_inference': record.has_pedestrian_inference or False, 'data_quality_score': record.data_quality_score or 50.0, # Enhanced features for training 'hour_of_day': record.date.hour if record.date else 12, 'day_of_week': record.date.weekday() if record.date else 0, 'month': record.date.month if record.date else 1, # City-specific features 'city_specific_data': record.city_specific_data or {} } training_data.append(training_record) logger.info("Retrieved training data", location_id=location_id, records=len(training_data), with_pedestrian_inference=sum(1 for r in training_data if r['has_pedestrian_inference'])) return training_data except Exception as e: logger.error("Failed to get training data", latitude=latitude, longitude=longitude, error=str(e)) raise DatabaseError(f"Training data retrieval failed: {str(e)}") async def get_historical_data_by_location( self, latitude: float, longitude: float, start_date: datetime, end_date: datetime, tenant_id: Optional[str] = None ) -> List[TrafficData]: """Get historical traffic data for a specific location and date range""" return await self.get_by_location_and_date_range( latitude=latitude, longitude=longitude, start_date=start_date, end_date=end_date, tenant_id=tenant_id, limit=1000000 # Large limit for historical data ) async def count_records_in_period( self, latitude: float, longitude: float, start_date: datetime, end_date: datetime, city: Optional[str] = None, tenant_id: Optional[str] = None ) -> int: """Count traffic records for a specific location and time period""" try: location_id = f"{latitude:.4f},{longitude:.4f}" query = select(func.count(self.model.id)).where( and_( self.model.location_id == location_id, self.model.date >= self._ensure_utc_datetime(start_date), self.model.date <= self._ensure_utc_datetime(end_date) ) ) if city: query = query.where(self.model.city == city) if tenant_id: query = query.where(self.model.tenant_id == tenant_id) result = await self.session.execute(query) count = result.scalar() return count or 0 except Exception as e: logger.error("Failed to count records in period", latitude=latitude, longitude=longitude, error=str(e)) raise DatabaseError(f"Record count failed: {str(e)}") # ================================================================ # DATA QUALITY AND MAINTENANCE # ================================================================ async def update_data_quality_scores(self, city: str) -> int: """Update data quality scores based on various criteria""" try: # Calculate quality scores based on data completeness and consistency query = text(""" UPDATE traffic_data SET data_quality_score = ( CASE WHEN traffic_volume IS NOT NULL THEN 20 ELSE 0 END + CASE WHEN pedestrian_count IS NOT NULL THEN 20 ELSE 0 END + CASE WHEN average_speed IS NOT NULL AND average_speed > 0 THEN 20 ELSE 0 END + CASE WHEN congestion_level IS NOT NULL THEN 15 ELSE 0 END + CASE WHEN measurement_point_id IS NOT NULL THEN 10 ELSE 0 END + CASE WHEN district IS NOT NULL THEN 10 ELSE 0 END + CASE WHEN has_pedestrian_inference = true THEN 5 ELSE 0 END ), updated_at = :updated_at WHERE city = :city AND data_quality_score IS NULL """) params = { "city": city, "updated_at": datetime.now(timezone.utc) } result = await self.session.execute(query, params) updated_count = result.rowcount await self.session.commit() logger.info("Updated data quality scores", city=city, updated_count=updated_count) return updated_count except Exception as e: logger.error("Failed to update data quality scores", city=city, error=str(e)) await self.session.rollback() raise DatabaseError(f"Data quality update failed: {str(e)}") async def cleanup_old_synthetic_data( self, city: str, days_to_keep: int = 90 ) -> int: """Clean up old synthetic data while preserving real data""" try: cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_to_keep) query = delete(self.model).where( and_( self.model.city == city, self.model.is_synthetic == True, self.model.date < cutoff_date ) ) result = await self.session.execute(query) deleted_count = result.rowcount await self.session.commit() logger.info("Cleaned up old synthetic data", city=city, deleted_count=deleted_count, days_kept=days_to_keep) return deleted_count except Exception as e: logger.error("Failed to cleanup old synthetic data", city=city, error=str(e)) await self.session.rollback() raise DatabaseError(f"Synthetic data cleanup failed: {str(e)}") class TrafficMeasurementPointRepository(DataBaseRepository[TrafficMeasurementPoint, Dict, Dict]): """Repository for traffic measurement points across cities""" async def get_points_near_location( self, latitude: float, longitude: float, city: str, radius_km: float = 10.0, limit: int = 20 ) -> List[TrafficMeasurementPoint]: """Get measurement points near a location using spatial query""" try: # Simple distance calculation (for more precise, use PostGIS) query = text(""" SELECT *, (6371 * acos( cos(radians(:lat)) * cos(radians(latitude)) * cos(radians(longitude) - radians(:lon)) + sin(radians(:lat)) * sin(radians(latitude)) )) as distance_km FROM traffic_measurement_points WHERE city = :city AND is_active = true HAVING distance_km <= :radius_km ORDER BY distance_km LIMIT :limit """) params = { "lat": latitude, "lon": longitude, "city": city, "radius_km": radius_km, "limit": limit } result = await self.session.execute(query, params) rows = result.fetchall() # Convert rows to model instances points = [] for row in rows: point = TrafficMeasurementPoint() for key, value in row._mapping.items(): if hasattr(point, key) and key != 'distance_km': setattr(point, key, value) points.append(point) return points except Exception as e: logger.error("Failed to get measurement points near location", latitude=latitude, longitude=longitude, city=city, error=str(e)) raise DatabaseError(f"Measurement points query failed: {str(e)}") class TrafficBackgroundJobRepository(DataBaseRepository[TrafficDataBackgroundJob, Dict, Dict]): """Repository for managing background traffic data jobs""" async def get_pending_jobs_by_city(self, city: str) -> List[TrafficDataBackgroundJob]: """Get pending background jobs for a specific city""" try: query = select(self.model).where( and_( self.model.city == city, self.model.status == 'pending' ) ).order_by(self.model.scheduled_at) result = await self.session.execute(query) return result.scalars().all() except Exception as e: logger.error("Failed to get pending jobs by city", city=city, error=str(e)) raise DatabaseError(f"Background jobs query failed: {str(e)}") async def update_job_progress( self, job_id: str, progress_percentage: float, records_processed: int = 0, records_stored: int = 0 ) -> bool: """Update job progress""" try: query = update(self.model).where( self.model.id == job_id ).values( progress_percentage=progress_percentage, records_processed=records_processed, records_stored=records_stored, updated_at=datetime.now(timezone.utc) ) result = await self.session.execute(query) await self.session.commit() return result.rowcount > 0 except Exception as e: logger.error("Failed to update job progress", job_id=job_id, error=str(e)) await self.session.rollback() raise DatabaseError(f"Job progress update failed: {str(e)}")