Skip to content

Commit

Permalink
fix(moler_decoder): Fix hydrogen handling in scaffolds with explicit …
Browse files Browse the repository at this point in the history
…attachment points
  • Loading branch information
kmaziarz committed Nov 23, 2023
1 parent a9a3fc9 commit 94ca0b4
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions molecule_generation/layers/moler_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,8 +1193,8 @@ def decode(
for graph_repr, init_mol, mol_id in zip(graph_representations, initial_molecules, mol_ids):
num_free_bond_slots = [0] * len(init_mol.GetAtoms())

atom_ids_to_remove = []
atom_ids_to_keep = []
atom_id_pairs_to_disconnect: List[Tuple[int, int]] = []
atom_ids_to_keep: List[int] = []

for atom in init_mol.GetAtoms():
if atom.GetAtomicNum() == 0:
Expand All @@ -1220,22 +1220,29 @@ def decode(
neighbour_idx = begin_idx if begin_idx != atom.GetIdx() else end_idx
num_free_bond_slots[neighbour_idx] += 1

atom_ids_to_remove.append(atom.GetIdx())
atom_id_pairs_to_disconnect.append((atom.GetIdx(), neighbour_idx))
else:
atom_ids_to_keep.append(atom.GetIdx())

if not atom_ids_to_remove:
init_mol_original = init_mol
if not atom_id_pairs_to_disconnect:
# No explicit attachment points, so assume we can connect anywhere.
num_free_bond_slots = None
else:
num_free_bond_slots = [num_free_bond_slots[idx] for idx in atom_ids_to_keep]
init_mol = Chem.RWMol(init_mol)

# Save the atom list to be able to extract neighbour atoms by their original id.
original_atom_list = list(init_mol.GetAtoms())

# Remove atoms starting from largest index, so that we don't have to account for
# indices shifting during removal.
for atom_idx in reversed(atom_ids_to_remove):
# indices of atoms to remove shifting due to other removals.
for atom_idx, neighbour_idx in reversed(atom_id_pairs_to_disconnect):
init_mol.RemoveAtom(atom_idx)

neighbour_atom = original_atom_list[neighbour_idx]
neighbour_atom.SetNumExplicitHs(neighbour_atom.GetNumExplicitHs() + 1)

# Determine how the scaffold atoms will get reordered when we canonicalize it, so we can
# permute `num_free_bond_slots` appropriately.
canonical_ordering = compute_canonical_atom_order(init_mol)
Expand All @@ -1245,6 +1252,10 @@ def decode(
# renumbering to `num_free_bond_slots` earlier.
init_mol = Chem.MolFromSmiles(Chem.MolToSmiles(init_mol))

if init_mol is None:
scaffold = Chem.MolToSmiles(init_mol_original)
raise ValueError(f"Scaffold {scaffold} could not be processed")

# Clear aromatic flags in the scaffold, since partial graphs during training never have
# them set (however we _do_ run `AtomIsAromaticFeatureExtractor`, it just always returns
# 0 for partial graphs during training).
Expand Down

0 comments on commit 94ca0b4

Please sign in to comment.