Skip to content

Commit

Permalink
Merge pull request #208 from commonknowledge/feat/area-code-geocoding
Browse files Browse the repository at this point in the history
feat: add AREA_CODE high performance geocoding
  • Loading branch information
joaquimds authored Feb 5, 2025
2 parents 331a448 + 9c9bc6b commit 5154e33
Show file tree
Hide file tree
Showing 7 changed files with 391 additions and 48 deletions.
119 changes: 115 additions & 4 deletions hub/data_imports/geocoding_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from asgiref.sync import sync_to_async

from hub.data_imports.utils import get_update_data
from hub.graphql.dataloaders import FieldDataLoaderFactory
from utils import google_maps, mapit_types
from utils.findthatpostcode import (
get_example_postcode_from_area_gss,
Expand All @@ -31,18 +32,32 @@ def find_config_item(source: "ExternalDataSource", key: str, value, default=None
)


class GeocoderContext:
"""
Context class to support DataLoader creation and re-use
(existing dataloaders are stored here, so each record can
re-use a previously created loader. This is a necessary
component for dataloader batching to work).
"""

def __init__(self):
self.dataloaders = {}


# enum of geocoders: postcodes_io, mapbox, google
class Geocoder(Enum):
POSTCODES_IO = "postcodes_io"
FINDTHATPOSTCODE = "findthatpostcode"
MAPBOX = "mapbox"
GOOGLE = "google"
AREA_GEOCODER_V2 = "AREA_GEOCODER_V2"
AREA_CODE_GEOCODER_V2 = "AREA_CODE_GEOCODER_V2"
ADDRESS_GEOCODER_V2 = "ADDRESS_GEOCODER_V2"
COORDINATE_GEOCODER_V1 = "COORDINATE_GEOCODER_V1"


LATEST_AREA_GEOCODER = Geocoder.AREA_GEOCODER_V2
LATEST_AREA_CODE_GEOCODER = Geocoder.AREA_CODE_GEOCODER_V2
LATEST_ADDRESS_GEOCODER = Geocoder.ADDRESS_GEOCODER_V2
LATEST_COORDINATE_GEOCODER = Geocoder.COORDINATE_GEOCODER_V1

Expand Down Expand Up @@ -80,6 +95,7 @@ async def import_record(
source: "ExternalDataSource",
data_type: "DataType",
loaders: "Loaders",
geocoder_context: GeocoderContext,
):
from hub.models import ExternalDataSource, GenericData

Expand All @@ -93,8 +109,13 @@ async def import_record(
geocoding_config_type = source.geocoding_config.get("type", None)
importer_fn = None
if geocoding_config_type == ExternalDataSource.GeographyTypes.AREA:
geocoder = LATEST_AREA_GEOCODER
importer_fn = import_area_data
components = source.geocoding_config.get("components", [])
if len(components) == 1 and components[0].get("type") == "area_code":
geocoder = LATEST_AREA_CODE_GEOCODER
importer_fn = import_area_code_data
else:
geocoder = LATEST_AREA_GEOCODER
importer_fn = import_area_data
elif geocoding_config_type == ExternalDataSource.GeographyTypes.ADDRESS:
geocoder = LATEST_ADDRESS_GEOCODER
importer_fn = import_address_data
Expand All @@ -107,8 +128,7 @@ async def import_record(

# check if geocoding_config and dependent fields are the same; if so, skip geocoding
try:
generic_data = await GenericData.objects.aget(data_type=data_type, data=id)

generic_data = await loaders["generic_data"].load(id)
# First check if the configs are the same
if (
generic_data is not None
Expand Down Expand Up @@ -159,6 +179,94 @@ async def import_record(
data_type=data_type,
loaders=loaders,
update_data=update_data,
geocoder_context=geocoder_context,
)


async def import_area_code_data(
record,
source: "ExternalDataSource",
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
from hub.models import Area, GenericData

update_data["geocoder"] = LATEST_AREA_CODE_GEOCODER.value

area = None
geocoding_data = {}
steps = []

components = source.geocoding_config.get("components", [])
if not components:
return

item = components[0]

literal_lih_area_type__code = item.get("metadata", {}).get(
"lih_area_type__code", None
)
literal_mapit_type = item.get("metadata", {}).get("mapit_type", None)
area_types = literal_lih_area_type__code or literal_mapit_type
literal_area_field = item.get("field", None)
area_code = str(source.get_record_field(record, literal_area_field))

if area_types is None or literal_area_field is None or area_code is None:
return

parsed_area_types = [str(s).upper() for s in ensure_list(area_types)]

area_filters = {}
if literal_lih_area_type__code:
area_filters["area_type__code"] = literal_lih_area_type__code
if literal_mapit_type:
area_filters["mapit_type"] = literal_mapit_type

AreaLoader = FieldDataLoaderFactory.get_loader_class(
Area, field="gss", filters=area_filters, select_related=["area_type"]
)

area = await AreaLoader(geocoder_context).load(area_code)
if area is None:
return

step = {
"type": "area_code_matching",
"area_types": parsed_area_types,
"result": "failed" if area is None else "success",
"search_term": area_code,
"data": (
{
"centroid": area.polygon.centroid.json,
"name": area.name,
"id": area.id,
"gss": area.gss,
}
if area is not None
else None
),
}
steps.append(step)

geocoding_data["area_fields"] = geocoding_data.get("area_fields", {})
geocoding_data["area_fields"][area.area_type.code] = area.gss
update_data["geocode_data"].update({"data": geocoding_data})
if area is not None:
postcode_data = await get_postcode_data_for_area(area, loaders, steps)
update_data["postcode_data"] = postcode_data
update_data["area"] = area
update_data["point"] = area.point
else:
# Reset geocoding data
update_data["postcode_data"] = None

# Update the geocode data regardless, for debugging purposes
update_data["geocode_data"].update({"steps": steps})

await GenericData.objects.aupdate_or_create(
data_type=data_type, data=source.get_record_id(record), defaults=update_data
)


Expand All @@ -168,6 +276,7 @@ async def import_area_data(
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
from hub.models import Area, GenericData

Expand Down Expand Up @@ -419,6 +528,7 @@ async def import_address_data(
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
"""
Converts a record fetched from the API into
Expand Down Expand Up @@ -580,6 +690,7 @@ async def import_coordinate_data(
data_type: "DataType",
loaders: "Loaders",
update_data: dict,
geocoder_context: GeocoderContext,
):
from hub.models import GenericData

Expand Down
36 changes: 26 additions & 10 deletions hub/graphql/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,16 @@

class BasicFieldDataLoader(dataloaders.BaseDjangoModelDataLoader):
field: str
filters: dict = {}
filters: dict
select_related: list

@classmethod
def queryset(cls, keys: list[str]):
if len(keys) == 0:
return []
return cls.model.objects.filter(
**{f"{cls.field}__in": set(keys)}, **cls.filters
)
).select_related(*cls.select_related)

@classmethod
@sync_to_async
Expand Down Expand Up @@ -67,27 +68,42 @@ class FieldDataLoaderFactory(factories.BaseDjangoModelDataLoaderFactory):

@classmethod
def get_loader_key(
cls, model: Type["DjangoModel"], field: str, filters: dict = {}, **kwargs
cls,
model: Type["DjangoModel"],
field: str,
filters: dict = {},
select_related: list = [],
**kwargs,
):
return model, field, json.dumps(filters)
return model, field, json.dumps(filters), json.dumps(select_related)

@classmethod
def get_loader_class_kwargs(
cls, model: Type["DjangoModel"], field: str, filters: dict = {}, **kwargs
cls,
model: Type["DjangoModel"],
field: str,
filters: dict = {},
select_related: list = [],
**kwargs,
):
return {"model": model, "field": field, "filters": filters}
return {
"model": model,
"field": field,
"filters": filters,
"select_related": select_related,
}

@classmethod
def as_resolver(
cls, field: str, filters: dict = {}
cls, field: str, filters: dict = {}, select_related: list = []
) -> Callable[["DjangoModel", Info], Coroutine]:
async def resolver(
root: "DjangoModel", info: "Info"
): # beware, first argument needs to be called 'root'
field_data: "StrawberryDjangoField" = info._field
return await cls.get_loader_class(field_data.django_model, field, filters)(
context=info.context
).load(getattr(root, field))
return await cls.get_loader_class(
field_data.django_model, field, filters, select_related
)(context=info.context).load(getattr(root, field))

return resolver

Expand Down
5 changes: 1 addition & 4 deletions hub/graphql/types/model_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@
from hub.graphql.types.postcodes import PostcodesIOResult
from hub.graphql.utils import attr_field, dict_key_field, fn_field
from hub.management.commands.import_mps import party_shades
from utils.geo_reference import (
AnalyticalAreaType,
area_to_postcode_io_key,
)
from utils.geo_reference import AnalyticalAreaType, area_to_postcode_io_key
from utils.postcode import get_postcode_data_for_gss
from utils.statistics import (
attempt_interpret_series_as_float,
Expand Down
25 changes: 24 additions & 1 deletion hub/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,11 +1750,14 @@ async def import_many(self, members: list):
)

loaders = await self.get_loaders()
geocoder_context = geocoding_config.GeocoderContext()

if self.uses_valid_geocoding_config():
await asyncio.gather(
*[
geocoding_config.import_record(record, self, data_type, loaders)
geocoding_config.import_record(
record, self, data_type, loaders, geocoder_context
)
for record in data
]
)
Expand Down Expand Up @@ -2044,6 +2047,25 @@ def get_import_data(self, **kwargs):
data_type__data_set__external_data_source_id=self.id
)

async def imported_data_loader(self, record_ids):
"""
A dataloader function for getting already-imported GenericData
for a given record ID. This ID should be the result of calling
get_record_id(record) on the source data (i.e. the record that
comes from the 3rd party data source) – NOT the GenericData.id.
"""
results = GenericData.objects.filter(
data_type__data_set__external_data_source_id=self.id, data__in=record_ids
)
results = await sync_to_async(list)(results)
return [
next(
(result for result in results if result.data == id),
None,
)
for id in record_ids
]

def get_analytics_queryset(self, **kwargs):
return self.get_import_data()

Expand Down Expand Up @@ -2144,6 +2166,7 @@ async def get_loaders(self) -> Loaders:
postcodesIOFromPoint=DataLoader(load_fn=get_bulk_postcode_geo_from_coords),
fetch_record=DataLoader(load_fn=self.fetch_many_loader, cache=False),
source_loaders=await self.get_source_loaders(),
generic_data=DataLoader(load_fn=self.imported_data_loader),
)

return loaders
Expand Down
28 changes: 28 additions & 0 deletions hub/tests/fixtures/geocoding_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,31 @@
"expected_area_gss": "E05014930",
},
]


area_code_geocoding_cases = [
{
"id": "1",
"ward": "E05000993",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05000993",
},
{
"id": "2",
"ward": "E05015081",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05015081",
},
{
"id": "3",
"ward": "E05012085",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05012085",
},
{
"id": "4",
"ward": "E05007461",
"expected_area_type_code": "WD23",
"expected_area_gss": "E05007461",
},
]
Loading

0 comments on commit 5154e33

Please sign in to comment.