From 94ca0b471158edc31c24f632841d798260998fc3 Mon Sep 17 00:00:00 2001 From: Krzysztof Maziarz Date: Thu, 23 Nov 2023 13:47:53 +0000 Subject: [PATCH] fix(moler_decoder): Fix hydrogen handling in scaffolds with explicit attachment points --- molecule_generation/layers/moler_decoder.py | 23 +++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/molecule_generation/layers/moler_decoder.py b/molecule_generation/layers/moler_decoder.py index b1e0df9..154a0c9 100644 --- a/molecule_generation/layers/moler_decoder.py +++ b/molecule_generation/layers/moler_decoder.py @@ -1193,8 +1193,8 @@ def decode( for graph_repr, init_mol, mol_id in zip(graph_representations, initial_molecules, mol_ids): num_free_bond_slots = [0] * len(init_mol.GetAtoms()) - atom_ids_to_remove = [] - atom_ids_to_keep = [] + atom_id_pairs_to_disconnect: List[Tuple[int, int]] = [] + atom_ids_to_keep: List[int] = [] for atom in init_mol.GetAtoms(): if atom.GetAtomicNum() == 0: @@ -1220,22 +1220,29 @@ def decode( neighbour_idx = begin_idx if begin_idx != atom.GetIdx() else end_idx num_free_bond_slots[neighbour_idx] += 1 - atom_ids_to_remove.append(atom.GetIdx()) + atom_id_pairs_to_disconnect.append((atom.GetIdx(), neighbour_idx)) else: atom_ids_to_keep.append(atom.GetIdx()) - if not atom_ids_to_remove: + init_mol_original = init_mol + if not atom_id_pairs_to_disconnect: # No explicit attachment points, so assume we can connect anywhere. num_free_bond_slots = None else: num_free_bond_slots = [num_free_bond_slots[idx] for idx in atom_ids_to_keep] init_mol = Chem.RWMol(init_mol) + # Save the atom list to be able to extract neighbour atoms by their original id. + original_atom_list = list(init_mol.GetAtoms()) + # Remove atoms starting from largest index, so that we don't have to account for - # indices shifting during removal. - for atom_idx in reversed(atom_ids_to_remove): + # indices of atoms to remove shifting due to other removals. + for atom_idx, neighbour_idx in reversed(atom_id_pairs_to_disconnect): init_mol.RemoveAtom(atom_idx) + neighbour_atom = original_atom_list[neighbour_idx] + neighbour_atom.SetNumExplicitHs(neighbour_atom.GetNumExplicitHs() + 1) + # Determine how the scaffold atoms will get reordered when we canonicalize it, so we can # permute `num_free_bond_slots` appropriately. canonical_ordering = compute_canonical_atom_order(init_mol) @@ -1245,6 +1252,10 @@ def decode( # renumbering to `num_free_bond_slots` earlier. init_mol = Chem.MolFromSmiles(Chem.MolToSmiles(init_mol)) + if init_mol is None: + scaffold = Chem.MolToSmiles(init_mol_original) + raise ValueError(f"Scaffold {scaffold} could not be processed") + # Clear aromatic flags in the scaffold, since partial graphs during training never have # them set (however we _do_ run `AtomIsAromaticFeatureExtractor`, it just always returns # 0 for partial graphs during training).