diff --git a/doc/figure_scripts/caffeine_example.py b/doc/figure_scripts/caffeine_example.py new file mode 100644 index 0000000..3b550db --- /dev/null +++ b/doc/figure_scripts/caffeine_example.py @@ -0,0 +1,11 @@ +import matplotlib.pyplot as plt +from fgutils import Parser +from fgutils.vis import plot_as_mol + +parser = Parser() +mol = parser("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") + +fig, ax = plt.subplots(1, 1, dpi=200) +plot_as_mol(mol, ax) +plt.savefig("doc/figures/caffeine_example.png", bbox_inches="tight", transparent=True) + diff --git a/doc/figure_scripts/diels_alder_example.py b/doc/figure_scripts/diels_alder_example.py new file mode 100644 index 0000000..a73cb4f --- /dev/null +++ b/doc/figure_scripts/diels_alder_example.py @@ -0,0 +1,41 @@ +import matplotlib.pyplot as plt +from fgutils.proxy import ProxyGroup, ProxyGraph, ReactionProxy +from fgutils.proxy_collection.common import common_groups +from fgutils.vis import plot_reaction + + +electron_donating_group = ProxyGroup( + "electron_donating_group", pattern="{alkyl,aryl,amine}" +) +electron_withdrawing_group = ProxyGroup( + "electron_withdrawing_group", + pattern="{alkohol,ether,aldehyde,ester,nitrile}", +) +diene_group = ProxyGroup( + "diene", + ProxyGraph("C<2,1>C<1,2>C<2,1>C{electron_donating_group}", anchor=[0, 3]), +) +dienophile_group = ProxyGroup( + "dienophile", + ProxyGraph("C<2,1>C{electron_withdrawing_group}", anchor=[0, 1]), +) +groups = common_groups + [ + electron_donating_group, + electron_withdrawing_group, + diene_group, + dienophile_group, +] + +proxy = ReactionProxy("{diene}1<0,1>{dienophile}<0,1>1", groups) + +r, c = 3, 2 +fig, ax = plt.subplots(r, c, dpi=400) +for ri in range(r): + for ci in range(c): + g, h = next(proxy) + ax[ri, ci].axis("off") + plot_reaction(g, h, ax[ri, ci]) + +plt.tight_layout() +plt.savefig("doc/figures/diels_alder_example.png", bbox_inches="tight", transparent=True) +plt.show() diff --git a/doc/figure_scripts/labeled_node_example.py b/doc/figure_scripts/labeled_node_example.py new file mode 100644 index 0000000..e3ada3b --- /dev/null +++ b/doc/figure_scripts/labeled_node_example.py @@ -0,0 +1,20 @@ +import matplotlib.pyplot as plt +from fgutils import Parser +from fgutils.proxy import MolProxy, ProxyGroup +from fgutils.vis import plot_graph + +pattern = "CC(=O)O{propyl}" +propyl_group = ProxyGroup("propyl", pattern="CCC") +parser = Parser() +proxy = MolProxy(pattern, propyl_group, parser=parser) + +g = parser(pattern) +mol = next(proxy) + +fig, ax = plt.subplots(1, 2, dpi=100, figsize=(12, 4)) +plot_graph(g, ax[0], show_labels=True) +plot_graph(mol, ax[1]) +plt.savefig( + "doc/figures/labeled_node_example.png", bbox_inches="tight", transparent=True +) +plt.show() diff --git a/doc/figure_scripts/simple_its_example.py b/doc/figure_scripts/simple_its_example.py new file mode 100644 index 0000000..14daec3 --- /dev/null +++ b/doc/figure_scripts/simple_its_example.py @@ -0,0 +1,14 @@ +import matplotlib.pyplot as plt +from fgutils import Parser +from fgutils.proxy import Proxy +from fgutils.vis import plot_its, plot_graph + + +pattern = "C1<2,1>C<1,2>C<2,1>C(C)<0,1>C<2,1>C(O)<0,1>1" +parser = Parser() +g = parser(pattern) + +fig, ax = plt.subplots(1, 1) +plot_graph(g, ax) +plt.savefig("doc/figures/simple_its_example.png", bbox_inches="tight", transparent=True) +plt.show() diff --git a/doc/figures/caffeine_example.png b/doc/figures/caffeine_example.png new file mode 100644 index 0000000..e0f0029 Binary files /dev/null and b/doc/figures/caffeine_example.png differ diff --git a/doc/figures/diels_alder_example.png b/doc/figures/diels_alder_example.png new file mode 100644 index 0000000..34f7ca7 Binary files /dev/null and b/doc/figures/diels_alder_example.png differ diff --git a/doc/figures/labeled_node_example.png b/doc/figures/labeled_node_example.png new file mode 100644 index 0000000..0558679 Binary files /dev/null and b/doc/figures/labeled_node_example.png differ diff --git a/doc/figures/simple_its_example.png b/doc/figures/simple_its_example.png new file mode 100644 index 0000000..c782e9a Binary files /dev/null and b/doc/figures/simple_its_example.png differ diff --git a/doc/index.rst b/doc/index.rst index 6f794f0..04b2227 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -8,8 +8,8 @@ Welcome to FGUtils's documentation! .. toctree:: :maxdepth: 1 - :caption: Contents: + pattern_syntax references diff --git a/doc/pattern_syntax.rst b/doc/pattern_syntax.rst new file mode 100644 index 0000000..df0224a --- /dev/null +++ b/doc/pattern_syntax.rst @@ -0,0 +1,121 @@ +============== +Pattern Syntax +============== + +FGUtils has its own graph description language. The syntax is closely related +to the SMILES format for molecules and reactions. It is kind of an extenstion +to SMILES to support modeling ITS graphs and reaction patterns. To convert the +SMILES-like description into a graph object use the +:py:class:`~fgutils.parse.Parser` class. The Caffeine molecular graph can be +obtained as follows:: + + import matplotlib.pyplot as plt + from fgutils import Parser + from fgutils.vis import plot_as_mol + + parser = Parser() + mol = parser("CN1C=NC2=C1C(=O)N(C(=O)N2C)C") + + fig, ax = plt.subplots(1, 1) + plot_as_mol(mol, ax) + plt.show() + +.. image:: figures/caffeine_example.png + :width: 300 + +Besides parsing common SMILES it is possible to generate molecule-like graphs +with more abstract nodes, i.e., arbitrary node labels. Arbitrary node labels +are surrounded by ``{}`` (e.g. ``{label}``). This abstract labeling can be used +to substitute nodes with specific patterns. This can be done by using a +:py:class:`~fgutils.proxy.Proxy`. Propyl acetate can be created by replacing +the labeled node with the propyl group:: + + import matplotlib.pyplot as plt + from fgutils import Parser + from fgutils.proxy import MolProxy, ProxyGroup + from fgutils.vis import plot_graph + + pattern = "CC(=O)O{propyl}" + propyl_group = ProxyGroup("propyl", pattern="CCC") + parser = Parser() + proxy = MolProxy(pattern, propyl_group, parser=parser) + + g = parser(pattern) + mol = next(proxy) + + fig, ax = plt.subplots(1, 2, dpi=100, figsize=(12, 4)) + plot_graph(g, ax[0], show_labels=True) + plot_graph(mol, ax[1]) + plt.show() + +.. image:: figures/labeled_node_example.png + :width: 600 + + +.. note:: + + A node can have more than one label. This can be done by separating the + labels with a comma, e.g.: ``{label_1,label_2}``. + +Another extension to the SMILES notation is the encoding of bond changes. This +feature is required to model reaction mechanisms as ITS graph. Changing bonds +are surrounded by ``<>`` (e.g. ``<1, 2>`` for the formation of a double bond +from a single bond). The extended notation allows the automated generation of +reaction examples with complete atom-to-atom maps. The following code snippet +demonstrates the generation of a few Diels-Alder reactions. The ``diene`` and +``dienophile`` groups can of course be extended to increase varaity of the +samples:: + + + import matplotlib.pyplot as plt + from fgutils.proxy import ProxyGroup, ProxyGraph, ReactionProxy + from fgutils.proxy_collection.common import common_groups + from fgutils.vis import plot_reaction + + + electron_donating_group = ProxyGroup( + "electron_donating_group", pattern="{alkyl,aryl,amine}" + ) + electron_withdrawing_group = ProxyGroup( + "electron_withdrawing_group", + pattern="{alkohol,ether,aldehyde,ester,nitrile}", + ) + diene_group = ProxyGroup( + "diene", + ProxyGraph("C<2,1>C<1,2>C<2,1>C{electron_donating_group}", anchor=[0, 3]), + ) + dienophile_group = ProxyGroup( + "dienophile", + ProxyGraph("C<2,1>C{electron_withdrawing_group}", anchor=[0, 1]), + ) + groups = common_groups + [ + electron_donating_group, + electron_withdrawing_group, + diene_group, + dienophile_group, + ] + + proxy = ReactionProxy("{diene}1<0,1>{dienophile}<0,1>1", groups) + + r, c = 3, 2 + fig, ax = plt.subplots(r, c, dpi=400) + for ri in range(r): + for ci in range(c): + g, h = next(proxy) + ax[ri, ci].axis("off") + plot_reaction(g, h, ax[ri, ci]) + plt.tight_layout() + plt.show() + +.. image:: figures/diels_alder_example.png + :width: 1000 + +.. note:: + + The ``electron_donating_group`` and ``electron_withdrawing_group`` serve as + a collection of other groups to simplify the notation. They consist of a + single node with multiple labels. When iterating the next sample from the + proxy (``next(proxy)``) the labeled nodes get replaced by the pattern from + one of the groups. The group/label is chosen randomly with uniform + distribution. + diff --git a/doc/references.rst b/doc/references.rst index 16c852b..18fa2c4 100644 --- a/doc/references.rst +++ b/doc/references.rst @@ -2,7 +2,14 @@ References ========== -Proxy +parse +===== + +.. automodule:: fgutils.parse + :members: + + +proxy ===== .. automodule:: fgutils.proxy diff --git a/fgutils/__init__.py b/fgutils/__init__.py index 33e8e13..b6a21eb 100644 --- a/fgutils/__init__.py +++ b/fgutils/__init__.py @@ -1,3 +1,4 @@ from .permutation import PermutationMapper from .query import FGQuery from .proxy import ReactionProxy +from .parse import Parser diff --git a/fgutils/chem/its.py b/fgutils/chem/its.py new file mode 100644 index 0000000..a2542bf --- /dev/null +++ b/fgutils/chem/its.py @@ -0,0 +1,89 @@ +import collections +import networkx as nx +import rdkit.Chem as Chem + +from fgutils.rdkit import graph_to_mol +from fgutils.const import SYMBOL_KEY, AAM_KEY, BOND_KEY + + +def _add_its_nodes(ITS, G, H, eta, symbol_key): + eta_G, eta_G_inv, eta_H, eta_H_inv = eta[0], eta[1], eta[2], eta[3] + for n, d in G.nodes(data=True): + n_ITS = eta_G[n] + n_H = eta_H_inv[n_ITS] + if n_ITS is not None and n_H is not None: + ITS.add_node(n_ITS, symbol=d[symbol_key], idx_map=(n, n_H)) + for n, d in H.nodes(data=True): + n_ITS = eta_H[n] + n_G = eta_G_inv[n_ITS] + if n_ITS is not None and n_G is not None and n_ITS not in ITS.nodes: + ITS.add_node(n_ITS, symbol=d[symbol_key], idx_map=(n_G, n)) + + +def _add_its_edges(ITS, G, H, eta, bond_key): + eta_G, eta_G_inv, eta_H, eta_H_inv = eta[0], eta[1], eta[2], eta[3] + for n1, n2, d in G.edges(data=True): + if n1 > n2: + continue + e_G = d[bond_key] + n_ITS1 = eta_G[n1] + n_ITS2 = eta_G[n2] + n_H1 = eta_H_inv[n_ITS1] + n_H2 = eta_H_inv[n_ITS2] + e_H = None + if H.has_edge(n_H1, n_H2): + e_H = H[n_H1][n_H2][bond_key] + if not ITS.has_edge(n_ITS1, n_ITS2) and n_ITS1 > 0 and n_ITS2 > 0: + ITS.add_edge(n_ITS1, n_ITS2, bond=(e_G, e_H)) + + for n1, n2, d in H.edges(data=True): + if n1 > n2: + continue + e_H = d[bond_key] + n_ITS1 = eta_H[n1] + n_ITS2 = eta_H[n2] + n_G1 = eta_G_inv[n_ITS1] + n_G2 = eta_G_inv[n_ITS2] + if n_G1 is None or n_G2 is None: + continue + if not G.has_edge(n_G1, n_G2) and n_ITS1 > 0 and n_ITS2 > 0: + ITS.add_edge(n_ITS1, n_ITS2, bond=(None, e_H)) + + +def get_its(G: nx.Graph, H: nx.Graph) -> nx.Graph: + """ + + Get the ITS graph of reaction G \u2192 H. G and H must be molecular graphs + with node labels 'aam' and 'symbol' and bond label 'bond'. + + :param G: Reactant molecular graph. + :param H: Product molecular graph. + + :returns: Returns the ITS graph. + """ + eta_G = collections.defaultdict(lambda: None) + eta_G_inv = collections.defaultdict(lambda: None) + eta_H = collections.defaultdict(lambda: None) + eta_H_inv = collections.defaultdict(lambda: None) + eta = (eta_G, eta_G_inv, eta_H, eta_H_inv) + + for n, d in G.nodes(data=True): + if d is None: + raise ValueError("Graph node {} has no data.".format(n)) + if AAM_KEY in d.keys() and d[AAM_KEY] >= 0: + eta_G[n] = d[AAM_KEY] + eta_G_inv[d[AAM_KEY]] = n + for n, d in H.nodes(data=True): + if d is None: + raise ValueError("Graph node {} has no data.".format(n)) + if AAM_KEY in d.keys() and d[AAM_KEY] >= 0: + eta_H[n] = d[AAM_KEY] + eta_H_inv[d[AAM_KEY]] = n + + ITS = nx.Graph() + _add_its_nodes(ITS, G, H, eta, SYMBOL_KEY) + _add_its_edges(ITS, G, H, eta, BOND_KEY) + + return ITS + + diff --git a/fgutils/const.py b/fgutils/const.py new file mode 100644 index 0000000..1b6627d --- /dev/null +++ b/fgutils/const.py @@ -0,0 +1,5 @@ +AAM_KEY = "aam" +SYMBOL_KEY = "symbol" +BOND_KEY = "bond" +IS_LABELED_KEY = "is_labeled" +LABELS_KEY = "labels" diff --git a/fgutils/parse.py b/fgutils/parse.py index 5497f1e..a53b23e 100644 --- a/fgutils/parse.py +++ b/fgutils/parse.py @@ -2,19 +2,23 @@ import numpy as np import networkx as nx +from fgutils.const import SYMBOL_KEY + + +token_specification = [ + ("ATOM", r"H|Br|Cl|Se|Sn|Si|C|N|O|P|S|F|B|I|b|c|n|o|p|s"), + ("BOND", r"\.|-|=|#|$|:|/|\\"), + ("BRANCH_START", r"\("), + ("BRANCH_END", r"\)"), + ("RING_NUM", r"\d+"), + ("WILDCARD", r"R"), + ("RC_BOND", r"<\d*,\d*>"), + ("NODE_LABEL", r"\{[a-zA-Z0-9_,-]+\}"), + ("MISMATCH", r"."), +] + def tokenize(pattern): - token_specification = [ - ("ATOM", r"H|Br|Cl|Se|Sn|Si|C|N|O|P|S|F|B|I|b|c|n|o|p|s"), - ("BOND", r"\.|-|=|#|$|:|/|\\"), - ("BRANCH_START", r"\("), - ("BRANCH_END", r"\)"), - ("RING_NUM", r"\d+"), - ("WILDCARD", r"R"), - ("RC_BOND", r"<\d*,\d*>"), - ("NODE_LABEL", r"\{[a-zA-Z0-9_,-]+\}"), - ("MISMATCH", r"."), - ] token_re = "|".join("(?P<%s>%s)" % pair for pair in token_specification) for m in re.finditer(token_re, pattern): ttype = m.lastgroup @@ -26,12 +30,34 @@ def tokenize(pattern): class Parser: + """ + + Class to convert a SMILES like graph description into a NetworkX graph. + + Example for parsing acetic acid:: + + parser = Parser() + g = parser("CC(O)=O") # Returns graph with 4 nodes and 3 edges + + :param use_multigraph: + + Flag to specify if the resulting graph object should be of type + networkx.MultiGraph or networkx.Graph. The difference is that a + MultiGraph can have more than one edge between two nodes. For parsing + molecule like graphs this is not necessary because bond types are + encoded as edge labels. (Default = False) + + :param verbose: + + Flag to print information during parsing. (Default = False) + + """ def __init__(self, use_multigraph=False, verbose=False): self.bond_to_order_map = {"-": 1, "=": 2, "#": 3, "$": 4, ":": 1.5, ".": None} self.verbose = verbose self.use_multigraph = use_multigraph self.__clear() - + def __clear(self): if self.use_multigraph: self.graph = nx.MultiGraph() @@ -48,7 +74,7 @@ def __print_process_token(self, ttype, value): "Process Token: {:>15}={} | Anchor: {}@{} Bond: {}".format( ttype, value, - self.graph.nodes[self.anchor]["symbol"] + self.graph.nodes[self.anchor][SYMBOL_KEY] if self.anchor is not None else "None", self.anchor, @@ -65,7 +91,7 @@ def __process_token_add_node(self, ttype, value, idx): value = "#" self.graph.add_node(idx, symbol=value, labels=labels, is_labeled=is_labeled) if self.anchor is not None: - anchor_sym = self.graph.nodes[self.anchor]["symbol"] + anchor_sym = self.graph.nodes[self.anchor][SYMBOL_KEY] if self.bond_order == 1 and anchor_sym.islower() and value.islower(): self.bond_order = 1.5 if self.bond_order is not None: @@ -86,9 +112,9 @@ def __process_token_rc_bond(self, value): def __process_token_ring(self, value): if value in self.rings.keys(): - anchor_sym = self.graph.nodes[self.anchor]["symbol"] + anchor_sym = self.graph.nodes[self.anchor][SYMBOL_KEY] ring_anchor = self.rings[value] - ring_anchor_sym = self.graph.nodes[ring_anchor]["symbol"] + ring_anchor_sym = self.graph.nodes[ring_anchor][SYMBOL_KEY] if anchor_sym.islower() != ring_anchor_sym.islower(): raise SyntaxError( ( @@ -124,7 +150,27 @@ def __process_token(self, ttype, value, idx) -> bool: return False return True - def parse(self, pattern, idx_offset=0): + def parse(self, pattern: str, idx_offset: int = 0): + """ + + Method to parse a SMILES like graph pattern. + + :param pattern: + + The pattern to convert into a graph. The pattern is a tree-like + description of the graph. It is strongly oriented at the SMILES + notation. + + :param idx_offset: + + The index offset argument provides the starting value for the + consecutive node numbering. (Default = 0) + + :returns: + + Returns the converted graph object. + + """ self.__clear() for ttype, value, col in tokenize(pattern): self.__print_process_token(ttype, value) diff --git a/fgutils/rdkit.py b/fgutils/rdkit.py index 473fd88..05744f2 100644 --- a/fgutils/rdkit.py +++ b/fgutils/rdkit.py @@ -2,6 +2,8 @@ import rdkit.Chem as Chem import rdkit.Chem.rdmolfiles as rdmolfiles +from fgutils.const import IS_LABELED_KEY, SYMBOL_KEY, AAM_KEY, LABELS_KEY, BOND_KEY + def mol_to_graph(mol: Chem.rdchem.Mol) -> nx.Graph: bond_order_map = { @@ -39,21 +41,25 @@ def graph_to_mol(g: nx.Graph) -> Chem.rdchem.Mol: rw_mol = Chem.rdchem.RWMol() idx_map = {} for n, d in g.nodes(data=True): - atom_symbol = _get_rdkit_atom_sym(d["symbol"]) - if "is_labeled" in d.keys() and d["is_labeled"]: + if d is None: + raise ValueError("Graph node {} has no data.".format(n)) + atom_symbol = _get_rdkit_atom_sym(d[SYMBOL_KEY]) + if IS_LABELED_KEY in d.keys() and d[IS_LABELED_KEY]: raise ValueError( "Graph contains labeled nodes. Node {} with label [{}].".format( - n, ",".join(d["labels"]) + n, ",".join(d[LABELS_KEY]) ) ) idx = rw_mol.AddAtom(Chem.rdchem.Atom(atom_symbol)) idx_map[n] = idx - if "aam" in d.keys() and d["aam"] >= 0: - rw_mol.GetAtomWithIdx(idx).SetAtomMapNum(d["aam"]) + if AAM_KEY in d.keys() and d[AAM_KEY] >= 0: + rw_mol.GetAtomWithIdx(idx).SetAtomMapNum(d[AAM_KEY]) for n1, n2, d in g.edges(data=True): + if d is None: + raise ValueError("Graph edge {} has no data.".format((n1, n2))) idx1 = idx_map[n1] idx2 = idx_map[n2] - rw_mol.AddBond(idx1, idx2, bond_order_map[d["bond"]]) + rw_mol.AddBond(idx1, idx2, bond_order_map[d[BOND_KEY]]) return rw_mol.GetMol() diff --git a/fgutils/vis.py b/fgutils/vis.py new file mode 100644 index 0000000..9b256cb --- /dev/null +++ b/fgutils/vis.py @@ -0,0 +1,174 @@ +import io +import networkx as nx +import rdkit.Chem.rdmolfiles as rdmolfiles +import rdkit.Chem.Draw.rdMolDraw2D as rdMolDraw2D +import rdkit.Chem.rdChemReactions as rdChemReactions + +import rdkit.Chem as Chem +import rdkit.Chem.rdmolfiles as rdmolfiles +import rdkit.Chem.rdDepictor as rdDepictor + +from PIL import Image + +from fgutils.rdkit import graph_to_mol, graph_to_smiles +from fgutils.const import SYMBOL_KEY, AAM_KEY, BOND_KEY, IS_LABELED_KEY, LABELS_KEY + + +def _get_its_as_mol(its: nx.Graph) -> Chem.rdchem.Mol: + _its = its.copy() + for n in _its.nodes: + _its.nodes[n][AAM_KEY] = n + for u, v in _its.edges(): + _its[u][v][BOND_KEY] = 1 + return graph_to_mol(_its) + + +def _get_graph_as_mol(g: nx.Graph) -> Chem.rdchem.Mol: + _g = g.copy() + for n, d in _g.nodes(data=True): + if d[IS_LABELED_KEY]: + _g.nodes[n][SYMBOL_KEY] = "C" + _g.nodes[n][IS_LABELED_KEY] = False + _g.nodes[n][AAM_KEY] = n + for u, v in _g.edges(): + _g[u][v][BOND_KEY] = 1 + return graph_to_mol(_g) + + +def plot_its(its, ax, use_mol_coords=True): + bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡"} + + if use_mol_coords: + mol = _get_its_as_mol(its) + positions = {} + conformer = rdDepictor.Compute2DCoords(mol) + for i, atom in enumerate(mol.GetAtoms()): + aam = atom.GetAtomMapNum() + apos = mol.GetConformer(conformer).GetAtomPosition(i) + positions[aam] = [apos.x, apos.y] + else: + positions = nx.spring_layout(its) + + ax.axis("equal") + ax.axis("off") + + nx.draw_networkx_edges(its, positions, edge_color="#000000", ax=ax) + nx.draw_networkx_nodes(its, positions, node_color="#FFFFFF", node_size=500, ax=ax) + + labels = {n: "{}:{}".format(d[SYMBOL_KEY], n) for n, d in its.nodes(data=True)} + edge_labels = {} + for u, v, d in its.edges(data=True): + bc1 = d[BOND_KEY][0] + bc2 = d[BOND_KEY][1] + if bc1 == bc2: + continue + if bc1 in bond_char.keys(): + bc1 = bond_char[bc1] + if bc2 in bond_char.keys(): + bc2 = bond_char[bc2] + edge_labels[(u, v)] = "({},{})".format(bc1, bc2) + + nx.draw_networkx_labels(its, positions, labels=labels, ax=ax) + nx.draw_networkx_edge_labels(its, positions, edge_labels=edge_labels, ax=ax) + + +def plot_as_mol(g: nx.Graph, ax, use_mol_coords=True): + bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡"} + + if use_mol_coords: + mol = graph_to_mol(g) + positions = {} + conformer = rdDepictor.Compute2DCoords(mol) + for i, atom in enumerate(mol.GetAtoms()): + aidx = atom.GetIdx() + apos = mol.GetConformer(conformer).GetAtomPosition(i) + positions[aidx] = [apos.x, apos.y] + else: + positions = nx.spring_layout(g) + + ax.axis("equal") + ax.axis("off") + + nx.draw_networkx_edges(g, positions, edge_color="#909090", ax=ax) + nx.draw_networkx_nodes(g, positions, node_color="#FFFFFF", node_size=500, ax=ax) + + labels = {n: "{}".format(d[SYMBOL_KEY]) for n, d in g.nodes(data=True)} + edge_labels = {} + for u, v, d in g.edges(data=True): + bc = d[BOND_KEY] + if bc in bond_char.keys(): + bc = bond_char[bc] + edge_labels[(u, v)] = "{}".format(bc) + + nx.draw_networkx_labels(g, positions, labels=labels, ax=ax) + nx.draw_networkx_edge_labels(g, positions, edge_labels=edge_labels, ax=ax) + + +def get_rxn_img(smiles): + drawer = rdMolDraw2D.MolDraw2DCairo(1600, 900) + if ">>" in smiles: + rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True) + drawer.DrawReaction(rxn) + else: + mol = rdmolfiles.MolFromSmiles(smiles) + if mol is None: + mol = rdmolfiles.MolFromSmarts(smiles) + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + img = Image.open(io.BytesIO(drawer.GetDrawingText())) + nonwhite_positions = [ + (x, y) + for x in range(img.size[0]) + for y in range(img.size[1]) + if img.getdata()[x + y * img.size[0]] != (255, 255, 255) # type: ignore + ] + rect = ( + min([x - 10 for x, _ in nonwhite_positions]), + min([y - 10 for _, y in nonwhite_positions]), + max([x + 10 for x, _ in nonwhite_positions]), + max([y + 10 for _, y in nonwhite_positions]), + ) + return img.crop(rect) + + +def plot_graph(g: nx.Graph, ax, use_mol_coords=True, show_labels=False): + bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡"} + + if use_mol_coords: + mol = _get_graph_as_mol(g) + positions = {} + conformer = rdDepictor.Compute2DCoords(mol) + for i, atom in enumerate(mol.GetAtoms()): + aidx = atom.GetIdx() + apos = mol.GetConformer(conformer).GetAtomPosition(i) + positions[aidx] = [apos.x, apos.y] + else: + positions = nx.spring_layout(g) + + ax.axis("equal") + ax.axis("off") + + nx.draw_networkx_edges(g, positions, edge_color="#909090", ax=ax) + nx.draw_networkx_nodes(g, positions, node_color="#FFFFFF", node_size=500, ax=ax) + + labels = {} # {n: "{}".format(d[SYMBOL_KEY]) for n, d in g.nodes(data=True)} + for n, d in g.nodes(data=True): + lbl = "{}".format(d[SYMBOL_KEY]) + if d[IS_LABELED_KEY] and show_labels: + lbl = "{}".format(d[LABELS_KEY]) + labels[n] = lbl + + edge_labels = {} + for u, v, d in g.edges(data=True): + bc = d[BOND_KEY] + if bc in bond_char.keys(): + bc = bond_char[bc] + edge_labels[(u, v)] = "{}".format(bc) + + nx.draw_networkx_labels(g, positions, labels=labels, ax=ax) + nx.draw_networkx_edge_labels(g, positions, edge_labels=edge_labels, ax=ax) + + +def plot_reaction(g: nx.Graph, h: nx.Graph, ax): + rxn_smiles = "{}>>{}".format(graph_to_smiles(g), graph_to_smiles(h)) + ax.imshow(get_rxn_img(rxn_smiles))