Files
bakery-ia/services/external/app/repositories/poi_context_repository.py

272 lines
8.6 KiB
Python

"""
POI Context Repository
Data access layer for TenantPOIContext model.
Handles CRUD operations for POI detection results and ML features.
"""
from typing import Optional, List
from datetime import datetime, timezone
from sqlalchemy import select, update, delete
from sqlalchemy.ext.asyncio import AsyncSession
import structlog
import uuid
from app.models.poi_context import TenantPOIContext
logger = structlog.get_logger()
class POIContextRepository:
"""
Repository for POI context data access.
Manages storage and retrieval of POI detection results
and ML features for tenant locations.
"""
def __init__(self, session: AsyncSession):
"""
Initialize repository.
Args:
session: SQLAlchemy async session
"""
self.session = session
async def create(self, poi_context_data: dict) -> TenantPOIContext:
"""
Create new POI context record.
Args:
poi_context_data: Dictionary with POI context data
Returns:
Created TenantPOIContext instance
"""
poi_context = TenantPOIContext(
tenant_id=poi_context_data["tenant_id"],
latitude=poi_context_data["latitude"],
longitude=poi_context_data["longitude"],
poi_detection_results=poi_context_data.get("poi_detection_results", {}),
ml_features=poi_context_data.get("ml_features", {}),
total_pois_detected=poi_context_data.get("total_pois_detected", 0),
high_impact_categories=poi_context_data.get("high_impact_categories", []),
relevant_categories=poi_context_data.get("relevant_categories", []),
detection_timestamp=poi_context_data.get(
"detection_timestamp",
datetime.now(timezone.utc)
),
detection_source=poi_context_data.get("detection_source", "overpass_api"),
detection_status=poi_context_data.get("detection_status", "completed"),
detection_error=poi_context_data.get("detection_error"),
refresh_interval_days=poi_context_data.get("refresh_interval_days", 180)
)
# Calculate next refresh date
poi_context.next_refresh_date = poi_context.calculate_next_refresh()
self.session.add(poi_context)
await self.session.commit()
await self.session.refresh(poi_context)
logger.info(
"POI context created",
tenant_id=str(poi_context.tenant_id),
total_pois=poi_context.total_pois_detected
)
return poi_context
async def get_by_tenant_id(self, tenant_id: str | uuid.UUID) -> Optional[TenantPOIContext]:
"""
Get POI context by tenant ID.
Args:
tenant_id: Tenant UUID
Returns:
TenantPOIContext or None if not found
"""
if isinstance(tenant_id, str):
tenant_id = uuid.UUID(tenant_id)
stmt = select(TenantPOIContext).where(
TenantPOIContext.tenant_id == tenant_id
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def get_by_id(self, poi_context_id: str | uuid.UUID) -> Optional[TenantPOIContext]:
"""
Get POI context by ID.
Args:
poi_context_id: POI context UUID
Returns:
TenantPOIContext or None if not found
"""
if isinstance(poi_context_id, str):
poi_context_id = uuid.UUID(poi_context_id)
stmt = select(TenantPOIContext).where(
TenantPOIContext.id == poi_context_id
)
result = await self.session.execute(stmt)
return result.scalar_one_or_none()
async def update(
self,
tenant_id: str | uuid.UUID,
update_data: dict
) -> Optional[TenantPOIContext]:
"""
Update POI context for tenant.
Args:
tenant_id: Tenant UUID
update_data: Dictionary with fields to update
Returns:
Updated TenantPOIContext or None if not found
"""
if isinstance(tenant_id, str):
tenant_id = uuid.UUID(tenant_id)
poi_context = await self.get_by_tenant_id(tenant_id)
if not poi_context:
return None
# Update fields
for key, value in update_data.items():
if hasattr(poi_context, key):
setattr(poi_context, key, value)
# Update timestamp
poi_context.updated_at = datetime.now(timezone.utc)
await self.session.commit()
await self.session.refresh(poi_context)
logger.info(
"POI context updated",
tenant_id=str(tenant_id),
updated_fields=list(update_data.keys())
)
return poi_context
async def create_or_update(
self,
tenant_id: str | uuid.UUID,
poi_detection_results: dict
) -> TenantPOIContext:
"""
Create new POI context or update existing one.
Args:
tenant_id: Tenant UUID
poi_detection_results: Full POI detection results
Returns:
Created or updated TenantPOIContext
"""
if isinstance(tenant_id, str):
tenant_id = uuid.UUID(tenant_id)
existing = await self.get_by_tenant_id(tenant_id)
poi_context_data = {
"tenant_id": tenant_id,
"latitude": poi_detection_results["location"]["latitude"],
"longitude": poi_detection_results["location"]["longitude"],
"poi_detection_results": poi_detection_results.get("poi_categories", {}),
"ml_features": poi_detection_results.get("ml_features", {}),
"total_pois_detected": poi_detection_results.get("summary", {}).get("total_pois_detected", 0),
"high_impact_categories": poi_detection_results.get("summary", {}).get("high_impact_categories", []),
"relevant_categories": poi_detection_results.get("relevant_categories", []),
"detection_timestamp": datetime.fromisoformat(
poi_detection_results["detection_timestamp"].replace("Z", "+00:00")
) if isinstance(poi_detection_results.get("detection_timestamp"), str)
else datetime.now(timezone.utc),
"detection_status": poi_detection_results.get("detection_status", "completed"),
"detection_error": None if poi_detection_results.get("detection_status") == "completed"
else str(poi_detection_results.get("detection_errors"))
}
if existing:
# Update existing
update_data = {
**poi_context_data,
"last_refreshed_at": datetime.now(timezone.utc)
}
existing.mark_refreshed() # Update next_refresh_date
return await self.update(tenant_id, update_data)
else:
# Create new
return await self.create(poi_context_data)
async def delete_by_tenant_id(self, tenant_id: str | uuid.UUID) -> bool:
"""
Delete POI context for tenant.
Args:
tenant_id: Tenant UUID
Returns:
True if deleted, False if not found
"""
if isinstance(tenant_id, str):
tenant_id = uuid.UUID(tenant_id)
stmt = delete(TenantPOIContext).where(
TenantPOIContext.tenant_id == tenant_id
)
result = await self.session.execute(stmt)
await self.session.commit()
deleted = result.rowcount > 0
if deleted:
logger.info("POI context deleted", tenant_id=str(tenant_id))
return deleted
async def get_stale_contexts(self, limit: int = 100) -> List[TenantPOIContext]:
"""
Get POI contexts that need refresh.
Args:
limit: Maximum number of contexts to return
Returns:
List of stale TenantPOIContext instances
"""
now = datetime.now(timezone.utc)
stmt = (
select(TenantPOIContext)
.where(TenantPOIContext.next_refresh_date <= now)
.limit(limit)
)
result = await self.session.execute(stmt)
return list(result.scalars().all())
async def count_by_status(self) -> dict:
"""
Count POI contexts by detection status.
Returns:
Dictionary with counts by status
"""
from sqlalchemy import func
stmt = select(
TenantPOIContext.detection_status,
func.count(TenantPOIContext.id)
).group_by(TenantPOIContext.detection_status)
result = await self.session.execute(stmt)
rows = result.all()
return {status: count for status, count in rows}