From 7e180dcdb4280ec48fb9958c55218367ee9ce1f4 Mon Sep 17 00:00:00 2001 From: Noa Aviel Dove Date: Wed, 23 Aug 2023 21:43:14 -0700 Subject: [PATCH] fixup! Refactor indices to accomodate replicas (#5358) --- test/indexer/test_indexer.py | 224 ++++++++++++++++++++--------------- 1 file changed, 131 insertions(+), 93 deletions(-) diff --git a/test/indexer/test_indexer.py b/test/indexer/test_indexer.py index 03ecaaabe..b7dbc0483 100644 --- a/test/indexer/test_indexer.py +++ b/test/indexer/test_indexer.py @@ -17,6 +17,7 @@ ) import re from typing import ( + Iterable, Optional, cast, ) @@ -55,6 +56,7 @@ Contribution, ContributionCoordinates, EntityReference, + EntityType, null_bool, null_int, null_str, @@ -173,16 +175,18 @@ def test_deletion(self): self.assertEqual(len(hits), size * 2) num_aggregates, num_contribs = 0, 0 for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: doc = aggregate_cls.from_index(field_types, hit) self.assertNotEqual(doc.contents, {}) num_aggregates += 1 - else: + elif doc_type is DocumentType.contributions: doc = Contribution.from_index(field_types, hit) self.assertEqual(bundle_fqid.upcast(), doc.coordinates.bundle) self.assertFalse(doc.coordinates.deleted) num_contribs += 1 + else: + assert False, doc_type self.assertEqual(num_aggregates, size) self.assertEqual(num_contribs, size) @@ -193,11 +197,12 @@ def test_deletion(self): self.assertEqual(len(hits), 2 * size) docs_by_entity: dict[EntityReference, list[Contribution]] = defaultdict(list) for hit in hits: + # FIXME: Why parse the hit as a contribution before asserting doc type? doc = Contribution.from_index(field_types, hit) docs_by_entity[doc.entity].append(doc) - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) # Since there is only one bundle and it was deleted, nothing should be aggregated - self.assertFalse(aggregate) + self.assertNotEqual(doc_type, DocumentType.aggregates) self.assertEqual(bundle_fqid.upcast(), doc.coordinates.bundle) for pair in docs_by_entity.values(): @@ -206,9 +211,19 @@ def test_deletion(self): self.index_service.delete_indices(self.catalog) self.index_service.create_indices(self.catalog) - def _parse_index_name(self, hit) -> tuple[str, bool]: + def _parse_index_name(self, hit) -> tuple[str, DocumentType]: index_name = config.parse_es_index_name(hit['_index']) - return index_name.entity_type, index_name.doc_type is DocumentType.aggregates + return index_name.entity_type, index_name.doc_type + + def _filter_hits(self, + hits: JSONs, + doc_type: Optional[DocumentType] = None, + entity_type: Optional[EntityType] = None, + ) -> Iterable[JSON]: + for hit in hits: + hit_entity_type, hit_doc_type = self._parse_index_name(hit) + if entity_type in (None, hit_entity_type) and doc_type in (None, hit_doc_type): + yield hit def test_duplicate_notification(self): # Contribute the bundle once @@ -299,11 +314,7 @@ def _assert_index_counts(self, just_deletion): actual_addition_contributions = [h for h in hits if not h['_source']['bundle_deleted']] actual_deletion_contributions = [h for h in hits if h['_source']['bundle_deleted']] - def is_aggregate(h): - _, aggregate_ = self._parse_index_name(h) - return aggregate_ - - actual_aggregates = [h for h in hits if is_aggregate(h)] + actual_aggregates = list(self._filter_hits(hits, DocumentType.aggregates)) self.assertEqual(len(actual_addition_contributions), num_expected_addition_contributions) self.assertEqual(len(actual_deletion_contributions), num_expected_deletion_contributions) @@ -347,23 +358,26 @@ def test_multi_entity_contributing_bundles(self): hits_after = self._get_all_hits() num_docs_by_index_after = self._num_docs_by_index(hits_after) - for entity_type, aggregate in num_docs_by_index_after.keys(): + for entity_type, doc_type in num_docs_by_index_after.keys(): # Both bundles reference two files. They both share one file # and exclusively own another one. Deleting one of the bundles removes the file owned exclusively by # that bundle, as well as the bundle itself. - if aggregate: + if doc_type is DocumentType.aggregates: difference = 1 if entity_type in ('files', 'bundles') else 0 - self.assertEqual(num_docs_by_index_after[entity_type, aggregate], - num_docs_by_index_before[entity_type, aggregate] - difference) - elif entity_type in ('bundles', 'samples', 'projects', 'cell_suspensions'): - # Count one extra deletion contribution - self.assertEqual(num_docs_by_index_after[entity_type, aggregate], - num_docs_by_index_before[entity_type, aggregate] + 1) + self.assertEqual(num_docs_by_index_after[entity_type, doc_type], + num_docs_by_index_before[entity_type, doc_type] - difference) + elif doc_type is DocumentType.contributions: + if entity_type in ('bundles', 'samples', 'projects', 'cell_suspensions'): + # Count one extra deletion contribution + self.assertEqual(num_docs_by_index_after[entity_type, doc_type], + num_docs_by_index_before[entity_type, doc_type] + 1) + else: + # Count two extra deletion contributions for the two files + self.assertEqual(entity_type, 'files') + self.assertEqual(num_docs_by_index_after[entity_type, doc_type], + num_docs_by_index_before[entity_type, doc_type] + 2) else: - # Count two extra deletion contributions for the two files - self.assertEqual(entity_type, 'files') - self.assertEqual(num_docs_by_index_after[entity_type, aggregate], - num_docs_by_index_before[entity_type, aggregate] + 2) + assert False, doc_type entity = CataloguedEntityReference(catalog=self.catalog, entity_id=old_file_uuid, @@ -404,7 +418,7 @@ def _walkthrough(v): bundle.metadata_files = _walkthrough(bundle.metadata_files) return old_file_uuid - def _num_docs_by_index(self, hits) -> Mapping[tuple[str, bool], int]: + def _num_docs_by_index(self, hits) -> Mapping[tuple[str, DocumentType], int]: return Counter(map(self._parse_index_name, hits)) def test_indexed_matrices(self): @@ -829,15 +843,13 @@ def test_indexed_matrices(self): } } matrices = {} - for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if entity_type == 'projects' and aggregate: - project_id = hit['_source']['entity_id'] - assert project_id not in matrices, project_id - matrices[project_id] = { - k: hit['_source']['contents'][k] - for k in ('matrices', 'contributed_analyses') - } + for hit in self._filter_hits(hits, DocumentType.aggregates, 'projects'): + project_id = hit['_source']['entity_id'] + assert project_id not in matrices, project_id + matrices[project_id] = { + k: hit['_source']['contents'][k] + for k in ('matrices', 'contributed_analyses') + } self.assertEqual(expected_matrices, matrices) def test_organic_matrix_bundle(self): @@ -847,7 +859,7 @@ def test_organic_matrix_bundle(self): self._index_canned_bundle(bundle) hits = self._get_all_hits() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] for file in contents['files']: if file['file_format'] == 'Rds': @@ -856,7 +868,10 @@ def test_organic_matrix_bundle(self): else: expected_source = self.translated_str_null expected_cell_count = self.translated_bool_null - if aggregate and entity_type not in ('bundles', 'files'): + if ( + doc_type is DocumentType.aggregates and + entity_type not in ('bundles', 'files') + ): expected_source = [expected_source] self.assertEqual(expected_source, file['file_source']) if 'matrix_cell_count' in file: @@ -875,7 +890,7 @@ def test_sequence_files_with_file_source(self): files = set() contributed_analyses = set() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] if entity_type == 'files': file = one(contents['files']) @@ -886,7 +901,7 @@ def test_sequence_files_with_file_source(self): null_bool.from_index(file['is_intermediate']) ) ) - elif entity_type == 'projects' and aggregate: + elif entity_type == 'projects' and doc_type is DocumentType.aggregates: self.assertEqual([], contents['matrices']) for file in one(contents['contributed_analyses'])['file']: contributed_analyses.add( @@ -934,10 +949,10 @@ def test_derived_files(self): self.assertEqual(len(hits), (num_files + 1 + 1 + 1 + 1) * 2) num_contribs, num_aggregates = Counter(), Counter() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) source = hit['_source'] contents = source['contents'] - if aggregate: + if doc_type is DocumentType.aggregates: num_aggregates[entity_type] += 1 bundle = one(source['bundles']) actual_fqid = self.bundle_fqid(uuid=bundle['uuid'], @@ -949,12 +964,14 @@ def test_derived_files(self): self.assertEqual(num_files, len(contents['files'])) else: self.assertEqual(num_files, sum(file['count'] for file in contents['files'])) - else: + elif doc_type is DocumentType.contributions: num_contribs[entity_type] += 1 actual_fqid = self.bundle_fqid(uuid=source['bundle_uuid'], version=source['bundle_version']) self.assertEqual(analysis_bundle, actual_fqid) self.assertEqual(1 if entity_type == 'files' else num_files, len(contents['files'])) + else: + assert False, doc_type self.assertEqual(1, len(contents['specimens'])) self.assertEqual(1, len(contents['projects'])) num_expected = dict(files=num_files, samples=1, cell_suspensions=1, projects=1, bundles=1) @@ -984,7 +1001,7 @@ def _assert_old_bundle(self, num_expected_new_contributions: int = 0, num_expected_new_deleted_contributions: int = 0, ignore_aggregates=False - ) -> Mapping[tuple[str, bool], JSON]: + ) -> Mapping[tuple[str, DocumentType], JSON]: """ Assert that the old bundle is still indexed correctly @@ -1004,13 +1021,15 @@ def _assert_old_bundle(self, self.assertEqual(6 + 6 + num_expected_new_contributions + num_expected_new_deleted_contributions * 2, len(hits)) hits_by_id = {} for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates + contribution = doc_type is DocumentType.contributions if aggregate and ignore_aggregates: continue source = hit['_source'] - hits_by_id[source['entity_id'], aggregate] = hit + hits_by_id[source['entity_id'], doc_type] = hit version = one(source['bundles'])['version'] if aggregate else source['bundle_version'] - if aggregate or self.old_bundle.version == version: + if aggregate or (contribution and self.old_bundle.version == version): contents = source['contents'] project = one(contents['projects']) self.assertEqual('Single cell transcriptome patterns.', get(project['project_title'])) @@ -1024,12 +1043,14 @@ def _assert_old_bundle(self, self.assertIn('Australopithecus', donor['genus_species']) if not aggregate: self.assertFalse(source['bundle_deleted']) - else: + elif contribution: if source['bundle_deleted']: num_actual_new_deleted_contributions += 1 else: self.assertLess(self.old_bundle.version, version) num_actual_new_contributions += 1 + else: + assert False, doc_type # We count the deleted contributions here too since they should have a corresponding addition contribution self.assertEqual(num_expected_new_contributions + num_expected_new_deleted_contributions, num_actual_new_contributions) @@ -1038,7 +1059,7 @@ def _assert_old_bundle(self, def _assert_new_bundle(self, num_expected_old_contributions: int = 0, - old_hits_by_id: Optional[Mapping[tuple[str, bool], JSON]] = None + old_hits_by_id: Optional[Mapping[tuple[str, DocumentType], JSON]] = None ) -> None: num_actual_old_contributions = 0 hits = self._get_all_hits() @@ -1046,7 +1067,8 @@ def _assert_new_bundle(self, # One contribution and one aggregate per entity self.assertEqual(6 + 6 + num_expected_old_contributions, len(hits)) for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates source = hit['_source'] version = one(source['bundles'])['version'] if aggregate else source['bundle_version'] contents = source['contents'] @@ -1130,7 +1152,8 @@ def mocked_mget(self, body, _source_includes): # 1 samples agg + 1 projects agg + 2 cell suspension agg + 2 bundle agg + 4 file agg = 22 hits self.assertEqual(22, len(hits)) for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates contents = hit['_source']['contents'] if aggregate: self.assertEqual(hit['_id'], hit['_source']['entity_id']) @@ -1177,7 +1200,7 @@ def test_indexing_matrix_related_files(self): hits = self._get_all_hits() zarrs = [] for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) if entity_type == 'files': file = one(hit['_source']['contents']['files']) if len(file['related_files']) > 0: @@ -1186,7 +1209,7 @@ def test_indexing_matrix_related_files(self): elif file['file_format'] == 'matrix': # Matrix of Loom or CSV format possibly self.assertNotIn('.zarr', file['name']) - elif not aggregate: + elif doc_type is DocumentType.contributions: for file in hit['_source']['contents']['files']: self.assertEqual(file['related_files'], []) @@ -1209,9 +1232,9 @@ def test_indexing_with_skipped_matrix_file(self): file_names, aggregate_file_names = set(), set() entities_with_matrix_files = set() for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) files = hit['_source']['contents']['files'] - if aggregate: + if doc_type is DocumentType.aggregates: if entity_type == 'files': aggregate_file_names.add(one(files)['name']) else: @@ -1225,10 +1248,12 @@ def test_indexing_with_skipped_matrix_file(self): if file['file_format'] == 'matrix': self.assertEqual(1, file['count']) entities_with_matrix_files.add(hit['_source']['entity_id']) - else: + elif doc_type is DocumentType.contributions: for file in files: file_name = file['name'] file_names.add(file_name) + else: + assert False, doc_type self.assertEqual(4, len(entities_with_matrix_files)) # a project, a specimen, a cell suspension and a bundle self.assertEqual(aggregate_file_names, file_names) matrix_file_names = {file_name for file_name in file_names if '.zarr/' in file_name} @@ -1246,30 +1271,33 @@ def test_plate_bundle(self): expected_cell_count = 380 # 384 wells in total, four of them empty, the rest with a single cell documents_with_cell_suspension = 0 for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) contents = hit['_source']['contents'] cell_suspensions = contents['cell_suspensions'] if entity_type == 'files' and contents['files'][0]['file_format'] == 'pdf': # The PDF files in that bundle aren't linked to a specimen self.assertEqual(0, len(cell_suspensions)) else: - if aggregate: + if doc_type is DocumentType.aggregates: bundles = hit['_source']['bundles'] self.assertEqual(1, len(bundles)) self.assertEqual(one(contents['sequencing_protocols'])['paired_end'], [ self.translated_bool_true, ]) - else: + elif doc_type is DocumentType.contributions: self.assertEqual({p.get('paired_end') for p in contents['sequencing_protocols']}, { self.translated_bool_true, }) + else: + assert False specimens = contents['specimens'] for specimen in specimens: self.assertEqual({'bone marrow', 'temporal lobe'}, set(specimen['organ_part'])) for cell_suspension in cell_suspensions: self.assertEqual({'bone marrow', 'temporal lobe'}, set(cell_suspension['organ_part'])) self.assertEqual({'Plasma cells'}, set(cell_suspension['selected_cell_type'])) - self.assertEqual(1 if entity_type == 'cell_suspensions' or aggregate else 384, len(cell_suspensions)) + self.assertEqual(1 if entity_type == 'cell_suspensions' or doc_type is DocumentType.aggregates else 384, + len(cell_suspensions)) if entity_type == 'cell_suspensions': counted_cell_count += one(cell_suspensions)['total_estimated_cells'] else: @@ -1294,8 +1322,8 @@ def test_well_bundles(self): self.assertGreater(len(hits), 0) for hit in hits: contents = hit["_source"]['contents'] - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: cell_suspensions = contents['cell_suspensions'] self.assertEqual(1, len(cell_suspensions)) # Each bundle contributes a well with one cell. The data files in each bundle are derived from @@ -1304,8 +1332,10 @@ def test_well_bundles(self): expected_cells = 1 if entity_type in ('files', 'cell_suspensions', 'bundles') else 2 self.assertEqual(expected_cells, cell_suspensions[0]['total_estimated_cells']) self.assertEqual(one(contents['analysis_protocols'])['workflow'], ['smartseq2_v2.1.0']) - else: + elif doc_type is DocumentType.contributions: self.assertEqual({p['workflow'] for p in contents['analysis_protocols']}, {'smartseq2_v2.1.0'}) + else: + assert False, doc_type def test_pooled_specimens(self): """ @@ -1319,8 +1349,8 @@ def test_pooled_specimens(self): hits = self._get_all_hits() self.assertGreater(len(hits), 0) for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: contents = hit["_source"]['contents'] cell_suspensions = contents['cell_suspensions'] self.assertEqual(1, len(cell_suspensions)) @@ -1368,7 +1398,8 @@ def test_organoid_priority(self): for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates if entity_type != 'files' or one(contents['files'])['file_format'] != 'pdf': inner_cell_suspensions += len(contents['cell_suspensions']) @@ -1415,7 +1446,7 @@ def test_accessions_fields(self): 'insdc_project': ['SRP000000', 'SRP000001'], 'insdc_study': ['PRJNA000000'] } - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) if entity_type == 'project': expected_accessions = [ {'namespace': namespace, 'accession': accession} @@ -1445,7 +1476,8 @@ def test_cell_counts(self): ] actual = NestedDict(2, list) for hit in sorted(hits, key=lambda d: d['_id']): - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates contents = hit['_source']['contents'] for inner_entity_type, field_name in field_paths: for inner_entity in contents[inner_entity_type]: @@ -1474,9 +1506,7 @@ def test_cell_counts(self): def test_no_cell_count_contributions(self): def assert_cell_suspension(expected: JSON, hits: list[JSON]): - project_hit = one(hit - for hit in hits - if ('projects', True) == self._parse_index_name(hit)) + project_hit = one(self._filter_hits(hits, DocumentType.aggregates, 'projects')) contents = project_hit['_source']['contents'] cell_suspension = cast(JSON, one(contents['cell_suspensions'])) actual_result = { @@ -1528,10 +1558,10 @@ def test_imaging_bundle(self): hits = self._get_all_hits() sources = defaultdict(list) for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - sources[entity_type, aggregate].append(hit['_source']) + entity_type, doc_type = self._parse_index_name(hit) + sources[entity_type, doc_type].append(hit['_source']) # bundle has 240 imaging_protocol_0.json['target'] items, each with an assay_type of 'in situ sequencing' - assay_type = ['in situ sequencing'] if aggregate else {'in situ sequencing': 240} + assay_type = ['in situ sequencing'] if doc_type is DocumentType.aggregates else {'in situ sequencing': 240} self.assertEqual(one(hit['_source']['contents']['imaging_protocols'])['assay_type'], assay_type) for aggregate in True, False: with self.subTest(aggregate=aggregate): @@ -1568,24 +1598,30 @@ def test_cell_line_sample(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) + aggregate = doc_type is DocumentType.aggregates + contribution = doc_type is DocumentType.contributions if entity_type == 'samples': sample = one(contents['samples']) sample_entity_type = sample['entity_type'] if aggregate: document_ids = one(contents[sample_entity_type])['document_id'] - else: + elif contribution: document_ids = [d['document_id'] for d in contents[sample_entity_type]] entity = one([d for d in contents[sample_entity_type] if d['document_id'] == sample['document_id']]) self.assertEqual(sample['biomaterial_id'], entity['biomaterial_id']) + else: + assert False, doc_type self.assertTrue(sample['document_id'] in document_ids) self.assertEqual(one(contents['specimens'])['organ'], ['blood'] if aggregate else 'blood') self.assertEqual(one(contents['specimens'])['organ_part'], ['venous blood']) self.assertEqual(len(contents['cell_lines']), 1 if aggregate else 2) if aggregate: cell_lines_model_organ = set(one(contents['cell_lines'])['model_organ']) - else: + elif contribution: cell_lines_model_organ = {cl['model_organ'] for cl in contents['cell_lines']} + else: + assert False, doc_type self.assertEqual(cell_lines_model_organ, {'blood (parent_cell_line)', 'blood (child_cell_line)'}) self.assertEqual(one(contents['cell_suspensions'])['organ'], ['blood (child_cell_line)']) self.assertEqual(one(contents['cell_suspensions'])['organ_part'], [self.translated_str_null]) @@ -1602,15 +1638,17 @@ def test_multiple_samples(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) cell_suspension = one(contents['cell_suspensions']) self.assertEqual(cell_suspension['organ'], ['embryo', 'immune system']) self.assertEqual(cell_suspension['organ_part'], ['skin epidermis', self.translated_str_null]) - if aggregate and entity_type != 'samples': + if doc_type is DocumentType.aggregates and entity_type != 'samples': self.assertEqual(one(contents['samples'])['entity_type'], sample_entity_types) - else: + elif doc_type is DocumentType.contributions: for sample in contents['samples']: self.assertIn(sample['entity_type'], sample_entity_types) + else: + assert False, doc_type def test_sample_with_no_donor(self): """ @@ -1663,11 +1701,11 @@ def test_sample_with_no_donor(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) + entity_type, doc_type = self._parse_index_name(hit) if entity_type == 'projects': - if aggregate: + if doc_type is DocumentType.aggregates: self.assertElasticEqual([aggregate_donor], contents['donors']) - else: + elif doc_type is DocumentType.contributions: sample_id = one(contents['samples'])['document_id'] if sample_id == '70d2b85a-8055-4027-a0d9-29452a49d668': self.assertEqual([donor], contents['donors']) @@ -1679,6 +1717,8 @@ def test_sample_with_no_donor(self): self.assertEqual([] if True else [donor_none], contents['donors']) else: assert False, sample_id + else: + assert False, doc_type def test_files_content_description(self): bundle_fqid = self.bundle_fqid(uuid='ffac201f-4b1c-4455-bd58-19c1a9e863b4', @@ -1687,13 +1727,15 @@ def test_files_content_description(self): hits = self._get_all_hits() for hit in hits: contents = hit['_source']['contents'] - entity_type, aggregate = self._parse_index_name(hit) - if aggregate: + entity_type, doc_type = self._parse_index_name(hit) + if doc_type is DocumentType.aggregates: # bundle aggregates keep individual files num_inner_files = 2 if entity_type == 'bundles' else 1 - else: + elif doc_type is DocumentType.contributions: # one inner file per file contribution num_inner_files = 1 if entity_type == 'files' else 2 + else: + assert False, doc_type self.assertEqual(len(contents['files']), num_inner_files) for file in contents['files']: self.assertEqual(file['content_description'], ['RNA sequence']) @@ -1713,11 +1755,9 @@ def test_related_files_field_exclusion(self): # Ensure that related_files exists hits = self._get_all_hits() - for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if aggregate and entity_type == 'files': - file = one(hit['_source']['contents']['files']) - self.assertIn('related_files', file) + for hit in self._filter_hits(hits, DocumentType.aggregates, 'files'): + file = one(hit['_source']['contents']['files']) + self.assertIn('related_files', file) # … but that it can't be used for queries zattrs_file = '377f2f5a-4a45-4c62-8fb0-db9ef33f5cf0.zarr/.zattrs' @@ -1751,10 +1791,8 @@ def test_downstream_entities(self): self._index_bundle(bundle) def get_aggregates(hits, type): - for hit in hits: - entity_type, aggregate = self._parse_index_name(hit) - if entity_type == type and aggregate: - yield hit['_source']['contents'] + for hit in self._filter_hits(hits, DocumentType.aggregates, type): + yield hit['_source']['contents'] hits = self._get_all_hits() samples = list(get_aggregates(hits, 'samples'))