272 lines
8.6 KiB
Python
272 lines
8.6 KiB
Python
"""
|
|
POI Context Repository
|
|
|
|
Data access layer for TenantPOIContext model.
|
|
Handles CRUD operations for POI detection results and ML features.
|
|
"""
|
|
|
|
from typing import Optional, List
|
|
from datetime import datetime, timezone
|
|
from sqlalchemy import select, update, delete
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
import structlog
|
|
import uuid
|
|
|
|
from app.models.poi_context import TenantPOIContext
|
|
|
|
logger = structlog.get_logger()
|
|
|
|
|
|
class POIContextRepository:
|
|
"""
|
|
Repository for POI context data access.
|
|
|
|
Manages storage and retrieval of POI detection results
|
|
and ML features for tenant locations.
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
"""
|
|
Initialize repository.
|
|
|
|
Args:
|
|
session: SQLAlchemy async session
|
|
"""
|
|
self.session = session
|
|
|
|
async def create(self, poi_context_data: dict) -> TenantPOIContext:
|
|
"""
|
|
Create new POI context record.
|
|
|
|
Args:
|
|
poi_context_data: Dictionary with POI context data
|
|
|
|
Returns:
|
|
Created TenantPOIContext instance
|
|
"""
|
|
poi_context = TenantPOIContext(
|
|
tenant_id=poi_context_data["tenant_id"],
|
|
latitude=poi_context_data["latitude"],
|
|
longitude=poi_context_data["longitude"],
|
|
poi_detection_results=poi_context_data.get("poi_detection_results", {}),
|
|
ml_features=poi_context_data.get("ml_features", {}),
|
|
total_pois_detected=poi_context_data.get("total_pois_detected", 0),
|
|
high_impact_categories=poi_context_data.get("high_impact_categories", []),
|
|
relevant_categories=poi_context_data.get("relevant_categories", []),
|
|
detection_timestamp=poi_context_data.get(
|
|
"detection_timestamp",
|
|
datetime.now(timezone.utc)
|
|
),
|
|
detection_source=poi_context_data.get("detection_source", "overpass_api"),
|
|
detection_status=poi_context_data.get("detection_status", "completed"),
|
|
detection_error=poi_context_data.get("detection_error"),
|
|
refresh_interval_days=poi_context_data.get("refresh_interval_days", 180)
|
|
)
|
|
|
|
# Calculate next refresh date
|
|
poi_context.next_refresh_date = poi_context.calculate_next_refresh()
|
|
|
|
self.session.add(poi_context)
|
|
await self.session.commit()
|
|
await self.session.refresh(poi_context)
|
|
|
|
logger.info(
|
|
"POI context created",
|
|
tenant_id=str(poi_context.tenant_id),
|
|
total_pois=poi_context.total_pois_detected
|
|
)
|
|
|
|
return poi_context
|
|
|
|
async def get_by_tenant_id(self, tenant_id: str | uuid.UUID) -> Optional[TenantPOIContext]:
|
|
"""
|
|
Get POI context by tenant ID.
|
|
|
|
Args:
|
|
tenant_id: Tenant UUID
|
|
|
|
Returns:
|
|
TenantPOIContext or None if not found
|
|
"""
|
|
if isinstance(tenant_id, str):
|
|
tenant_id = uuid.UUID(tenant_id)
|
|
|
|
stmt = select(TenantPOIContext).where(
|
|
TenantPOIContext.tenant_id == tenant_id
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_by_id(self, poi_context_id: str | uuid.UUID) -> Optional[TenantPOIContext]:
|
|
"""
|
|
Get POI context by ID.
|
|
|
|
Args:
|
|
poi_context_id: POI context UUID
|
|
|
|
Returns:
|
|
TenantPOIContext or None if not found
|
|
"""
|
|
if isinstance(poi_context_id, str):
|
|
poi_context_id = uuid.UUID(poi_context_id)
|
|
|
|
stmt = select(TenantPOIContext).where(
|
|
TenantPOIContext.id == poi_context_id
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def update(
|
|
self,
|
|
tenant_id: str | uuid.UUID,
|
|
update_data: dict
|
|
) -> Optional[TenantPOIContext]:
|
|
"""
|
|
Update POI context for tenant.
|
|
|
|
Args:
|
|
tenant_id: Tenant UUID
|
|
update_data: Dictionary with fields to update
|
|
|
|
Returns:
|
|
Updated TenantPOIContext or None if not found
|
|
"""
|
|
if isinstance(tenant_id, str):
|
|
tenant_id = uuid.UUID(tenant_id)
|
|
|
|
poi_context = await self.get_by_tenant_id(tenant_id)
|
|
if not poi_context:
|
|
return None
|
|
|
|
# Update fields
|
|
for key, value in update_data.items():
|
|
if hasattr(poi_context, key):
|
|
setattr(poi_context, key, value)
|
|
|
|
# Update timestamp
|
|
poi_context.updated_at = datetime.now(timezone.utc)
|
|
|
|
await self.session.commit()
|
|
await self.session.refresh(poi_context)
|
|
|
|
logger.info(
|
|
"POI context updated",
|
|
tenant_id=str(tenant_id),
|
|
updated_fields=list(update_data.keys())
|
|
)
|
|
|
|
return poi_context
|
|
|
|
async def create_or_update(
|
|
self,
|
|
tenant_id: str | uuid.UUID,
|
|
poi_detection_results: dict
|
|
) -> TenantPOIContext:
|
|
"""
|
|
Create new POI context or update existing one.
|
|
|
|
Args:
|
|
tenant_id: Tenant UUID
|
|
poi_detection_results: Full POI detection results
|
|
|
|
Returns:
|
|
Created or updated TenantPOIContext
|
|
"""
|
|
if isinstance(tenant_id, str):
|
|
tenant_id = uuid.UUID(tenant_id)
|
|
|
|
existing = await self.get_by_tenant_id(tenant_id)
|
|
|
|
poi_context_data = {
|
|
"tenant_id": tenant_id,
|
|
"latitude": poi_detection_results["location"]["latitude"],
|
|
"longitude": poi_detection_results["location"]["longitude"],
|
|
"poi_detection_results": poi_detection_results.get("poi_categories", {}),
|
|
"ml_features": poi_detection_results.get("ml_features", {}),
|
|
"total_pois_detected": poi_detection_results.get("summary", {}).get("total_pois_detected", 0),
|
|
"high_impact_categories": poi_detection_results.get("summary", {}).get("high_impact_categories", []),
|
|
"relevant_categories": poi_detection_results.get("relevant_categories", []),
|
|
"detection_timestamp": datetime.fromisoformat(
|
|
poi_detection_results["detection_timestamp"].replace("Z", "+00:00")
|
|
) if isinstance(poi_detection_results.get("detection_timestamp"), str)
|
|
else datetime.now(timezone.utc),
|
|
"detection_status": poi_detection_results.get("detection_status", "completed"),
|
|
"detection_error": None if poi_detection_results.get("detection_status") == "completed"
|
|
else str(poi_detection_results.get("detection_errors"))
|
|
}
|
|
|
|
if existing:
|
|
# Update existing
|
|
update_data = {
|
|
**poi_context_data,
|
|
"last_refreshed_at": datetime.now(timezone.utc)
|
|
}
|
|
existing.mark_refreshed() # Update next_refresh_date
|
|
return await self.update(tenant_id, update_data)
|
|
else:
|
|
# Create new
|
|
return await self.create(poi_context_data)
|
|
|
|
async def delete_by_tenant_id(self, tenant_id: str | uuid.UUID) -> bool:
|
|
"""
|
|
Delete POI context for tenant.
|
|
|
|
Args:
|
|
tenant_id: Tenant UUID
|
|
|
|
Returns:
|
|
True if deleted, False if not found
|
|
"""
|
|
if isinstance(tenant_id, str):
|
|
tenant_id = uuid.UUID(tenant_id)
|
|
|
|
stmt = delete(TenantPOIContext).where(
|
|
TenantPOIContext.tenant_id == tenant_id
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
await self.session.commit()
|
|
|
|
deleted = result.rowcount > 0
|
|
if deleted:
|
|
logger.info("POI context deleted", tenant_id=str(tenant_id))
|
|
|
|
return deleted
|
|
|
|
async def get_stale_contexts(self, limit: int = 100) -> List[TenantPOIContext]:
|
|
"""
|
|
Get POI contexts that need refresh.
|
|
|
|
Args:
|
|
limit: Maximum number of contexts to return
|
|
|
|
Returns:
|
|
List of stale TenantPOIContext instances
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
stmt = (
|
|
select(TenantPOIContext)
|
|
.where(TenantPOIContext.next_refresh_date <= now)
|
|
.limit(limit)
|
|
)
|
|
result = await self.session.execute(stmt)
|
|
return list(result.scalars().all())
|
|
|
|
async def count_by_status(self) -> dict:
|
|
"""
|
|
Count POI contexts by detection status.
|
|
|
|
Returns:
|
|
Dictionary with counts by status
|
|
"""
|
|
from sqlalchemy import func
|
|
|
|
stmt = select(
|
|
TenantPOIContext.detection_status,
|
|
func.count(TenantPOIContext.id)
|
|
).group_by(TenantPOIContext.detection_status)
|
|
|
|
result = await self.session.execute(stmt)
|
|
rows = result.all()
|
|
|
|
return {status: count for status, count in rows}
|