Skip to content

Commit

Permalink
Add match_stereo to compare.
Browse files Browse the repository at this point in the history
  • Loading branch information
bergwerf committed Dec 4, 2024
1 parent 4aed664 commit 981350c
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 24 deletions.
42 changes: 31 additions & 11 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
}


bond_dir = {
'single_up': Chem.BondDir.BEGINWEDGE,
'single_down': Chem.BondDir.BEGINDASH
}


def check_authorization(provided_key: str) -> bool:
allowed_key = environ.get('API_KEY')
if not allowed_key:
Expand Down Expand Up @@ -59,18 +65,26 @@ def graph_to_mol(g: nx.Graph) -> Chem.rdchem.RWMol:
index[v] = mol.AddAtom(a)

for u, v in g.edges:
mol.AddBond(index[u], index[v], bond_type[g.edges[(u,v)]['type']])
type = g.edges[(u,v)]['type']
bond = mol.AddBond(index[u], index[v], bond_type[type])
if type in bond_dir:
mol.GetBondWithIdx(bond - 1).SetBondDir(bond_dir[type])

# Build 2D conformer.
conf = Chem.Conformer(mol.GetNumAtoms())
conf.Set3D(False)
for v in g.nodes:
# Note that we add a Z-coordinate with value 0.
conf.SetAtomPosition(index[v], g.nodes[v]['position'] + (0,))

# Assign conformer and compute stereochemistry.
mol = mol.GetMol()
mol.AddConformer(conf)
Chem.AssignStereochemistryFrom3D(mol)

Chem.SanitizeMol(mol)
Chem.DetectBondStereochemistry(mol)
Chem.AssignChiralTypesFromBondDirs(mol)
Chem.AssignStereochemistry(mol)

return mol

Expand All @@ -93,21 +107,27 @@ def _graph_to_smiles(g: nx.Graph) -> str:
return Chem.rdmolfiles.MolToSmiles(mol, isomericSmiles=True)


def compare_components(g1: nx.Graph, g2: nx.Graph) -> bool:
return (
nx.is_isomorphic(g1, g2, _node_match, _edge_match) and
_graph_to_smiles(g1) == _graph_to_smiles(g2))
def compare_components(g1: nx.Graph, g2: nx.Graph, match_stereo: bool) -> bool:
if nx.is_isomorphic(g1, g2, _node_match, _edge_match):
if match_stereo:
try:
return _graph_to_smiles(g1) == _graph_to_smiles(g2)
except ValueError:
return False
else:
return True


def compare(diagram1, diagram2) -> bool:
def compare(diagram1, diagram2, match_stereo = True) -> bool:
g1 = json_to_graph(diagram1)
g2 = json_to_graph(diagram2)
cs1 = list(map(lambda c: g1.subgraph(c), nx.connected_components(g1)))
cs2 = list(map(lambda c: g2.subgraph(c), nx.connected_components(g2)))

# Check if every component in g1 matches a unique component in g2.
for c1 in cs1:
match = (i for i, c2 in enumerate(cs2) if compare_components(c1, c2))
match = (i for i, c2 in enumerate(cs2) if
compare_components(c1, c2, match_stereo))
index = next(match, None)
if index is None:
return False
Expand All @@ -119,9 +139,8 @@ def compare(diagram1, diagram2) -> bool:


def validate(diagram) -> bool:
mol = json_to_mol(diagram)
try:
Chem.SanitizeMol(mol)
json_to_mol(diagram)
return True
except ValueError:
return False
Expand All @@ -137,7 +156,8 @@ def handle_compare(body):
return {
'equal': compare(
body['reference_diagram'],
body['student_diagram'])
body['student_diagram'],
body['match_stereo'])
}


Expand Down
Loading

0 comments on commit 981350c

Please sign in to comment.