Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add AREA_CODE high performance geocoding #208

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 115 additions & 4 deletions hub/data_imports/geocoding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from asgiref.sync import sync_to_async

from hub.data_imports.utils import get_update_data
from hub.graphql.dataloaders import FieldDataLoaderFactory
from utils import google_maps, mapit_types
from utils.findthatpostcode import (
get_example_postcode_from_area_gss,
Expand All @@ -31,18 +32,32 @@ def find_config_item(source: "ExternalDataSource", key: str, value, default=None
)


class GeocoderContext:
"""
Context class to support DataLoader creation and re-use
(existing dataloaders are stored here, so each record can
re-use a previously created loader. This is a necessary
component for dataloader batching to work).
"""

def __init__(self):
self.dataloaders = {}


# enum of geocoders: postcodes_io, mapbox, google
class Geocoder(Enum):
POSTCODES_IO = "postcodes_io"
FINDTHATPOSTCODE = "findthatpostcode"
MAPBOX = "mapbox"
GOOGLE = "google"
AREA_GEOCODER_V2 = "AREA_GEOCODER_V2"
AREA_CODE_GEOCODER_V2 = "AREA_CODE_GEOCODER_V2"
ADDRESS_GEOCODER_V2 = "ADDRESS_GEOCODER_V2"
COORDINATE_GEOCODER_V1 = "COORDINATE_GEOCODER_V1"


LATEST_AREA_GEOCODER = Geocoder.AREA_GEOCODER_V2
LATEST_AREA_CODE_GEOCODER = Geocoder.AREA_CODE_GEOCODER_V2
LATEST_ADDRESS_GEOCODER = Geocoder.ADDRESS_GEOCODER_V2
LATEST_COORDINATE_GEOCODER = Geocoder.COORDINATE_GEOCODER_V1

Expand Down Expand Up @@ -80,6 +95,7 @@ async def import_record(
source: "ExternalDataSource",
data_type: "DataType",
loaders: "Loaders",
geocoder_context: GeocoderContext,
):
from hub.models import ExternalDataSource, GenericData

Expand All @@ -93,8 +109,13 @@ async def import_record(
geocoding_config_type = source.geocoding_config.get("type", None)
importer_fn = None
if geocoding_config_type == ExternalDataSource.GeographyTypes.AREA:
geocoder = LATEST_AREA_GEOCODER
importer_fn = import_area_data
components = source.geocoding_config.get("components", [])
if len(components) == 1 and components[0].get("type") == "area_code":
geocoder = LATEST_AREA_CODE_GEOCODER
importer_fn = import_area_code_data
else:
geocoder = LATEST_AREA_GEOCODER
importer_fn = import_area_data
elif geocoding_config_type == ExternalDataSource.GeographyTypes.ADDRESS:
geocoder = LATEST_ADDRESS_GEOCODER
importer_fn = import_address_data
Expand All @@ -107,8 +128,7 @@ async def import_record(

# check if geocoding_config and dependent fields are the same; if so, skip geocoding
try:
generic_data = await GenericData.objects.aget(data_type=data_type, data=id)

generic_data = await loaders["generic_data"].load(id)
# First check if the configs are the same
if (
generic_data is not None
Expand Down Expand Up @@ -159,6 +179,94 @@ async def import_record(
data_type=data_type,
loaders=loaders,
update_data=update_data,
geocoder_context=geocoder_context,
)


async def import_area_code_data(
record,
source: "ExternalDataSource",
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
from hub.models import Area, GenericData

update_data["geocoder"] = LATEST_AREA_CODE_GEOCODER.value

area = None
geocoding_data = {}
steps = []

components = source.geocoding_config.get("components", [])
if not components:
return

item = components[0]

literal_lih_area_type__code = item.get("metadata", {}).get(
"lih_area_type__code", None
)
literal_mapit_type = item.get("metadata", {}).get("mapit_type", None)
area_types = literal_lih_area_type__code or literal_mapit_type
literal_area_field = item.get("field", None)
area_code = str(source.get_record_field(record, literal_area_field))

if area_types is None or literal_area_field is None or area_code is None:
return

parsed_area_types = [str(s).upper() for s in ensure_list(area_types)]

area_filters = {}
if literal_lih_area_type__code:
area_filters["area_type__code"] = literal_lih_area_type__code
if literal_mapit_type:
area_filters["mapit_type"] = literal_mapit_type

AreaLoader = FieldDataLoaderFactory.get_loader_class(
Area, field="gss", filters=area_filters, select_related=["area_type"]
)

area = await AreaLoader(geocoder_context).load(area_code)
if area is None:
return

step = {
"type": "area_code_matching",
"area_types": parsed_area_types,
"result": "failed" if area is None else "success",
"search_term": area_code,
"data": (
{
"centroid": area.polygon.centroid.json,
"name": area.name,
"id": area.id,
"gss": area.gss,
}
if area is not None
else None
),
}
steps.append(step)

geocoding_data["area_fields"] = geocoding_data.get("area_fields", {})
geocoding_data["area_fields"][area.area_type.code] = area.gss
update_data["geocode_data"].update({"data": geocoding_data})
if area is not None:
postcode_data = await get_postcode_data_for_area(area, loaders, steps)
update_data["postcode_data"] = postcode_data
update_data["area"] = area
update_data["point"] = area.point
else:
# Reset geocoding data
update_data["postcode_data"] = None

# Update the geocode data regardless, for debugging purposes
update_data["geocode_data"].update({"steps": steps})

await GenericData.objects.aupdate_or_create(
data_type=data_type, data=source.get_record_id(record), defaults=update_data
)


Expand All @@ -168,6 +276,7 @@ async def import_area_data(
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
from hub.models import Area, GenericData

Expand Down Expand Up @@ -419,6 +528,7 @@ async def import_address_data(
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
"""
Converts a record fetched from the API into
Expand Down Expand Up @@ -580,6 +690,7 @@ async def import_coordinate_data(
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
from hub.models import GenericData

Expand Down
34 changes: 25 additions & 9 deletions hub/graphql/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@
class BasicFieldDataLoader(dataloaders.BaseDjangoModelDataLoader):
field: str
filters: dict = {}
select_related: list = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these need to be = dict() and = list() for immutability's sake?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think changing them to dict() and list() would fix the immutability issue, but I think simply removing the default values works – the factory sets them to fresh {} and [] on class creation.


@classmethod
def queryset(cls, keys: list[str]):
if len(keys) == 0:
return []
return cls.model.objects.filter(
**{f"{cls.field}__in": set(keys)}, **cls.filters
)
).select_related(*cls.select_related)

@classmethod
@sync_to_async
Expand Down Expand Up @@ -67,27 +68,42 @@ class FieldDataLoaderFactory(factories.BaseDjangoModelDataLoaderFactory):

@classmethod
def get_loader_key(
cls, model: Type["DjangoModel"], field: str, filters: dict = {}, **kwargs
cls,
model: Type["DjangoModel"],
field: str,
filters: dict = {},
select_related: list = [],
**kwargs,
):
return model, field, json.dumps(filters)
return model, field, json.dumps(filters), json.dumps(select_related)

@classmethod
def get_loader_class_kwargs(
cls, model: Type["DjangoModel"], field: str, filters: dict = {}, **kwargs
cls,
model: Type["DjangoModel"],
field: str,
filters: dict = {},
select_related: list = [],
**kwargs,
):
return {"model": model, "field": field, "filters": filters}
return {
"model": model,
"field": field,
"filters": filters,
"select_related": select_related,
}

@classmethod
def as_resolver(
cls, field: str, filters: dict = {}
cls, field: str, filters: dict = {}, select_related: list = []
) -> Callable[["DjangoModel", Info], Coroutine]:
async def resolver(
root: "DjangoModel", info: "Info"
): # beware, first argument needs to be called 'root'
field_data: "StrawberryDjangoField" = info._field
return await cls.get_loader_class(field_data.django_model, field, filters)(
context=info.context
).load(getattr(root, field))
return await cls.get_loader_class(
field_data.django_model, field, filters, select_related
)(context=info.context).load(getattr(root, field))

return resolver

Expand Down
5 changes: 1 addition & 4 deletions hub/graphql/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@
from hub.graphql.types.postcodes import PostcodesIOResult
from hub.graphql.utils import attr_field, dict_key_field, fn_field
from hub.management.commands.import_mps import party_shades
from utils.geo_reference import (
AnalyticalAreaType,
area_to_postcode_io_key,
)
from utils.geo_reference import AnalyticalAreaType, area_to_postcode_io_key
from utils.postcode import get_postcode_data_for_gss
from utils.statistics import (
attempt_interpret_series_as_float,
Expand Down
25 changes: 24 additions & 1 deletion hub/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,11 +1750,14 @@ async def import_many(self, members: list):
)

loaders = await self.get_loaders()
geocoder_context = geocoding_config.GeocoderContext()

if self.uses_valid_geocoding_config():
await asyncio.gather(
*[
geocoding_config.import_record(record, self, data_type, loaders)
geocoding_config.import_record(
record, self, data_type, loaders, geocoder_context
)
for record in data
]
)
Expand Down Expand Up @@ -2044,6 +2047,25 @@ def get_import_data(self, **kwargs):
data_type__data_set__external_data_source_id=self.id
)

async def imported_data_loader(self, record_ids):
"""
A dataloader function for getting already-imported GenericData
for a given record ID. This ID should be the result of calling
get_record_id(record) on the source data (i.e. the record that
comes from the 3rd party data source) – NOT the GenericData.id.
"""
results = GenericData.objects.filter(
data_type__data_set__external_data_source_id=self.id, data__in=record_ids
)
results = await sync_to_async(list)(results)
return [
next(
(result for result in results if result.data == id),
None,
)
for id in record_ids
]

def get_analytics_queryset(self, **kwargs):
return self.get_import_data()

Expand Down Expand Up @@ -2144,6 +2166,7 @@ async def get_loaders(self) -> Loaders:
postcodesIOFromPoint=DataLoader(load_fn=get_bulk_postcode_geo_from_coords),
fetch_record=DataLoader(load_fn=self.fetch_many_loader, cache=False),
source_loaders=await self.get_source_loaders(),
generic_data=DataLoader(load_fn=self.imported_data_loader),
)

return loaders
Expand Down
28 changes: 28 additions & 0 deletions hub/tests/fixtures/geocoding_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,31 @@
"expected_area_gss": "E05014930",
},
]


area_code_geocoding_cases = [
{
"id": "1",
"ward": "E05000993",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05000993",
},
{
"id": "2",
"ward": "E05015081",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05015081",
},
{
"id": "3",
"ward": "E05012085",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05012085",
},
{
"id": "4",
"ward": "E05007461",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05007461",
},
]
Loading