diff --git a/src/azul/__init__.py b/src/azul/__init__.py index 7599cd2ca..a691a5dc1 100644 --- a/src/azul/__init__.py +++ b/src/azul/__init__.py @@ -8,6 +8,7 @@ ) from enum import ( Enum, + auto, ) import functools from itertools import ( @@ -81,6 +82,11 @@ def cache_per_thread(f, /): mutable_furl = furl +class DocumentType(Enum): + contributions = auto() + aggregates = auto() + + class Config: """ See `environment` for documentation of these settings. @@ -938,14 +944,14 @@ def integration_test_catalogs(self) -> Mapping[CatalogName, Catalog]: def es_index_name(self, catalog: CatalogName, entity_type: str, - aggregate: bool + doc_type: DocumentType ) -> str: return str(IndexName(prefix=self._index_prefix, version=2, deployment=self.deployment_stage, catalog=catalog, entity_type=entity_type, - aggregate=aggregate)) + doc_type=doc_type)) def parse_es_index_name(self, index_name: str) -> 'IndexName': """ @@ -1511,20 +1517,46 @@ class IndexName: entity_type: str #: Whether the documents in the index are contributions or aggregates - aggregate: bool = False + doc_type: DocumentType = DocumentType.contributions index_name_version_re: ClassVar[re.Pattern] = re.compile(r'v(\d+)') def __attrs_post_init__(self): """ - >>> IndexName(prefix='azul', version=1, deployment='dev', entity_type='foo_bar') - IndexName(prefix='azul', version=1, deployment='dev', catalog=None, entity_type='foo_bar', aggregate=False) - - >>> IndexName(prefix='azul', version=1, deployment='dev', catalog=None, entity_type='foo_bar') - IndexName(prefix='azul', version=1, deployment='dev', catalog=None, entity_type='foo_bar', aggregate=False) - - >>> IndexName(prefix='azul', version=2, deployment='dev', catalog='main', entity_type='foo_bar') - IndexName(prefix='azul', version=2, deployment='dev', catalog='main', entity_type='foo_bar', aggregate=False) + >>> IndexName(prefix='azul', + ... version=1, + ... deployment='dev', + ... entity_type='foo_bar') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=1, + deployment='dev', + catalog=None, + entity_type='foo_bar', + doc_type=) + + >>> IndexName(prefix='azul', + ... version=1, + ... deployment='dev', + ... catalog=None, + ... entity_type='foo_bar') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=1, + deployment='dev', + catalog=None, + entity_type='foo_bar', + doc_type=) + + >>> IndexName(prefix='azul', + ... version=2, + ... deployment='dev', + ... catalog='main', + ... entity_type='foo_bar') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=2, + deployment='dev', + catalog='main', + entity_type='foo_bar', + doc_type=) >>> IndexName(prefix='azul', version=1, deployment='dev', catalog='hca', entity_type='foo') Traceback (most recent call last): @@ -1581,20 +1613,45 @@ def parse(cls, index_name, expected_prefix=prefix) -> 'IndexName': """ Parse the name of an index from any deployment and any version of Azul. - >>> IndexName.parse('azul_foo_dev') - IndexName(prefix='azul', version=1, deployment='dev', catalog=None, entity_type='foo', aggregate=False) - - >>> IndexName.parse('azul_foo_aggregate_dev') - IndexName(prefix='azul', version=1, deployment='dev', catalog=None, entity_type='foo', aggregate=True) - - >>> IndexName.parse('azul_foo_bar_dev') - IndexName(prefix='azul', version=1, deployment='dev', catalog=None, entity_type='foo_bar', aggregate=False) - - >>> IndexName.parse('azul_foo_bar_aggregate_dev') - IndexName(prefix='azul', version=1, deployment='dev', catalog=None, entity_type='foo_bar', aggregate=True) - - >>> IndexName.parse('good_foo_dev', expected_prefix='good') - IndexName(prefix='good', version=1, deployment='dev', catalog=None, entity_type='foo', aggregate=False) + >>> IndexName.parse('azul_foo_dev') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=1, + deployment='dev', + catalog=None, + entity_type='foo', + doc_type=) + + >>> IndexName.parse('azul_foo_aggregate_dev') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=1, + deployment='dev', + catalog=None, + entity_type='foo', + doc_type=) + + >>> IndexName.parse('azul_foo_bar_dev') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=1, + deployment='dev', + catalog=None, + entity_type='foo_bar', + doc_type=) + + >>> IndexName.parse('azul_foo_bar_aggregate_dev') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=1, + deployment='dev', + catalog=None, + entity_type='foo_bar', + doc_type=) + + >>> IndexName.parse('good_foo_dev', expected_prefix='good') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='good', + version=1, + deployment='dev', + catalog=None, + entity_type='foo', + doc_type=) >>> IndexName.parse('bad_foo_dev') Traceback (most recent call last): @@ -1611,20 +1668,45 @@ def parse(cls, index_name, expected_prefix=prefix) -> 'IndexName': ... azul.RequirementError: entity_type ... '' - >>> IndexName.parse('azul_v2_dev_main_foo') - IndexName(prefix='azul', version=2, deployment='dev', catalog='main', entity_type='foo', aggregate=False) - - >>> IndexName.parse('azul_v2_dev_main_foo_aggregate') - IndexName(prefix='azul', version=2, deployment='dev', catalog='main', entity_type='foo', aggregate=True) - - >>> IndexName.parse('azul_v2_dev_main_foo_bar') - IndexName(prefix='azul', version=2, deployment='dev', catalog='main', entity_type='foo_bar', aggregate=False) - - >>> IndexName.parse('azul_v2_dev_main_foo_bar_aggregate') - IndexName(prefix='azul', version=2, deployment='dev', catalog='main', entity_type='foo_bar', aggregate=True) - - >>> IndexName.parse('azul_v2_staging_hca_foo_bar_aggregate') - IndexName(prefix='azul', version=2, deployment='staging', catalog='hca', entity_type='foo_bar', aggregate=True) + >>> IndexName.parse('azul_v2_dev_main_foo') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=2, + deployment='dev', + catalog='main', + entity_type='foo', + doc_type=) + + >>> IndexName.parse('azul_v2_dev_main_foo_aggregate') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=2, + deployment='dev', + catalog='main', + entity_type='foo', + doc_type=) + + >>> IndexName.parse('azul_v2_dev_main_foo_bar') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=2, + deployment='dev', + catalog='main', + entity_type='foo_bar', + doc_type=) + + >>> IndexName.parse('azul_v2_dev_main_foo_bar_aggregate') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=2, + deployment='dev', + catalog='main', + entity_type='foo_bar', + doc_type=) + + >>> IndexName.parse('azul_v2_staging_hca_foo_bar_aggregate') # doctest: +NORMALIZE_WHITESPACE + IndexName(prefix='azul', + version=2, + deployment='staging', + catalog='hca', + entity_type='foo_bar', + doc_type=) >>> IndexName.parse('azul_v2_staging__foo_bar__aggregate') # doctest: +ELLIPSIS Traceback (most recent call last): @@ -1653,9 +1735,9 @@ def parse(cls, index_name, expected_prefix=prefix) -> 'IndexName': *index_name, deployment = index_name if index_name[-1] == 'aggregate': *index_name, _ = index_name - aggregate = True + doc_type = DocumentType.aggregates else: - aggregate = False + doc_type = DocumentType.contributions entity_type = '_'.join(index_name) Config.validate_entity_type(entity_type) return cls(prefix=prefix, @@ -1663,38 +1745,50 @@ def parse(cls, index_name, expected_prefix=prefix) -> 'IndexName': deployment=deployment, catalog=catalog, entity_type=entity_type, - aggregate=aggregate) + doc_type=doc_type) def __str__(self) -> str: """ >>> str(IndexName(version=1, deployment='dev', entity_type='foo')) 'azul_foo_dev' - >>> str(IndexName(version=1, deployment='dev', entity_type='foo', aggregate=True)) + >>> str(IndexName(version=1, deployment='dev', entity_type='foo', doc_type=DocumentType.aggregates)) 'azul_foo_aggregate_dev' >>> str(IndexName(version=1, deployment='dev', entity_type='foo_bar')) 'azul_foo_bar_dev' - >>> str(IndexName(version=1, deployment='dev', entity_type='foo_bar', aggregate=True)) + >>> str(IndexName(version=1, deployment='dev', entity_type='foo_bar', doc_type=DocumentType.aggregates)) 'azul_foo_bar_aggregate_dev' >>> str(IndexName(version=2, deployment='dev', catalog='main', entity_type='foo')) 'azul_v2_dev_main_foo' - >>> str(IndexName(version=2, deployment='dev', catalog='main', entity_type='foo', aggregate=True)) + >>> str(IndexName(version=2, + ... deployment='dev', + ... catalog='main', + ... entity_type='foo', + ... doc_type=DocumentType.aggregates)) 'azul_v2_dev_main_foo_aggregate' >>> str(IndexName(version=2, deployment='dev', catalog='main', entity_type='foo_bar')) 'azul_v2_dev_main_foo_bar' - >>> str(IndexName(version=2, deployment='dev', catalog='main', entity_type='foo_bar', aggregate=True)) + >>> str(IndexName(version=2, + ... deployment='dev', + ... catalog='main', + ... entity_type='foo_bar', + ... doc_type=DocumentType.aggregates)) 'azul_v2_dev_main_foo_bar_aggregate' - >>> str(IndexName(version=2, deployment='staging', catalog='hca', entity_type='foo_bar', aggregate=True)) + >>> str(IndexName(version=2, + ... deployment='staging', + ... catalog='hca', + ... entity_type='foo_bar', + ... doc_type=DocumentType.aggregates)) 'azul_v2_staging_hca_foo_bar_aggregate' """ - aggregate = ['aggregate'] if self.aggregate else [] + aggregate = ['aggregate'] if self.doc_type is DocumentType.aggregates else [] if self.version == 1: require(self.catalog is None) return '_'.join([ diff --git a/src/azul/indexer/document.py b/src/azul/indexer/document.py index a49a12839..0bccb1a81 100644 --- a/src/azul/indexer/document.py +++ b/src/azul/indexer/document.py @@ -32,6 +32,7 @@ from azul import ( CatalogName, + DocumentType, IndexName, config, ) @@ -107,7 +108,7 @@ class DocumentCoordinates(Generic[E], metaclass=ABCMeta): be generic in E, the type of EntityReference. """ entity: E - aggregate: bool + doc_type: DocumentType @property def index_name(self) -> str: @@ -119,7 +120,7 @@ def index_name(self) -> str: assert isinstance(self.entity, CataloguedEntityReference) return config.es_index_name(catalog=self.entity.catalog, entity_type=self.entity.entity_type, - aggregate=self.aggregate) + doc_type=self.doc_type) @property @abstractmethod @@ -132,7 +133,12 @@ def from_hit(cls, ) -> 'DocumentCoordinates[CataloguedEntityReference]': index_name = config.parse_es_index_name(hit['_index']) document_id = hit['_id'] - subcls = AggregateCoordinates if index_name.aggregate else ContributionCoordinates + if index_name.doc_type is DocumentType.contributions: + subcls = ContributionCoordinates + elif index_name.doc_type is DocumentType.aggregates: + subcls = AggregateCoordinates + else: + assert False, index_name.doc_type assert issubclass(subcls, cls) return subcls._from_index(index_name, document_id) @@ -164,7 +170,7 @@ def with_catalog(self, @attr.s(frozen=True, auto_attribs=True, kw_only=True, slots=True) class ContributionCoordinates(DocumentCoordinates[E], Generic[E]): - aggregate: bool = attr.ib(init=False, default=False) + doc_type: DocumentType = attr.ib(init=False, default=DocumentType.contributions) bundle: BundleFQID deleted: bool @@ -200,7 +206,7 @@ def _from_index(cls, document_id: str ) -> 'ContributionCoordinates[CataloguedEntityReference]': entity_type = index_name.entity_type - assert index_name.aggregate is False + assert index_name.doc_type is DocumentType.contributions entity_id, bundle_uuid, bundle_version, deleted = document_id.split('_') if deleted == 'deleted': deleted = True @@ -229,7 +235,7 @@ class AggregateCoordinates(DocumentCoordinates[CataloguedEntityReference]): Document coordinates for aggregates. Aggregate coordinates always carry a catalog. """ - aggregate: bool = attr.ib(init=False, default=True) + doc_type: DocumentType = attr.ib(init=False, default=DocumentType.aggregates) @classmethod def _from_index(cls, @@ -237,7 +243,7 @@ def _from_index(cls, document_id: str ) -> 'AggregateCoordinates': entity_type = index_name.entity_type - assert index_name.aggregate is True + assert index_name.doc_type is DocumentType.aggregates return cls(entity=CataloguedEntityReference(catalog=index_name.catalog, entity_type=entity_type, entity_id=document_id)) @@ -957,7 +963,7 @@ class Contribution(Document[ContributionCoordinates[E]]): def __attrs_post_init__(self): assert isinstance(self.coordinates, ContributionCoordinates) - assert self.coordinates.aggregate is False + assert self.coordinates.doc_type is DocumentType.contributions @classmethod def field_types(cls, field_types: FieldTypes) -> FieldTypes: @@ -1035,7 +1041,7 @@ def __init__(self, def __attrs_post_init__(self): assert isinstance(self.coordinates, AggregateCoordinates) - assert self.coordinates.aggregate is True + assert self.coordinates.doc_type is DocumentType.aggregates @classmethod def field_types(cls, field_types: FieldTypes) -> FieldTypes: diff --git a/src/azul/indexer/index_service.py b/src/azul/indexer/index_service.py index 09babed32..9b44d74b0 100644 --- a/src/azul/indexer/index_service.py +++ b/src/azul/indexer/index_service.py @@ -40,6 +40,7 @@ from azul import ( CatalogName, + DocumentType, cache, config, freeze, @@ -117,7 +118,7 @@ def repository_plugin(self, catalog: CatalogName) -> RepositoryPlugin: def settings(self, index_name) -> JSON: index_name = config.parse_es_index_name(index_name) - aggregate = index_name.aggregate + aggregate = index_name.doc_type is DocumentType.aggregates catalog = index_name.catalog assert catalog is not None, catalog if config.catalogs[catalog].is_integration_test_catalog: @@ -166,9 +167,9 @@ def index_names(self, catalog: CatalogName) -> list[str]: return [ config.es_index_name(catalog=catalog, entity_type=entity_type, - aggregate=aggregate) + doc_type=doc_type) for entity_type in self.entity_types(catalog) - for aggregate in (False, True) + for doc_type in (DocumentType.contributions, DocumentType.aggregates) ] def fetch_bundle(self, @@ -516,7 +517,7 @@ def _read_contributions(self, for entity in tallies.keys(): index = config.es_index_name(catalog=entity.catalog, entity_type=entity.entity_type, - aggregate=False) + doc_type=DocumentType.contributions) entity_ids_by_index[index].add(entity.entity_id) query = { diff --git a/src/azul/service/elasticsearch_service.py b/src/azul/service/elasticsearch_service.py index 16740c621..ee312f492 100644 --- a/src/azul/service/elasticsearch_service.py +++ b/src/azul/service/elasticsearch_service.py @@ -45,6 +45,7 @@ from azul import ( CatalogName, + DocumentType, cached_property, config, reject, @@ -668,4 +669,4 @@ def create_request(self, catalog, entity_type) -> Search: return Search(using=self._es_client, index=config.es_index_name(catalog=catalog, entity_type=entity_type, - aggregate=True)) + doc_type=DocumentType.aggregates)) diff --git a/test/indexer/__init__.py b/test/indexer/__init__.py index 2586792c9..acc31285c 100644 --- a/test/indexer/__init__.py +++ b/test/indexer/__init__.py @@ -150,7 +150,7 @@ def _load_canned_result(self, bundle_fqid: BundleFQID) -> MutableJSONs: index_name = IndexName.parse(hit['_index']) hit['_index'] = config.es_index_name(catalog=self.catalog, entity_type=index_name.entity_type, - aggregate=index_name.aggregate) + doc_type=index_name.doc_type) return expected_hits @classmethod diff --git a/test/indexer/test_indexer.py b/test/indexer/test_indexer.py index dc147575a..f5c6e23b5 100644 --- a/test/indexer/test_indexer.py +++ b/test/indexer/test_indexer.py @@ -17,6 +17,7 @@ ) import re from typing import ( + Iterable, Optional, cast, ) @@ -38,6 +39,7 @@ ) from azul import ( + DocumentType, RequirementError, cached_property, config, @@ -54,6 +56,7 @@ Contribution, ContributionCoordinates, EntityReference, + EntityType, null_bool, null_int, null_str, @@ -172,16 +175,18 @@ def test_deletion(self): self.assertEqual(len(hits), size * 2) num_aggregates, num_contribs = 0, 0 for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: doc = aggregate_cls.from_index(field_types, hit) self.assertNotEqual(doc.contents, {}) num_aggregates += 1 - else: + elif doc_type is DocumentType.contributions: doc = Contribution.from_index(field_types, hit) self.assertEqual(bundle_fqid.upcast(), doc.coordinates.bundle) self.assertFalse(doc.coordinates.deleted) num_contribs += 1 + else: + assert False, doc_type self.assertEqual(num_aggregates, size) self.assertEqual(num_contribs, size) @@ -192,12 +197,13 @@ def test_deletion(self): self.assertEqual(len(hits), 2 * size) docs_by_entity: dict[EntityReference, list[Contribution]] = defaultdict(list) for hit in hits: + # FIXME: Why parse the hit as a contribution before asserting doc type? doc = Contribution.from_index(field_types, hit) docs_by_entity[doc.entity].append(doc) - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) # Since there is only one bundle and it was deleted, # nothing should be aggregated - self.assertFalse(aggregate) + self.assertNotEqual(doc_type, DocumentType.aggregates) self.assertEqual(bundle_fqid.upcast(), doc.coordinates.bundle) for pair in docs_by_entity.values(): @@ -207,9 +213,19 @@ def test_deletion(self): self.index_service.delete_indices(self.catalog) self.index_service.create_indices(self.catalog) - def _parse_index_name(self, hit) -> tuple[str, bool]: + def _parse_index_name(self, hit) -> tuple[str, DocumentType]: index_name = config.parse_es_index_name(hit['_index']) - return index_name.entity_type, index_name.aggregate + return index_name.entity_type, index_name.doc_type + + def _filter_hits(self, + hits: JSONs, + doc_type: Optional[DocumentType] = None, + entity_type: Optional[EntityType] = None, + ) -> Iterable[JSON]: + for hit in hits: + hit_entity_type, hit_doc_type = self._parse_index_name(hit) + if entity_type in (None, hit_entity_type) and doc_type in (None, hit_doc_type): + yield hit def test_duplicate_notification(self): # Contribute the bundle once @@ -308,11 +324,7 @@ def _assert_index_counts(self, just_deletion): if h['_source']['bundle_deleted'] ] - def is_aggregate(h): - _, aggregate_ = self._parse_index_name(h) - return aggregate_ - - actual_aggregates = [h for h in hits if is_aggregate(h)] + actual_aggregates = list(self._filter_hits(hits, DocumentType.aggregates)) self.assertEqual(len(actual_addition_contributions), num_expected_addition_contributions) @@ -361,23 +373,26 @@ def test_multi_entity_contributing_bundles(self): hits_after = self._get_all_hits() num_docs_by_index_after = self._num_docs_by_index(hits_after) - for entity_type, aggregate in num_docs_by_index_after.keys(): + for entity_type, doc_type in num_docs_by_index_after.keys(): # Both bundles reference two files. They both share one file and # exclusively own another one. Deleting one of the bundles removes # the file owned exclusively by that bundle, as well as the bundle itself. - if aggregate: + if doc_type is DocumentType.aggregates: difference = 1 if entity_type in ('files', 'bundles') else 0 - self.assertEqual(num_docs_by_index_after[entity_type, aggregate], - num_docs_by_index_before[entity_type, aggregate] - difference) - elif entity_type in ('bundles', 'samples', 'projects', 'cell_suspensions'): - # Count one extra deletion contribution - self.assertEqual(num_docs_by_index_after[entity_type, aggregate], - num_docs_by_index_before[entity_type, aggregate] + 1) + self.assertEqual(num_docs_by_index_after[entity_type, doc_type], + num_docs_by_index_before[entity_type, doc_type] - difference) + elif doc_type is DocumentType.contributions: + if entity_type in ('bundles', 'samples', 'projects', 'cell_suspensions'): + # Count one extra deletion contribution + self.assertEqual(num_docs_by_index_after[entity_type, doc_type], + num_docs_by_index_before[entity_type, doc_type] + 1) + else: + # Count two extra deletion contributions for the two files + self.assertEqual(entity_type, 'files') + self.assertEqual(num_docs_by_index_after[entity_type, doc_type], + num_docs_by_index_before[entity_type, doc_type] + 2) else: - # Count two extra deletion contributions for the two files - self.assertEqual(entity_type, 'files') - self.assertEqual(num_docs_by_index_after[entity_type, aggregate], - num_docs_by_index_before[entity_type, aggregate] + 2) + assert False, doc_type entity = CataloguedEntityReference(catalog=self.catalog, entity_id=old_file_uuid, @@ -418,7 +433,7 @@ def _walkthrough(v): bundle.metadata_files = _walkthrough(bundle.metadata_files) return old_file_uuid - def _num_docs_by_index(self, hits) -> Mapping[tuple[str, bool], int]: + def _num_docs_by_index(self, hits) -> Mapping[tuple[str, DocumentType], int]: return Counter(map(self._parse_index_name, hits)) def test_indexed_matrices(self): @@ -843,15 +858,13 @@ def test_indexed_matrices(self): } } matrices = {} - for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if entity_type == 'projects' and aggregate: - project_id = hit['_source']['entity_id'] - assert project_id not in matrices, project_id - matrices[project_id] = { - k: hit['_source']['contents'][k] - for k in ('matrices', 'contributed_analyses') - } + for hit in self._filter_hits(hits, DocumentType.aggregates, 'projects'): + project_id = hit['_source']['entity_id'] + assert project_id not in matrices, project_id + matrices[project_id] = { + k: hit['_source']['contents'][k] + for k in ('matrices', 'contributed_analyses') + } self.assertEqual(expected_matrices, matrices) def test_organic_matrix_bundle(self): @@ -861,7 +874,7 @@ def test_organic_matrix_bundle(self): self._index_canned_bundle(bundle) hits = self._get_all_hits() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] for file in contents['files']: if file['file_format'] == 'Rds': @@ -870,7 +883,10 @@ def test_organic_matrix_bundle(self): else: expected_source = self.translated_str_null expected_cell_count = self.translated_bool_null - if aggregate and entity_type not in ('bundles', 'files'): + if ( + doc_type is DocumentType.aggregates and + entity_type not in ('bundles', 'files') + ): expected_source = [expected_source] self.assertEqual(expected_source, file['file_source']) if 'matrix_cell_count' in file: @@ -889,9 +905,9 @@ def test_sequence_files_with_file_source(self): files = set() contributed_analyses = set() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] - if entity_type == 'files': + if entity_type == 'files' and doc_type in (DocumentType.aggregates, DocumentType.contributions): file = one(contents['files']) files.add( ( @@ -900,7 +916,7 @@ def test_sequence_files_with_file_source(self): null_bool.from_index(file['is_intermediate']) ) ) - elif entity_type == 'projects' and aggregate: + elif entity_type == 'projects' and doc_type is DocumentType.aggregates: self.assertEqual([], contents['matrices']) for file in one(contents['contributed_analyses'])['file']: contributed_analyses.add( @@ -949,10 +965,10 @@ def test_derived_files(self): self.assertEqual(len(hits), (num_files + 1 + 1 + 1 + 1) * 2) num_contribs, num_aggregates = Counter(), Counter() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) source = hit['_source'] contents = source['contents'] - if aggregate: + if doc_type is DocumentType.aggregates: num_aggregates[entity_type] += 1 bundle = one(source['bundles']) actual_fqid = self.bundle_fqid(uuid=bundle['uuid'], @@ -964,13 +980,15 @@ def test_derived_files(self): self.assertEqual(num_files, len(contents['files'])) else: self.assertEqual(num_files, sum(file['count'] for file in contents['files'])) - else: + elif doc_type is DocumentType.contributions: num_contribs[entity_type] += 1 actual_fqid = self.bundle_fqid(uuid=source['bundle_uuid'], version=source['bundle_version']) self.assertEqual(analysis_bundle, actual_fqid) self.assertEqual(1 if entity_type == 'files' else num_files, len(contents['files'])) + else: + assert False, doc_type self.assertEqual(1, len(contents['specimens'])) self.assertEqual(1, len(contents['projects'])) num_expected = dict(files=num_files, @@ -1007,7 +1025,7 @@ def _assert_old_bundle(self, num_expected_new_contributions: int = 0, num_expected_new_deleted_contributions: int = 0, ignore_aggregates: bool = False - ) -> Mapping[tuple[str, bool], JSON]: + ) -> Mapping[tuple[str, DocumentType], JSON]: """ Assert that the old bundle is still indexed correctly @@ -1028,34 +1046,44 @@ def _assert_old_bundle(self, self.assertEqual(6 + 6 + num_expected_new_contributions + num_expected_new_deleted_contributions * 2, len(hits)) hits_by_id = {} for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if aggregate and ignore_aggregates: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates and ignore_aggregates: continue source = hit['_source'] - hits_by_id[source['entity_id'], aggregate] = hit - version = one(source['bundles'])['version'] if aggregate else source['bundle_version'] - if aggregate or self.old_bundle.version == version: + hits_by_id[source['entity_id'], doc_type] = hit + if doc_type is DocumentType.aggregates: + version = one(source['bundles'])['version'] + elif doc_type is DocumentType.contributions: + version = source['bundle_version'] + else: + assert False, doc_type + if ( + doc_type is DocumentType.aggregates + or (doc_type is DocumentType.contributions and self.old_bundle.version == version) + ): contents = source['contents'] project = one(contents['projects']) self.assertEqual('Single cell transcriptome patterns.', get(project['project_title'])) self.assertEqual('Single of human pancreas', get(project['project_short_name'])) self.assertIn('John Dear', get(project['laboratory'])) - if aggregate and entity_type != 'projects': + if doc_type is DocumentType.aggregates and entity_type != 'projects': self.assertIn('Farmers Trucks', project['institutions']) - else: + elif doc_type is DocumentType.contributions: self.assertIn('Farmers Trucks', [c.get('institution') for c in project['contributors']]) donor = one(contents['donors']) self.assertIn('Australopithecus', donor['genus_species']) - if not aggregate: + if doc_type is DocumentType.contributions: self.assertFalse(source['bundle_deleted']) - else: + elif doc_type is DocumentType.contributions: if source['bundle_deleted']: num_actual_new_deleted_contributions += 1 else: self.assertLess(self.old_bundle.version, version) num_actual_new_contributions += 1 + else: + assert False, doc_type # We count the deleted contributions here too since they should have a # corresponding addition contribution self.assertEqual(num_expected_new_contributions + num_expected_new_deleted_contributions, @@ -1066,7 +1094,7 @@ def _assert_old_bundle(self, def _assert_new_bundle(self, num_expected_old_contributions: int = 0, - old_hits_by_id: Optional[Mapping[tuple[str, bool], JSON]] = None + old_hits_by_id: Optional[Mapping[tuple[str, DocumentType], JSON]] = None ) -> None: num_actual_old_contributions = 0 hits = self._get_all_hits() @@ -1074,31 +1102,40 @@ def _assert_new_bundle(self, # (two files, one project, one cell suspension, one sample and one bundle) # One contribution and one aggregate per entity self.assertEqual(6 + 6 + num_expected_old_contributions, len(hits)) + + def get_version(source, doc_type): + if doc_type is DocumentType.aggregates: + return one(source['bundles'])['version'] + elif doc_type is DocumentType.contributions: + return source['bundle_version'] + else: + assert False, doc_type + for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) source = hit['_source'] - version = one(source['bundles'])['version'] if aggregate else source['bundle_version'] + version = get_version(source, doc_type) contents = source['contents'] project = one(contents['projects']) - if not aggregate and version != self.new_bundle.version: + if doc_type is DocumentType.contributions and version != self.new_bundle.version: self.assertLess(version, self.new_bundle.version) num_actual_old_contributions += 1 continue if old_hits_by_id is not None: - old_hit = old_hits_by_id[source['entity_id'], aggregate] + old_hit = old_hits_by_id[source['entity_id'], doc_type] old_source = old_hit['_source'] - old_version = one(old_source['bundles'])['version'] if aggregate else old_source['bundle_version'] + old_version = get_version(old_source, doc_type) self.assertLess(old_version, version) old_contents = old_source['contents'] old_project = one(old_contents['projects']) self.assertNotEqual(old_project['project_title'], project['project_title']) self.assertNotEqual(old_project['project_short_name'], project['project_short_name']) self.assertNotEqual(old_project['laboratory'], project['laboratory']) - if aggregate and entity_type != 'projects': + if doc_type is DocumentType.aggregates and entity_type != 'projects': self.assertNotEqual(old_project['institutions'], project['institutions']) - else: + elif doc_type is DocumentType.contributions: self.assertNotEqual(old_project['contributors'], project['contributors']) self.assertNotEqual(old_contents['donors'][0]['genus_species'], contents['donors'][0]['genus_species']) @@ -1110,9 +1147,9 @@ def _assert_new_bundle(self, get(project['project_short_name'])) self.assertNotIn('Sarah Teichmann', project['laboratory']) self.assertIn('Molecular Atlas', project['laboratory']) - if aggregate and entity_type != 'projects': + if doc_type is DocumentType.aggregates and entity_type != 'projects': self.assertNotIn('Farmers Trucks', project['institutions']) - else: + elif doc_type is DocumentType.contributions: self.assertNotIn('Farmers Trucks', [c.get('institution') for c in project['contributors']]) @@ -1161,34 +1198,38 @@ def mocked_mget(self, body, _source_includes): # 1 samples agg + 1 projects agg + 2 cell suspension agg + 2 bundle agg + 4 file agg = 22 hits self.assertEqual(22, len(hits)) for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] - if aggregate: + if doc_type is DocumentType.aggregates: self.assertEqual(hit['_id'], hit['_source']['entity_id']) if entity_type == 'files': contents = hit['_source']['contents'] self.assertEqual(1, len(contents['files'])) - if aggregate: + if doc_type is DocumentType.aggregates: file_uuids.add(contents['files'][0]['uuid']) elif entity_type in ('samples', 'projects'): - if aggregate: + if doc_type is DocumentType.aggregates: self.assertEqual(2, len(hit['_source']['bundles'])) # All four files are fastqs so they are grouped together self.assertEqual(4, one(contents['files'])['count']) - else: + elif doc_type is DocumentType.contributions: self.assertEqual(2, len(contents['files'])) + else: + assert False, doc_type elif entity_type == 'bundles': - if aggregate: + if doc_type is DocumentType.aggregates: self.assertEqual(1, len(hit['_source']['bundles'])) self.assertEqual(2, len(contents['files'])) else: self.assertEqual(2, len(contents['files'])) elif entity_type == 'cell_suspensions': - if aggregate: + if doc_type is DocumentType.aggregates: self.assertEqual(1, len(hit['_source']['bundles'])) self.assertEqual(1, len(contents['files'])) - else: + elif doc_type is DocumentType.contributions: self.assertEqual(2, len(contents['files'])) + else: + assert False, doc_type else: self.fail() file_document_ids = set() @@ -1208,7 +1249,7 @@ def test_indexing_matrix_related_files(self): hits = self._get_all_hits() zarrs = [] for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) if entity_type == 'files': file = one(hit['_source']['contents']['files']) if len(file['related_files']) > 0: @@ -1217,7 +1258,7 @@ def test_indexing_matrix_related_files(self): elif file['file_format'] == 'matrix': # Matrix of Loom or CSV format possibly self.assertNotIn('.zarr', file['name']) - elif not aggregate: + elif doc_type is DocumentType.contributions: for file in hit['_source']['contents']['files']: self.assertEqual(file['related_files'], []) @@ -1240,9 +1281,9 @@ def test_indexing_with_skipped_matrix_file(self): file_names, aggregate_file_names = set(), set() entities_with_matrix_files = set() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) files = hit['_source']['contents']['files'] - if aggregate: + if doc_type is DocumentType.aggregates: if entity_type == 'files': aggregate_file_names.add(one(files)['name']) else: @@ -1256,10 +1297,12 @@ def test_indexing_with_skipped_matrix_file(self): if file['file_format'] == 'matrix': self.assertEqual(1, file['count']) entities_with_matrix_files.add(hit['_source']['entity_id']) - else: + elif doc_type is DocumentType.contributions: for file in files: file_name = file['name'] file_names.add(file_name) + else: + assert False, doc_type # a project, a specimen, a cell suspension and a bundle self.assertEqual(4, len(entities_with_matrix_files)) self.assertEqual(aggregate_file_names, file_names) @@ -1284,24 +1327,26 @@ def test_plate_bundle(self): expected_cell_count = 380 documents_with_cell_suspension = 0 for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] cell_suspensions = contents['cell_suspensions'] if entity_type == 'files' and contents['files'][0]['file_format'] == 'pdf': # The PDF files in that bundle aren't linked to a specimen self.assertEqual(0, len(cell_suspensions)) else: - if aggregate: + if doc_type is DocumentType.aggregates: bundles = hit['_source']['bundles'] self.assertEqual(1, len(bundles)) self.assertEqual(one(contents['sequencing_protocols'])['paired_end'], [ self.translated_bool_true, ]) - else: + elif doc_type is DocumentType.contributions: self.assertEqual( {p.get('paired_end') for p in contents['sequencing_protocols']}, {self.translated_bool_true, } ) + else: + assert False, doc_type specimens = contents['specimens'] for specimen in specimens: self.assertEqual({'bone marrow', 'temporal lobe'}, set(specimen['organ_part'])) @@ -1309,7 +1354,7 @@ def test_plate_bundle(self): self.assertEqual({'bone marrow', 'temporal lobe'}, set(cell_suspension['organ_part'])) self.assertEqual({'Plasma cells'}, set(cell_suspension['selected_cell_type'])) - self.assertEqual(1 if entity_type == 'cell_suspensions' or aggregate else 384, + self.assertEqual(1 if entity_type == 'cell_suspensions' or doc_type is DocumentType.aggregates else 384, len(cell_suspensions)) if entity_type == 'cell_suspensions': counted_cell_count += one(cell_suspensions)['total_estimated_cells'] @@ -1339,8 +1384,8 @@ def test_well_bundles(self): self.assertGreater(len(hits), 0) for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: cell_suspensions = contents['cell_suspensions'] self.assertEqual(1, len(cell_suspensions)) # Each bundle contributes a well with one cell. The data files @@ -1352,9 +1397,11 @@ def test_well_bundles(self): self.assertEqual(expected_cells, cell_suspensions[0]['total_estimated_cells']) self.assertEqual(one(contents['analysis_protocols'])['workflow'], ['smartseq2_v2.1.0']) - else: + elif doc_type is DocumentType.contributions: self.assertEqual({p['workflow'] for p in contents['analysis_protocols']}, {'smartseq2_v2.1.0'}) + else: + assert False, doc_type def test_pooled_specimens(self): """ @@ -1368,8 +1415,8 @@ def test_pooled_specimens(self): hits = self._get_all_hits() self.assertGreater(len(hits), 0) for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: contents = hit['_source']['contents'] cell_suspensions = contents['cell_suspensions'] self.assertEqual(1, len(cell_suspensions)) @@ -1423,7 +1470,8 @@ def test_organoid_priority(self): for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates if entity_type != 'files' or one(contents['files'])['file_format'] != 'pdf': inner_cell_suspensions += len(contents['cell_suspensions']) @@ -1477,7 +1525,7 @@ def test_accessions_fields(self): 'array_express': ['E-AAAA-00'], 'insdc_study': ['PRJNA000000'] } - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) if entity_type == 'projects': expected_accessions = [ {'namespace': namespace, 'accession': accession} @@ -1507,17 +1555,17 @@ def test_cell_counts(self): ] actual = NestedDict(2, list) for hit in sorted(hits, key=lambda d: d['_id']): - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] for inner_entity_type, field_name in field_paths: for inner_entity in contents[inner_entity_type]: value = inner_entity[field_name] - insort(actual[aggregate][entity_type][inner_entity_type], value) + insort(actual[doc_type][entity_type][inner_entity_type], value) expected = NestedDict(1, dict) - for aggregate in False, True: + for doc_type in DocumentType.contributions, DocumentType.aggregates: for entity_type in self.index_service.entity_types(self.catalog): - expected[aggregate][entity_type] = { + expected[doc_type][entity_type] = { 'cell_suspensions': [0, 20000, 20000], 'files': [2100, 15000, 15000], 'projects': [10000, 10000, 10000] @@ -1526,8 +1574,8 @@ def test_cell_counts(self): 'cell_suspensions': [40000], 'files': [17100], 'projects': [10000] - } if aggregate and entity_type == 'projects' else { - 'cell_suspensions': [20000, 20000] if aggregate else [0, 20000, 20000], + } if doc_type is DocumentType.aggregates and entity_type == 'projects' else { + 'cell_suspensions': [20000, 20000] if doc_type is DocumentType.aggregates else [0, 20000, 20000], 'files': [2100, 15000], 'projects': [10000, 10000] } @@ -1536,9 +1584,7 @@ def test_cell_counts(self): def test_no_cell_count_contributions(self): def assert_cell_suspension(expected: JSON, hits: list[JSON]): - project_hit = one(hit - for hit in hits - if ('projects', True) == self._parse_index_name(hit)) + project_hit = one(self._filter_hits(hits, DocumentType.aggregates, 'projects')) contents = project_hit['_source']['contents'] cell_suspension = cast(JSON, one(contents['cell_suspensions'])) actual_result = { @@ -1590,17 +1636,22 @@ def test_imaging_bundle(self): hits = self._get_all_hits() sources = defaultdict(list) for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - sources[entity_type, aggregate].append(hit['_source']) + entity_type, doc_type = self._parse_index_name(hit) + sources[entity_type, doc_type].append(hit['_source']) # bundle has 240 imaging_protocol_0.json['target'] items, each with # an assay_type of 'in situ sequencing' - assay_type = ['in situ sequencing'] if aggregate else {'in situ sequencing': 240} + if doc_type is DocumentType.aggregates: + assay_type = ['in situ sequencing'] + elif doc_type is DocumentType.contributions: + assay_type = {'in situ sequencing': 240} + else: + assert False, doc_type self.assertEqual( one(hit['_source']['contents']['imaging_protocols'])['assay_type'], assay_type ) - for aggregate in True, False: - with self.subTest(aggregate=aggregate): + for doc_type in DocumentType.contributions, DocumentType.aggregates: + with self.subTest(doc_type=doc_type): self.assertEqual( { 'bundles': 1, @@ -1610,15 +1661,15 @@ def test_imaging_bundle(self): }, { entity_type: len(sources) - for (entity_type, _aggregate), sources in sources.items() - if _aggregate is aggregate + for (entity_type, _doc_type), sources in sources.items() + if _doc_type is doc_type } ) # This imaging bundle contains 6 data files in JSON format self.assertEqual( Counter({'tiff': 221, 'json': 6}), Counter(one(source['contents']['files'])['file_format'] - for source in sources['files', aggregate]) + for source in sources['files', doc_type]) ) def test_cell_line_sample(self): @@ -1635,13 +1686,15 @@ def test_cell_line_sample(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates + contribution = doc_type is DocumentType.contributions if entity_type == 'samples': sample = one(contents['samples']) sample_entity_type = sample['entity_type'] if aggregate: document_ids = one(contents[sample_entity_type])['document_id'] - else: + elif contribution: document_ids = [d['document_id'] for d in contents[sample_entity_type]] entity = one( d @@ -1649,6 +1702,8 @@ def test_cell_line_sample(self): if d['document_id'] == sample['document_id'] ) self.assertEqual(sample['biomaterial_id'], entity['biomaterial_id']) + else: + assert False, doc_type self.assertTrue(sample['document_id'] in document_ids) self.assertEqual(one(contents['specimens'])['organ'], ['blood'] if aggregate else 'blood') @@ -1656,8 +1711,11 @@ def test_cell_line_sample(self): self.assertEqual(len(contents['cell_lines']), 1 if aggregate else 2) if aggregate: cell_lines_model_organ = set(one(contents['cell_lines'])['model_organ']) - else: + elif contribution: cell_lines_model_organ = {cl['model_organ'] for cl in contents['cell_lines']} + else: + assert False, doc_type + self.assertEqual(cell_lines_model_organ, {'blood (parent_cell_line)', 'blood (child_cell_line)'}) self.assertEqual(one(contents['cell_suspensions'])['organ'], @@ -1677,16 +1735,18 @@ def test_multiple_samples(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) cell_suspension = one(contents['cell_suspensions']) self.assertEqual(cell_suspension['organ'], ['embryo', 'immune system']) self.assertEqual(cell_suspension['organ_part'], ['skin epidermis', self.translated_str_null]) - if aggregate and entity_type != 'samples': + if doc_type is DocumentType.aggregates and entity_type != 'samples': self.assertEqual(one(contents['samples'])['entity_type'], sample_entity_types) - else: + elif doc_type in (DocumentType.aggregates, DocumentType.contributions): for sample in contents['samples']: self.assertIn(sample['entity_type'], sample_entity_types) + else: + assert False, doc_type def test_sample_with_no_donor(self): """ @@ -1739,11 +1799,11 @@ def test_sample_with_no_donor(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) if entity_type == 'projects': - if aggregate: + if doc_type is DocumentType.aggregates: self.assertElasticEqual([aggregate_donor], contents['donors']) - else: + elif doc_type is DocumentType.contributions: sample_id = one(contents['samples'])['document_id'] if sample_id == '70d2b85a-8055-4027-a0d9-29452a49d668': self.assertEqual([donor], contents['donors']) @@ -1755,6 +1815,8 @@ def test_sample_with_no_donor(self): self.assertEqual([] if True else [donor_none], contents['donors']) else: assert False, sample_id + else: + assert False, doc_type def test_files_content_description(self): bundle_fqid = self.bundle_fqid(uuid='ffac201f-4b1c-4455-bd58-19c1a9e863b4', @@ -1763,13 +1825,15 @@ def test_files_content_description(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: # bundle aggregates keep individual files num_inner_files = 2 if entity_type == 'bundles' else 1 - else: + elif doc_type is DocumentType.contributions: # one inner file per file contribution num_inner_files = 1 if entity_type == 'files' else 2 + else: + assert False, doc_type self.assertEqual(len(contents['files']), num_inner_files) for file in contents['files']: self.assertEqual(file['content_description'], ['RNA sequence']) @@ -1782,18 +1846,16 @@ def test_related_files_field_exclusion(self): # Check that the dynamic mapping has the related_files field disabled index = config.es_index_name(catalog=self.catalog, entity_type='files', - aggregate=False) + doc_type=DocumentType.aggregates) mapping = self.es_client.indices.get_mapping(index=index) contents = mapping[index]['mappings']['properties']['contents'] self.assertFalse(contents['properties']['files']['properties']['related_files']['enabled']) # Ensure that related_files exists hits = self._get_all_hits() - for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if aggregate and entity_type == 'files': - file = one(hit['_source']['contents']['files']) - self.assertIn('related_files', file) + for hit in self._filter_hits(hits, DocumentType.aggregates, 'files'): + file = one(hit['_source']['contents']['files']) + self.assertIn('related_files', file) # … but that it can't be used for queries zattrs_file = '377f2f5a-4a45-4c62-8fb0-db9ef33f5cf0.zarr/.zattrs' @@ -1827,10 +1889,8 @@ def test_downstream_entities(self): self._index_bundle(bundle) def get_aggregates(hits, type): - for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if entity_type == type and aggregate: - yield hit['_source']['contents'] + for hit in self._filter_hits(hits, DocumentType.aggregates, type): + yield hit['_source']['contents'] hits = self._get_all_hits() samples = list(get_aggregates(hits, 'samples')) diff --git a/test/indexer/test_projects.py b/test/indexer/test_projects.py index bed58d31d..3e1e72ffc 100644 --- a/test/indexer/test_projects.py +++ b/test/indexer/test_projects.py @@ -5,6 +5,7 @@ ) from azul import ( + DocumentType, config, ) from azul.es import ( @@ -66,9 +67,10 @@ def test_hca_extraction(self): for aggregate in True, False: with self.subTest(aggregate=aggregate): def index_name(entity_type): + doc_type = DocumentType.aggregates if aggregate else DocumentType.contributions return config.es_index_name(catalog=self.catalog, entity_type=entity_type, - aggregate=aggregate) + doc_type=doc_type) total_projects = self.es_client.count(index=index_name('projects')) # Three unique projects, six project contributions diff --git a/test/service/__init__.py b/test/service/__init__.py index c843251f5..b5f917b41 100644 --- a/test/service/__init__.py +++ b/test/service/__init__.py @@ -31,6 +31,7 @@ LocalAppTestCase, ) from azul import ( + DocumentType, JSON, cached_property, config, @@ -183,7 +184,7 @@ def _add_docs(self, num_docs): def _index_name(self): return config.es_index_name(catalog=self.catalog, entity_type='files', - aggregate=True) + doc_type=DocumentType.aggregates) class StorageServiceTestMixin: diff --git a/test/service/test_response.py b/test/service/test_response.py index 8024dc191..3ed59c321 100644 --- a/test/service/test_response.py +++ b/test/service/test_response.py @@ -45,6 +45,7 @@ LocalAppTestCase, ) from azul import ( + DocumentType, cached_property, config, ) @@ -152,7 +153,7 @@ def _get_hits(self, entity_type: str, entity_id: str): # Tests are assumed to only ever run with the azul dev index results = self.es_client.search(index=config.es_index_name(catalog=self.catalog, entity_type=entity_type, - aggregate=True), + doc_type=DocumentType.aggregates), body=body) return self._index_service.translate_fields(catalog=self.catalog, doc=[results['hits']['hits'][0]['_source']],