Skip to content

Commit

Permalink
Ensure field types do not conflict between document types
Browse files Browse the repository at this point in the history
  • Loading branch information
nadove-ucsc committed Sep 20, 2023
1 parent f2694c6 commit 61ef8b1
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
43 changes: 43 additions & 0 deletions src/azul/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,49 @@ def dict_merge(dicts: Iterable[Mapping]) -> Mapping:
return dict(items)


def deep_dict_merge(dicts: Iterable[Mapping]) -> Mapping:
"""
Merge all dictionaries yielded by the argument. If more than one dictionary
contains a given key, and all values associated with this key are themselves
dictionaries, then the value present in the result is the recursive merging
of those nested dictionaries.
>>> deep_dict_merge(({0: 1}, {1: 0}))
{0: 1, 1: 0}
>>> deep_dict_merge(({0: {'a': 1}}, {0: {'b': 2}}))
{0: {'a': 1, 'b': 2}}
Key collisions where either value is not a dictionary raise an exception,
unless the values compare equal to each other, in which case the entries
from *earlier* dictionaries takes precedence. This behavior is the opposite
of `dict_merge`, where later entries take precedence.
>>> deep_dict_merge(({0: 1}, {0: 2}))
Traceback (most recent call last):
...
ValueError: 1 != 2
>>> l1, l2 = [], []
>>> d = deep_dict_merge(({0: l1}, {0: l2}))
>>> d
{0: []}
>>> id(d[0]) == id(l1)
True
"""
merged = {}
for m in dicts:
for k, v2 in m.items():
v1 = merged.setdefault(k, v2)
if v1 != v2:
if isinstance(v1, dict) and isinstance(v2, dict):
merged[k] = deep_dict_merge((v1, v2))
else:
raise ValueError(f'{v1!r} != {v2!r}')

return merged


K = TypeVar('K')
V = TypeVar('V')

Expand Down
11 changes: 7 additions & 4 deletions src/azul/indexer/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
cache,
config,
)
from azul.collections import (
deep_dict_merge,
)
from azul.indexer.document import (
Aggregate,
CataloguedFieldTypes,
Expand Down Expand Up @@ -88,10 +91,10 @@ def field_types(self, catalog: CatalogName) -> FieldTypes:
aggregate_cls = self.aggregate_class(catalog)
for transformer_cls in self.transformer_types(catalog):
field_types.update(transformer_cls.field_types())
return {
**Contribution.field_types(field_types),
**aggregate_cls.field_types(field_types)
}
return deep_dict_merge((
Contribution.field_types(field_types),
aggregate_cls.field_types(field_types)
))

def catalogued_field_types(self) -> CataloguedFieldTypes:
return {
Expand Down

0 comments on commit 61ef8b1

Please sign in to comment.