From bf90332dddd57398cd183d4e4d2b9de817efd534 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Tue, 17 Sep 2024 13:27:56 -0700 Subject: [PATCH] Fix for cases where tokenization failed (#54) * Append None to structure_contexts on exception * Assign chain_id based on non-null contexts * Add test case * Apply suggestions from code review Co-authored-by: Alex Rogozhnikov * Add comment --------- Co-authored-by: Alex Rogozhnikov --- chai_lab/data/dataset/inference_dataset.py | 18 ++++++++---- tests/test_inference_dataset.py | 33 ++++++++++++++++++++++ 2 files changed, 45 insertions(+), 6 deletions(-) create mode 100644 tests/test_inference_dataset.py diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py index b2bc9d5..e2189ef 100644 --- a/chai_lab/data/dataset/inference_dataset.py +++ b/chai_lab/data/dataset/inference_dataset.py @@ -10,6 +10,9 @@ AllAtomResidueTokenizer, _make_sym_ids, ) +from chai_lab.data.dataset.structure.all_atom_structure_context import ( + AllAtomStructureContext, +) from chai_lab.data.dataset.structure.chain import Chain from chai_lab.data.parsing.fasta import get_residue_name, read_fasta from chai_lab.data.parsing.input_validation import ( @@ -164,19 +167,22 @@ def load_chains_from_raw( ) # Tokenize the entity data - structure_contexts = [] + structure_contexts: list[AllAtomStructureContext | None] = [] sym_ids = _make_sym_ids([x.entity_id for x in entities]) - for idx, (entity_data, sym_id) in enumerate(zip(entities, sym_ids)): + for entity_data, sym_id in zip(entities, sym_ids): + # chain index should not count null contexts that result from failed tokenization + chain_index = sum(ctx is not None for ctx in structure_contexts) + 1 try: tok = tokenizer._tokenize_entity( entity_data, - chain_id=idx + 1, + chain_id=chain_index, sym_id=sym_id, ) - structure_contexts.append(tok) except Exception: - logger.exception(f"Failed to tokenize input {inputs[idx]}") - + logger.exception(f"Failed to tokenize input {entity_data=} {sym_id=}") + tok = None + structure_contexts.append(tok) + assert len(structure_contexts) == len(entities) # Join the untokenized entity data with the tokenized chain data, removing # chains we failed to tokenize chains = [ diff --git a/tests/test_inference_dataset.py b/tests/test_inference_dataset.py new file mode 100644 index 0000000..d803cde --- /dev/null +++ b/tests/test_inference_dataset.py @@ -0,0 +1,33 @@ +""" +Tests for inference dataset. +""" + +from chai_lab.data.dataset.inference_dataset import Input, load_chains_from_raw +from chai_lab.data.dataset.structure.all_atom_residue_tokenizer import ( + AllAtomResidueTokenizer, +) +from chai_lab.data.parsing.structure.entity_type import EntityType +from chai_lab.data.sources.rdkit import RefConformerGenerator + + +def test_malformed_smiles(): + """Malformed SMILES should be dropped.""" + # Zn ligand is malformed (should be [Zn+2]) + inputs = [ + Input("RKDESES", entity_type=EntityType.PROTEIN.value, entity_name="foo"), + Input("Zn", entity_type=EntityType.LIGAND.value, entity_name="bar"), + Input("RKEEE", entity_type=EntityType.PROTEIN.value, entity_name="baz"), + Input("EEEEEEEEEEEE", entity_type=EntityType.PROTEIN.value, entity_name="boz"), + ] + chains = load_chains_from_raw( + inputs, + identifier="test", + tokenizer=AllAtomResidueTokenizer(RefConformerGenerator()), + ) + assert len(chains) == 3 + for chain in chains: + # NOTE this check is only valid because there are no residues that are tokenized per-atom + # Ensures that the entity data and the structure context in each chain are paired correctly + assert chain.structure_context.num_tokens == len( + chain.entity_data.full_sequence + )