Skip to content

Commit

Permalink
Update features and fix bug (#13)
Browse files Browse the repository at this point in the history
* update

* test mod compatible

* add graph visualizer

* fix lint

* add copy right for FGUtils

* prepare release

* add partial map expansion

* add testcase for partial expansion ver1

* update new features

* prepare release
  • Loading branch information
TieuLongPhan authored Nov 27, 2024
1 parent cfb522a commit 965ab58
Show file tree
Hide file tree
Showing 27 changed files with 1,345 additions and 166 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
*.json
test_mod.py
test_format.py
*dev_zone
test_format.py
97 changes: 97 additions & 0 deletions Test/SynAAM/test_aam_validator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import unittest
from synutility.SynAAM.aam_validator import AAMValidator


class TestAMMValidator(unittest.TestCase):

def setUp(self):
self.true_pair = (
(
"[CH:8]=1[S:9][CH:10]=[C:6]([C:5]#[C:4][CH2:3][N:2]([C:11]2=[CH:12]"
+ "[CH:13]=[CH:14][CH:15]=[CH:16]2)[CH3:1])[CH:7]=1.[OH2:17]>>[C:5]([N:2]"
+ "([CH3:1])[C:11]1=[CH:12][CH:13]=[CH:14][CH:15]=[CH:16]1)([C:6]2="
+ "[CH:10][S:9][CH:8]=[CH:7]2)=[CH:4][CH:3]=[O:17]"
),
(
"[OH2:17].[cH:12]1[cH:13][cH:14][cH:15][cH:16][c:11]1[N:2]([CH3:1])"
+ "[CH2:3][C:4]#[C:5][c:6]1[cH:10][s:9][cH:8][cH:7]1>>[cH:12]1[cH:13]"
+ "[cH:14][cH:15][cH:16][c:11]1[N:2]([CH3:1])[C:5](=[CH:4][CH:3]=[O:17])"
+ "[c:6]1[cH:10][s:9][cH:8][cH:7]1"
),
)
self.false_pair = (
(
"[CH:8]=1[S:9][CH:10]=[C:6]([C:5]#[C:4][CH2:3][N:2]([C:11]2=[CH:12]"
+ "[CH:13]=[CH:14][CH:15]=[CH:16]2)[CH3:1])[CH:7]=1.[OH2:17]>>[C:5]"
+ "([N:2]([CH3:1])[C:11]1=[CH:12][CH:13]=[CH:14][CH:15]=[CH:16]1)"
+ "([C:6]2=[CH:10][S:9][CH:8]=[CH:7]2)=[CH:4][CH:3]=[O:17]"
),
(
"[CH3:1][N:2]([CH2:3][C:4]#[C:5][c:7]1[cH:8][cH:9][s:10][cH:11]1)"
+ "[c:12]1[cH:13][cH:14][cH:15][cH:16][cH:17]1.[OH2:6]>>[CH3:1][N:2]"
+ "([C:3](=[CH:4][CH:5]=[O:6])[c:7]1[cH:8][cH:9][s:10][cH:11]1)"
+ "[c:12]1[cH:13][cH:14][cH:15][cH:16][cH:17]1"
),
)
self.tautomer = (
"[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][CH2:6][OH:7]>>[CH3:1][C:2](=[O:3])"
+ "[O:7][CH2:6][CH3:5].[OH2:4]",
"[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][CH2:6][OH:7]>>"
+ "[CH3:1][C:2](=[O:4])[O:7][CH2:6][CH3:5].[OH2:3]",
)

self.data_dict_1 = {"ref": self.true_pair[0], "map": self.true_pair[1]}
self.data_dict_2 = {"ref": self.false_pair[0], "map": self.false_pair[1]}
self.data_dict_3 = {"ref": self.tautomer[0], "map": self.tautomer[1]}
self.data = [self.data_dict_1, self.data_dict_2, self.data_dict_3]

def test_smiles_check(self):

self.assertTrue(
AAMValidator.smiles_check(
*self.true_pair, check_method="RC", ignore_aromaticity=False
)
)
self.assertFalse(
AAMValidator.smiles_check(
*self.false_pair, check_method="RC", ignore_aromaticity=False
)
)

def test_smiles_check_tautomer(self):
self.assertFalse(
AAMValidator.smiles_check(
self.tautomer[0],
self.tautomer[1],
check_method="RC",
ignore_aromaticity=False,
)
)

self.assertTrue(
AAMValidator.smiles_check_tautomer(
self.tautomer[0],
self.tautomer[1],
check_method="RC",
ignore_aromaticity=True,
)
)

def test_validate_smiles_dataframe(self):

results = AAMValidator.validate_smiles(
data=self.data,
ground_truth_col="ref",
mapped_cols=["map"],
check_method="RC",
ignore_aromaticity=False,
n_jobs=2,
verbose=0,
ignore_tautomers=False,
)
self.assertEqual(results[0]["accuracy"], 66.67)
self.assertEqual(results[0]["success_rate"], 100)


if __name__ == "__main__":
unittest.main()
25 changes: 25 additions & 0 deletions Test/SynAAM/test_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import unittest
from synutility.SynIO.Format.chemical_conversion import smart_to_gml
from synutility.SynAAM.inference import aam_infer


class TestAAMInference(unittest.TestCase):

def setUp(self):

self.rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1"
self.gml = smart_to_gml("[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]")
self.expect = (
"[Br:1][CH2:2][C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]"
+ "=[CH:5]1.[CH3:10][O:11][CH2:12][CH2:13][O:14][H:15]>>"
+ "[Br:1][H:15].[CH2:2]([C:3]1=[CH:4][CH:6]=[C:7]([Br:8])"
+ "[CH:9]=[CH:5]1)[O:14][CH2:13][CH2:12][O:11][CH3:10]"
)

def test_aam_infer(self):
result = aam_infer(self.rsmi, self.gml)
self.assertEqual(result[0], self.expect)


if __name__ == "__main__":
unittest.main()
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import networkx as nx
from synutility.SynIO.Format.its_construction import ITSConstruction
from synutility.SynAAM.its_construction import ITSConstruction


class TestITSConstruction(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion Test/SynAAM/test_normalize_aam.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_fix_rsmi(self):
for both reactants and products."""
input_rsmi = "[C:0]>>[C:1]"
expected_rsmi = "[C:1]>>[C:2]"
self.assertEqual(self.normalizer.fix_rsmi(input_rsmi), expected_rsmi)
self.assertEqual(self.normalizer.fix_aam_rsmi(input_rsmi), expected_rsmi)

def test_extract_subgraph(self):
"""Test extraction of a subgraph based on specified indices."""
Expand Down
23 changes: 21 additions & 2 deletions Test/SynAAM/test_partial_expand.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
import unittest
from synutility.SynAAM.aam_validator import AAMValidator
from synutility.SynAAM.its_construction import ITSConstruction
from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph

from synutility.SynAAM.partial_expand import PartialExpand


Expand All @@ -16,7 +20,7 @@ def test_expand(self):
# Perform the expansion
output_rsmi = PartialExpand.expand(input_rsmi)
# Assert the result matches the expected output
self.assertEqual(output_rsmi, expected_rsmi)
self.assertTrue(AAMValidator.smiles_check(output_rsmi, expected_rsmi, "ITS"))

def test_expand_2(self):
input_rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]"
Expand All @@ -25,7 +29,22 @@ def test_expand_2(self):
"[CH3:1][CH2:2][CH2:3][Cl:4].[NH2:5][H:6]"
+ ">>[CH3:1][CH2:2][CH2:3][NH2:5].[Cl:4][H:6]"
)
self.assertEqual(output_rsmi, expected_rsmi)
self.assertTrue(AAMValidator.smiles_check(output_rsmi, expected_rsmi, "ITS"))

def test_graph_expand(self):
input_rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1"
expect = (
"[Br:1][CH2:2][C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]=[CH:5]1."
+ "[CH3:10][O:11][CH2:12][CH2:13][O:14][H:15]>>[Br:1][H:15]"
+ ".[CH2:2]([C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]=[CH:5]1)"
+ "[O:14][CH2:13][CH2:12][O:11][CH3:10]"
)
r, p = rsmi_to_graph(
"[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]", sanitize=False
)
its = ITSConstruction().ITSGraph(r, p)
output = PartialExpand.graph_expand(its, input_rsmi)
self.assertTrue(AAMValidator.smiles_check(output, expect))


if __name__ == "__main__":
Expand Down
55 changes: 55 additions & 0 deletions Test/SynGraph/Transform/test_multi_step.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import unittest
from synutility.SynIO.Format.chemical_conversion import smart_to_gml
from synutility.SynGraph.Transform.multi_step import (
perform_multi_step_reaction,
remove_reagent_from_smiles,
calculate_max_depth,
find_all_paths,
)


class TestMultiStep(unittest.TestCase):
def setUp(self) -> None:
smarts = [
"[CH2:4]([CH:5]=[O:6])[H:7]>>[CH2:4]=[CH:5][O:6][H:7]",
(
"[CH2:2]=[O:3].[CH2:4]=[CH:5][O:6][H:7]>>[CH2:2]([O:3][H:7])[CH2:4]"
+ "[CH:5]=[O:6]"
),
"[CH2:4]([CH:5]=[O:6])[H:8]>>[CH2:4]=[CH:5][O:6][H:8]",
(
"[CH2:2]([OH:3])[CH:4]=[CH:5][O:6][H:8]>>[CH2:2]=[CH:4][CH:5]=[O:6]"
+ ".[OH:3][H:8]"
),
]
self.gml = [smart_to_gml(value) for value in smarts]
self.order = [0, 1, 0, -1]
self.rsmi = "CC=O.CC=O.CCC=O>>CC=O.CC=C(C)C=O.O"

def test_remove_reagent_from_smiles(self):
rsmi = remove_reagent_from_smiles(self.rsmi)
self.assertEqual(rsmi, "CC=O.CCC=O>>CC=C(C)C=O.O")

def test_perform_multi_step_reaction(self):
results, _ = perform_multi_step_reaction(self.gml, self.order, self.rsmi)
self.assertEqual(len(results), 4)

def test_calculate_max_depth(self):
_, reaction_tree = perform_multi_step_reaction(self.gml, self.order, self.rsmi)
max_depth = calculate_max_depth(reaction_tree)
self.assertEqual(max_depth, 4)

def test_find_all_paths(self):
results, reaction_tree = perform_multi_step_reaction(
self.gml, self.order, self.rsmi
)
target_products = sorted(self.rsmi.split(">>")[1].split("."))
max_depth = len(results)
all_paths = find_all_paths(reaction_tree, target_products, self.rsmi, max_depth)
self.assertEqual(len(all_paths), 1)
real_path = all_paths[0][1:] # remove the original reaction
self.assertEqual(len(real_path), 4)


if __name__ == "__main__":
unittest.main()
94 changes: 94 additions & 0 deletions Test/SynIO/Format/test_chemcal_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import unittest
import networkx as nx

from synutility.SynChem.Reaction.standardize import Standardize
from synutility.SynIO.Format.chemical_conversion import (
smiles_to_graph,
rsmi_to_graph,
graph_to_rsmi,
smart_to_gml,
gml_to_smart,
)

from synutility.SynGraph.Morphism.misc import rule_isomorphism


class TestChemicalConversions(unittest.TestCase):

def setUp(self) -> None:
self.rsmi = "[CH2:1]([H:4])[CH2:2][OH:3]>>[CH2:1]=[CH2:2].[H:4][OH:3]"
self.gml = (
"rule [\n"
' ruleID "rule"\n'
" left [\n"
' edge [ source 1 target 4 label "-" ]\n'
' edge [ source 1 target 2 label "-" ]\n'
' edge [ source 2 target 3 label "-" ]\n'
" ]\n"
" context [\n"
' node [ id 1 label "C" ]\n'
' node [ id 4 label "H" ]\n'
' node [ id 2 label "C" ]\n'
' node [ id 3 label "O" ]\n'
" ]\n"
" right [\n"
' edge [ source 1 target 2 label "=" ]\n'
' edge [ source 4 target 3 label "-" ]\n'
" ]\n"
"]"
)

self.std = Standardize()

def test_smiles_to_graph_valid(self):
# Test converting a valid SMILES to a graph
result = smiles_to_graph("[CH3:1][CH2:2][OH:3]", False, True, True)
self.assertIsInstance(result, nx.Graph)
self.assertEqual(result.number_of_nodes(), 3)

def test_smiles_to_graph_invalid(self):
# Test converting an invalid SMILES string to a graph
result = smiles_to_graph("invalid_smiles", True, False, False)
self.assertIsNone(result)

def test_rsmi_to_graph_valid(self):
# Test converting valid reaction SMILES to graphs for reactants and products
reactants_graph, products_graph = rsmi_to_graph(self.rsmi, sanitize=True)
self.assertIsInstance(reactants_graph, nx.Graph)
self.assertEqual(reactants_graph.number_of_nodes(), 3)
self.assertIsInstance(products_graph, nx.Graph)
self.assertEqual(products_graph.number_of_nodes(), 3)

reactants_graph, products_graph = rsmi_to_graph(self.rsmi, sanitize=False)
self.assertIsInstance(reactants_graph, nx.Graph)
self.assertEqual(reactants_graph.number_of_nodes(), 4)
self.assertIsInstance(products_graph, nx.Graph)
self.assertEqual(products_graph.number_of_nodes(), 4)

def test_rsmi_to_graph_invalid(self):
# Test handling of invalid RSMI format
result = rsmi_to_graph("invalid_format")
self.assertEqual((None, None), result)

def test_graph_to_rsmi(self):
r, p = rsmi_to_graph(self.rsmi, sanitize=False)
rsmi = graph_to_rsmi(r, p)
self.assertIsInstance(rsmi, str)
self.assertEqual(self.std.fit(rsmi, False), self.std.fit(self.rsmi, False))

def test_smart_to_gml(self):
result = smart_to_gml(self.rsmi, core=False, sanitize=False, reindex=False)
self.assertIsInstance(result, str)
self.assertEqual(result, self.gml)

result = smart_to_gml(self.rsmi, core=False, sanitize=False, reindex=True)
self.assertTrue(rule_isomorphism(result, self.gml))

def test_gml_to_smart(self):
smarts, _ = gml_to_smart(self.gml)
self.assertIsInstance(smarts, str)
self.assertEqual(self.std.fit(smarts, False), self.std.fit(self.rsmi, False))


if __name__ == "__main__":
unittest.main()
2 changes: 1 addition & 1 deletion lint.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash

flake8 . --count --max-complexity=13 --max-line-length=120 \
--per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501" \
--per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401" \
--exclude venv,core_engine.py,rule_apply.py \
--statistics
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "synutility"
version = "0.0.11"
version = "0.0.12"
authors = [
{name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"}
]
Expand Down
Loading

0 comments on commit 965ab58

Please sign in to comment.