From 965ab58dc6a3d7dde30683bc758c6f7730e96460 Mon Sep 17 00:00:00 2001 From: Tieu Long Phan <125431507+TieuLongPhan@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:04:06 +0100 Subject: [PATCH] Update features and fix bug (#13) * update * test mod compatible * add graph visualizer * fix lint * add copy right for FGUtils * prepare release * add partial map expansion * add testcase for partial expansion ver1 * update new features * prepare release --- .gitignore | 2 + Test/SynAAM/test_aam_validator.py | 97 +++++++ Test/SynAAM/test_inference.py | 25 ++ .../test_its_construction.py | 2 +- Test/SynAAM/test_normalize_aam.py | 2 +- Test/SynAAM/test_partial_expand.py | 23 +- Test/SynGraph/Transform/test_multi_step.py | 55 ++++ Test/SynIO/Format/test_chemcal_conversion.py | 94 +++++++ lint.sh | 2 +- pyproject.toml | 2 +- synutility/SynAAM/aam_validator.py | 254 ++++++++++++++++++ synutility/SynAAM/inference.py | 73 +++++ .../Format => SynAAM}/its_construction.py | 0 synutility/SynAAM/misc.py | 125 ++++++++- synutility/SynAAM/normalize_aam.py | 57 +++- synutility/SynAAM/partial_expand.py | 104 ++++--- .../{SynMOD => SynGraph/Morphism}/__init__.py | 0 synutility/SynGraph/Morphism/misc.py | 29 ++ synutility/SynGraph/Transform/core_engine.py | 58 ++-- synutility/SynGraph/Transform/multi_step.py | 223 +++++++++++++++ synutility/SynGraph/Transform/rule_apply.py | 115 +++----- .../SynIO/Format/chemical_conversion.py | 130 +++++++++ synutility/SynIO/Format/dg_to_gml.py | 2 + synutility/SynIO/Format/gml_to_nx.py | 4 +- synutility/SynIO/Format/mol_to_graph.py | 4 +- synutility/SynVis/graph_visualizer.py | 25 +- synutility/SynVis/rsmi_to_fig.py | 4 +- 27 files changed, 1345 insertions(+), 166 deletions(-) create mode 100644 Test/SynAAM/test_aam_validator.py create mode 100644 Test/SynAAM/test_inference.py rename Test/{SynIO/Format => SynAAM}/test_its_construction.py (96%) create mode 100644 Test/SynGraph/Transform/test_multi_step.py create mode 100644 Test/SynIO/Format/test_chemcal_conversion.py create mode 100644 synutility/SynAAM/aam_validator.py create mode 100644 synutility/SynAAM/inference.py rename synutility/{SynIO/Format => SynAAM}/its_construction.py (100%) rename synutility/{SynMOD => SynGraph/Morphism}/__init__.py (100%) create mode 100644 synutility/SynGraph/Morphism/misc.py create mode 100644 synutility/SynGraph/Transform/multi_step.py create mode 100644 synutility/SynIO/Format/chemical_conversion.py diff --git a/.gitignore b/.gitignore index 8e6144a..e1af3fc 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ *.json test_mod.py test_format.py +*dev_zone +test_format.py diff --git a/Test/SynAAM/test_aam_validator.py b/Test/SynAAM/test_aam_validator.py new file mode 100644 index 0000000..94a7099 --- /dev/null +++ b/Test/SynAAM/test_aam_validator.py @@ -0,0 +1,97 @@ +import unittest +from synutility.SynAAM.aam_validator import AAMValidator + + +class TestAMMValidator(unittest.TestCase): + + def setUp(self): + self.true_pair = ( + ( + "[CH:8]=1[S:9][CH:10]=[C:6]([C:5]#[C:4][CH2:3][N:2]([C:11]2=[CH:12]" + + "[CH:13]=[CH:14][CH:15]=[CH:16]2)[CH3:1])[CH:7]=1.[OH2:17]>>[C:5]([N:2]" + + "([CH3:1])[C:11]1=[CH:12][CH:13]=[CH:14][CH:15]=[CH:16]1)([C:6]2=" + + "[CH:10][S:9][CH:8]=[CH:7]2)=[CH:4][CH:3]=[O:17]" + ), + ( + "[OH2:17].[cH:12]1[cH:13][cH:14][cH:15][cH:16][c:11]1[N:2]([CH3:1])" + + "[CH2:3][C:4]#[C:5][c:6]1[cH:10][s:9][cH:8][cH:7]1>>[cH:12]1[cH:13]" + + "[cH:14][cH:15][cH:16][c:11]1[N:2]([CH3:1])[C:5](=[CH:4][CH:3]=[O:17])" + + "[c:6]1[cH:10][s:9][cH:8][cH:7]1" + ), + ) + self.false_pair = ( + ( + "[CH:8]=1[S:9][CH:10]=[C:6]([C:5]#[C:4][CH2:3][N:2]([C:11]2=[CH:12]" + + "[CH:13]=[CH:14][CH:15]=[CH:16]2)[CH3:1])[CH:7]=1.[OH2:17]>>[C:5]" + + "([N:2]([CH3:1])[C:11]1=[CH:12][CH:13]=[CH:14][CH:15]=[CH:16]1)" + + "([C:6]2=[CH:10][S:9][CH:8]=[CH:7]2)=[CH:4][CH:3]=[O:17]" + ), + ( + "[CH3:1][N:2]([CH2:3][C:4]#[C:5][c:7]1[cH:8][cH:9][s:10][cH:11]1)" + + "[c:12]1[cH:13][cH:14][cH:15][cH:16][cH:17]1.[OH2:6]>>[CH3:1][N:2]" + + "([C:3](=[CH:4][CH:5]=[O:6])[c:7]1[cH:8][cH:9][s:10][cH:11]1)" + + "[c:12]1[cH:13][cH:14][cH:15][cH:16][cH:17]1" + ), + ) + self.tautomer = ( + "[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][CH2:6][OH:7]>>[CH3:1][C:2](=[O:3])" + + "[O:7][CH2:6][CH3:5].[OH2:4]", + "[CH3:1][C:2](=[O:3])[OH:4].[CH3:5][CH2:6][OH:7]>>" + + "[CH3:1][C:2](=[O:4])[O:7][CH2:6][CH3:5].[OH2:3]", + ) + + self.data_dict_1 = {"ref": self.true_pair[0], "map": self.true_pair[1]} + self.data_dict_2 = {"ref": self.false_pair[0], "map": self.false_pair[1]} + self.data_dict_3 = {"ref": self.tautomer[0], "map": self.tautomer[1]} + self.data = [self.data_dict_1, self.data_dict_2, self.data_dict_3] + + def test_smiles_check(self): + + self.assertTrue( + AAMValidator.smiles_check( + *self.true_pair, check_method="RC", ignore_aromaticity=False + ) + ) + self.assertFalse( + AAMValidator.smiles_check( + *self.false_pair, check_method="RC", ignore_aromaticity=False + ) + ) + + def test_smiles_check_tautomer(self): + self.assertFalse( + AAMValidator.smiles_check( + self.tautomer[0], + self.tautomer[1], + check_method="RC", + ignore_aromaticity=False, + ) + ) + + self.assertTrue( + AAMValidator.smiles_check_tautomer( + self.tautomer[0], + self.tautomer[1], + check_method="RC", + ignore_aromaticity=True, + ) + ) + + def test_validate_smiles_dataframe(self): + + results = AAMValidator.validate_smiles( + data=self.data, + ground_truth_col="ref", + mapped_cols=["map"], + check_method="RC", + ignore_aromaticity=False, + n_jobs=2, + verbose=0, + ignore_tautomers=False, + ) + self.assertEqual(results[0]["accuracy"], 66.67) + self.assertEqual(results[0]["success_rate"], 100) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynAAM/test_inference.py b/Test/SynAAM/test_inference.py new file mode 100644 index 0000000..d83fe14 --- /dev/null +++ b/Test/SynAAM/test_inference.py @@ -0,0 +1,25 @@ +import unittest +from synutility.SynIO.Format.chemical_conversion import smart_to_gml +from synutility.SynAAM.inference import aam_infer + + +class TestAAMInference(unittest.TestCase): + + def setUp(self): + + self.rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1" + self.gml = smart_to_gml("[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]") + self.expect = ( + "[Br:1][CH2:2][C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]" + + "=[CH:5]1.[CH3:10][O:11][CH2:12][CH2:13][O:14][H:15]>>" + + "[Br:1][H:15].[CH2:2]([C:3]1=[CH:4][CH:6]=[C:7]([Br:8])" + + "[CH:9]=[CH:5]1)[O:14][CH2:13][CH2:12][O:11][CH3:10]" + ) + + def test_aam_infer(self): + result = aam_infer(self.rsmi, self.gml) + self.assertEqual(result[0], self.expect) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/Format/test_its_construction.py b/Test/SynAAM/test_its_construction.py similarity index 96% rename from Test/SynIO/Format/test_its_construction.py rename to Test/SynAAM/test_its_construction.py index 0b34ebb..6d21afe 100644 --- a/Test/SynIO/Format/test_its_construction.py +++ b/Test/SynAAM/test_its_construction.py @@ -1,6 +1,6 @@ import unittest import networkx as nx -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynAAM.its_construction import ITSConstruction class TestITSConstruction(unittest.TestCase): diff --git a/Test/SynAAM/test_normalize_aam.py b/Test/SynAAM/test_normalize_aam.py index 5ccd071..c3c022e 100644 --- a/Test/SynAAM/test_normalize_aam.py +++ b/Test/SynAAM/test_normalize_aam.py @@ -21,7 +21,7 @@ def test_fix_rsmi(self): for both reactants and products.""" input_rsmi = "[C:0]>>[C:1]" expected_rsmi = "[C:1]>>[C:2]" - self.assertEqual(self.normalizer.fix_rsmi(input_rsmi), expected_rsmi) + self.assertEqual(self.normalizer.fix_aam_rsmi(input_rsmi), expected_rsmi) def test_extract_subgraph(self): """Test extraction of a subgraph based on specified indices.""" diff --git a/Test/SynAAM/test_partial_expand.py b/Test/SynAAM/test_partial_expand.py index 0675bb9..fedc468 100644 --- a/Test/SynAAM/test_partial_expand.py +++ b/Test/SynAAM/test_partial_expand.py @@ -1,4 +1,8 @@ import unittest +from synutility.SynAAM.aam_validator import AAMValidator +from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph + from synutility.SynAAM.partial_expand import PartialExpand @@ -16,7 +20,7 @@ def test_expand(self): # Perform the expansion output_rsmi = PartialExpand.expand(input_rsmi) # Assert the result matches the expected output - self.assertEqual(output_rsmi, expected_rsmi) + self.assertTrue(AAMValidator.smiles_check(output_rsmi, expected_rsmi, "ITS")) def test_expand_2(self): input_rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]" @@ -25,7 +29,22 @@ def test_expand_2(self): "[CH3:1][CH2:2][CH2:3][Cl:4].[NH2:5][H:6]" + ">>[CH3:1][CH2:2][CH2:3][NH2:5].[Cl:4][H:6]" ) - self.assertEqual(output_rsmi, expected_rsmi) + self.assertTrue(AAMValidator.smiles_check(output_rsmi, expected_rsmi, "ITS")) + + def test_graph_expand(self): + input_rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1" + expect = ( + "[Br:1][CH2:2][C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]=[CH:5]1." + + "[CH3:10][O:11][CH2:12][CH2:13][O:14][H:15]>>[Br:1][H:15]" + + ".[CH2:2]([C:3]1=[CH:4][CH:6]=[C:7]([Br:8])[CH:9]=[CH:5]1)" + + "[O:14][CH2:13][CH2:12][O:11][CH3:10]" + ) + r, p = rsmi_to_graph( + "[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]", sanitize=False + ) + its = ITSConstruction().ITSGraph(r, p) + output = PartialExpand.graph_expand(its, input_rsmi) + self.assertTrue(AAMValidator.smiles_check(output, expect)) if __name__ == "__main__": diff --git a/Test/SynGraph/Transform/test_multi_step.py b/Test/SynGraph/Transform/test_multi_step.py new file mode 100644 index 0000000..cad6bf3 --- /dev/null +++ b/Test/SynGraph/Transform/test_multi_step.py @@ -0,0 +1,55 @@ +import unittest +from synutility.SynIO.Format.chemical_conversion import smart_to_gml +from synutility.SynGraph.Transform.multi_step import ( + perform_multi_step_reaction, + remove_reagent_from_smiles, + calculate_max_depth, + find_all_paths, +) + + +class TestMultiStep(unittest.TestCase): + def setUp(self) -> None: + smarts = [ + "[CH2:4]([CH:5]=[O:6])[H:7]>>[CH2:4]=[CH:5][O:6][H:7]", + ( + "[CH2:2]=[O:3].[CH2:4]=[CH:5][O:6][H:7]>>[CH2:2]([O:3][H:7])[CH2:4]" + + "[CH:5]=[O:6]" + ), + "[CH2:4]([CH:5]=[O:6])[H:8]>>[CH2:4]=[CH:5][O:6][H:8]", + ( + "[CH2:2]([OH:3])[CH:4]=[CH:5][O:6][H:8]>>[CH2:2]=[CH:4][CH:5]=[O:6]" + + ".[OH:3][H:8]" + ), + ] + self.gml = [smart_to_gml(value) for value in smarts] + self.order = [0, 1, 0, -1] + self.rsmi = "CC=O.CC=O.CCC=O>>CC=O.CC=C(C)C=O.O" + + def test_remove_reagent_from_smiles(self): + rsmi = remove_reagent_from_smiles(self.rsmi) + self.assertEqual(rsmi, "CC=O.CCC=O>>CC=C(C)C=O.O") + + def test_perform_multi_step_reaction(self): + results, _ = perform_multi_step_reaction(self.gml, self.order, self.rsmi) + self.assertEqual(len(results), 4) + + def test_calculate_max_depth(self): + _, reaction_tree = perform_multi_step_reaction(self.gml, self.order, self.rsmi) + max_depth = calculate_max_depth(reaction_tree) + self.assertEqual(max_depth, 4) + + def test_find_all_paths(self): + results, reaction_tree = perform_multi_step_reaction( + self.gml, self.order, self.rsmi + ) + target_products = sorted(self.rsmi.split(">>")[1].split(".")) + max_depth = len(results) + all_paths = find_all_paths(reaction_tree, target_products, self.rsmi, max_depth) + self.assertEqual(len(all_paths), 1) + real_path = all_paths[0][1:] # remove the original reaction + self.assertEqual(len(real_path), 4) + + +if __name__ == "__main__": + unittest.main() diff --git a/Test/SynIO/Format/test_chemcal_conversion.py b/Test/SynIO/Format/test_chemcal_conversion.py new file mode 100644 index 0000000..89df407 --- /dev/null +++ b/Test/SynIO/Format/test_chemcal_conversion.py @@ -0,0 +1,94 @@ +import unittest +import networkx as nx + +from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynIO.Format.chemical_conversion import ( + smiles_to_graph, + rsmi_to_graph, + graph_to_rsmi, + smart_to_gml, + gml_to_smart, +) + +from synutility.SynGraph.Morphism.misc import rule_isomorphism + + +class TestChemicalConversions(unittest.TestCase): + + def setUp(self) -> None: + self.rsmi = "[CH2:1]([H:4])[CH2:2][OH:3]>>[CH2:1]=[CH2:2].[H:4][OH:3]" + self.gml = ( + "rule [\n" + ' ruleID "rule"\n' + " left [\n" + ' edge [ source 1 target 4 label "-" ]\n' + ' edge [ source 1 target 2 label "-" ]\n' + ' edge [ source 2 target 3 label "-" ]\n' + " ]\n" + " context [\n" + ' node [ id 1 label "C" ]\n' + ' node [ id 4 label "H" ]\n' + ' node [ id 2 label "C" ]\n' + ' node [ id 3 label "O" ]\n' + " ]\n" + " right [\n" + ' edge [ source 1 target 2 label "=" ]\n' + ' edge [ source 4 target 3 label "-" ]\n' + " ]\n" + "]" + ) + + self.std = Standardize() + + def test_smiles_to_graph_valid(self): + # Test converting a valid SMILES to a graph + result = smiles_to_graph("[CH3:1][CH2:2][OH:3]", False, True, True) + self.assertIsInstance(result, nx.Graph) + self.assertEqual(result.number_of_nodes(), 3) + + def test_smiles_to_graph_invalid(self): + # Test converting an invalid SMILES string to a graph + result = smiles_to_graph("invalid_smiles", True, False, False) + self.assertIsNone(result) + + def test_rsmi_to_graph_valid(self): + # Test converting valid reaction SMILES to graphs for reactants and products + reactants_graph, products_graph = rsmi_to_graph(self.rsmi, sanitize=True) + self.assertIsInstance(reactants_graph, nx.Graph) + self.assertEqual(reactants_graph.number_of_nodes(), 3) + self.assertIsInstance(products_graph, nx.Graph) + self.assertEqual(products_graph.number_of_nodes(), 3) + + reactants_graph, products_graph = rsmi_to_graph(self.rsmi, sanitize=False) + self.assertIsInstance(reactants_graph, nx.Graph) + self.assertEqual(reactants_graph.number_of_nodes(), 4) + self.assertIsInstance(products_graph, nx.Graph) + self.assertEqual(products_graph.number_of_nodes(), 4) + + def test_rsmi_to_graph_invalid(self): + # Test handling of invalid RSMI format + result = rsmi_to_graph("invalid_format") + self.assertEqual((None, None), result) + + def test_graph_to_rsmi(self): + r, p = rsmi_to_graph(self.rsmi, sanitize=False) + rsmi = graph_to_rsmi(r, p) + self.assertIsInstance(rsmi, str) + self.assertEqual(self.std.fit(rsmi, False), self.std.fit(self.rsmi, False)) + + def test_smart_to_gml(self): + result = smart_to_gml(self.rsmi, core=False, sanitize=False, reindex=False) + self.assertIsInstance(result, str) + self.assertEqual(result, self.gml) + + result = smart_to_gml(self.rsmi, core=False, sanitize=False, reindex=True) + self.assertTrue(rule_isomorphism(result, self.gml)) + + def test_gml_to_smart(self): + smarts, _ = gml_to_smart(self.gml) + self.assertIsInstance(smarts, str) + self.assertEqual(self.std.fit(smarts, False), self.std.fit(self.rsmi, False)) + + +if __name__ == "__main__": + unittest.main() diff --git a/lint.sh b/lint.sh index c088240..b17d8cb 100755 --- a/lint.sh +++ b/lint.sh @@ -1,6 +1,6 @@ #!/bin/bash flake8 . --count --max-complexity=13 --max-line-length=120 \ - --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501" \ + --per-file-ignores="__init__.py:F401, chemical_reaction_visualizer.py:E501, test_reagent.py:E501, inference.py:F401" \ --exclude venv,core_engine.py,rule_apply.py \ --statistics diff --git a/pyproject.toml b/pyproject.toml index b15c061..24607e9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "synutility" -version = "0.0.11" +version = "0.0.12" authors = [ {name="Tieu Long Phan", email="tieu@bioinf.uni-leipzig.de"} ] diff --git a/synutility/SynAAM/aam_validator.py b/synutility/SynAAM/aam_validator.py new file mode 100644 index 0000000..53e3b4c --- /dev/null +++ b/synutility/SynAAM/aam_validator.py @@ -0,0 +1,254 @@ +import pandas as pd +import networkx as nx +from operator import eq +from itertools import combinations +from joblib import Parallel, delayed +from typing import Dict, List, Tuple, Union, Optional +from networkx.algorithms.isomorphism import generic_node_match, generic_edge_match + +from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph +from synutility.SynAAM.misc import get_rc, enumerate_tautomers, mapping_success_rate + + +class AAMValidator: + def __init__(self): + """Initializes the AAMValidator class.""" + pass + + @staticmethod + def check_equivariant_graph( + its_graphs: List[nx.Graph], + ) -> Tuple[List[Tuple[int, int]], int]: + """ + Checks for isomorphism among a list of ITS graphs and + identifies all pairs of isomorphic graphs. + + Parameters: + - its_graphs (List[nx.Graph]): A list of ITS graphs. + + Returns: + - List[Tuple[int, int]]: A list of tuples representing + pairs of indices of isomorphic graphs. + - int: The count of unique isomorphic graph pairs found. + """ + nodeLabelNames = ["typesGH"] + nodeLabelDefault = ["*", False, 0, 0, ()] + nodeLabelOperator = [eq, eq, eq, eq, eq] + nodeMatch = generic_node_match( + nodeLabelNames, nodeLabelDefault, nodeLabelOperator + ) + edgeMatch = generic_edge_match("order", 1, eq) + + classified = [] + for i, j in combinations(range(len(its_graphs)), 2): + if nx.is_isomorphic( + its_graphs[i], its_graphs[j], node_match=nodeMatch, edge_match=edgeMatch + ): + classified.append((i, j)) + + return classified, len(classified) + + @staticmethod + def smiles_check( + mapped_smile: str, + ground_truth: str, + check_method: str = "RC", # or 'ITS' + ignore_aromaticity: bool = False, + ) -> bool: + """ + Checks the equivalence of mapped SMILES against ground truth + using reaction center (RC) or ITS graph method. + + Parameters: + - mapped_smile (str): The mapped SMILES string. + - ground_truth (str): The ground truth SMILES string. + - check_method (str): The method used for validation ('RC' or 'ITS'). + - ignore_aromaticity (bool): Flag to ignore aromaticity in ITS graph construction. + + Returns: + - bool: True if the mapped SMILES is equivalent to the ground truth, + False otherwise. + """ + its_graphs = [] + rc_graphs = [] + try: + for rsmi in [mapped_smile, ground_truth]: + G, H = rsmi_to_graph( + rsmi=rsmi, sanitize=True, drop_non_aam=True, light_weight=True + ) + + ITS = ITSConstruction.ITSGraph(G, H, ignore_aromaticity) + its_graphs.append(ITS) + rc = get_rc(ITS) + rc_graphs.append(rc) + + _, equivariant = AAMValidator.check_equivariant_graph( + rc_graphs if check_method == "RC" else its_graphs + ) + return equivariant == 1 + + except Exception as e: + print("An error occurred:", str(e)) + return False + + @staticmethod + def smiles_check_tautomer( + mapped_smile: str, + ground_truth: str, + check_method: str = "RC", # or 'ITS' + ignore_aromaticity: bool = False, + ) -> Optional[bool]: + """ + Determines if a given mapped SMILE string is equivalent to any tautomer of + a ground truth SMILES string using a specified comparison method. + + Parameters: + - mapped_smile (str): The mapped SMILES string to check against the tautomers of + the ground truth. + - ground_truth (str): The reference SMILES string for generating possible + tautomers. + - check_method (str): The method used for checking equivalence. Default is 'RC'. + Possible values are 'RC' for reaction center or 'ITS'. + - ignore_aromaticity (bool): Flag to ignore differences in aromaticity between + the mapped SMILE and the tautomers.Default is False. + + Returns: + - Optional[bool]: True if the mapped SMILE matches any of the enumerated tautomers + of the ground truth according to the specified check method. + Returns False if no match is found. + Returns None if an error occurs during processing. + + Raises: + - Exception: If an error occurs during the tautomer enumeration + or the comparison process. + """ + try: + ground_truth_tautomers = enumerate_tautomers(ground_truth) + return any( + AAMValidator.smiles_check( + mapped_smile, t, check_method, ignore_aromaticity + ) + for t in ground_truth_tautomers + ) + except Exception as e: + print(f"An error occurred: {e}") + return None + + @staticmethod + def check_pair( + mapping: Dict[str, str], + mapped_col: str, + ground_truth_col: str, + check_method: str = "RC", + ignore_aromaticity: bool = False, + ignore_tautomers: bool = True, + ) -> bool: + """ + Checks the equivalence between the mapped and ground truth + values within a given mapping dictionary, using a specified check method. + The check can optionally ignore aromaticity. + + Parameters: + - mapping (Dict[str, str]): A dictionary containing the data entries to check. + - mapped_col (str): The key in the mapping dictionary corresponding + to the mapped value. + - ground_truth_col (str): The key in the mapping dictionary corresponding + to the ground truth value. + - check_method (str, optional): The method used for checking the equivalence. + Defaults to 'RC'. + - ignore_aromaticity (bool, optional): Flag to indicate whether aromaticity + should be ignored during the check. Defaults to False. + - ignore_tautomers (bool, optional): Flag to indicate whether tautomers + should be ignored during the check. Defaults to False. + + Returns: + - bool: The result of the check, indicating whether the mapped value is + equivalent to the ground truth according to the specified method + and considerations regarding aromaticity. + """ + if ignore_tautomers: + return AAMValidator.smiles_check( + mapping[mapped_col], + mapping[ground_truth_col], + check_method, + ignore_aromaticity, + ) + else: + return AAMValidator.smiles_check_tautomer( + mapping[mapped_col], + mapping[ground_truth_col], + check_method, + ignore_aromaticity, + ) + + @staticmethod + def validate_smiles( + data: Union[pd.DataFrame, List[Dict[str, str]]], + ground_truth_col: str = "ground_truth", + mapped_cols: List[str] = ["rxn_mapper", "graphormer", "local_mapper"], + check_method: str = "RC", + ignore_aromaticity: bool = False, + n_jobs: int = 1, + verbose: int = 0, + ignore_tautomers=True, + ) -> List[Dict[str, Union[str, float, List[bool]]]]: + """ + Validates collections of mapped SMILES against their ground truths for + multiple mappers and calculates the accuracy. + + Parameters: + - data (Union[pd.DataFrame, List[Dict[str, str]]]): + The input data containing mapped and ground truth SMILES. + - id_col (str): The name of the column or key containing the reaction ID. + - ground_truth_col (str): The name of the column or key containing + the ground truth SMILES. + - mapped_cols (List[str]): The list of columns or keys containing + the mapped SMILES for different mappers. + - check_method (str): The method used for validation ('RC' or 'ITS'). + - ignore_aromaticity (bool): Flag to ignore aromaticity in ITS graph construction. + - n_jobs (int): The number of parallel jobs to run. + - verbose (int): The verbosity level for joblib's parallel execution. + + Returns: + - List[Dict[str, Union[str, float, List[bool]]]]: A list of dictionaries, each + containing the mapper name, accuracy, and individual results for each SMILES pair. + """ + + validation_results = [] + + for mapped_col in mapped_cols: + + if isinstance(data, pd.DataFrame): + mappings = data.to_dict("records") + elif isinstance(data, list): + mappings = data + else: + raise ValueError( + "Data must be either a pandas DataFrame or a list of dictionaries." + ) + + results = Parallel(n_jobs=n_jobs, verbose=verbose)( + delayed(AAMValidator.check_pair)( + mapping, + mapped_col, + ground_truth_col, + check_method, + ignore_aromaticity, + ignore_tautomers, + ) + for mapping in mappings + ) + accuracy = sum(results) / len(mappings) if mappings else 0 + mapped_data = [value[mapped_col] for value in mappings] + + validation_results.append( + { + "mapper": mapped_col, + "accuracy": round(100 * accuracy, 2), + "results": results, + "success_rate": mapping_success_rate(mapped_data), + } + ) + + return validation_results diff --git a/synutility/SynAAM/inference.py b/synutility/SynAAM/inference.py new file mode 100644 index 0000000..24ad997 --- /dev/null +++ b/synutility/SynAAM/inference.py @@ -0,0 +1,73 @@ +import torch +from typing import List, Any +from synutility.SynIO.Format.dg_to_gml import DGToGML +from synutility.SynAAM.normalize_aam import NormalizeAAM +from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynGraph.Transform.rule_apply import rule_apply + +std = Standardize() + + +def aam_infer(rsmi: str, gml: Any) -> List[str]: + """ + Infers a set of normalized SMILES from a reaction SMILES string and a graph model (GML). + + This function takes a reaction SMILES string (rsmi) and a graph model (gml), applies the + reaction transformation using the graph model, normalizes and standardizes the resulting + SMILES, and returns a list of SMILES that match the original reaction's structure after + normalization and standardization. + + Steps: + 1. The reactants in the reaction SMILES string are separated. + 2. The transformation is applied to the reactants using the provided graph model (gml). + 3. The resulting SMILES are transformed to a canonical form. + 4. The resulting SMILES are normalized and standardized. + 5. The function returns the normalized SMILES that match the original reaction SMILES. + + Parameters: + - rsmi (str): The reaction SMILES string in the form "reactants >> products". + - gml (Any): A graph model or data structure used for applying the reaction transformation. + + Returns: + - List[str]: A list of valid, normalized, and standardized SMILES strings that match the original reaction SMILES. + """ + # Split the input reaction SMILES into reactants and products + smiles = rsmi.split(">>")[0].split(".") + + # Apply the reaction transformation based on the graph model (GML) + dg = rule_apply(smiles, gml) + + # Get the transformed reaction SMILES from the graph + transformed_rsmi = list(DGToGML.getReactionSmiles(dg).values()) + transformed_rsmi = [value[0] for value in transformed_rsmi] + + # Normalize the transformed SMILES + normalized_rsmi = [] + for value in transformed_rsmi: + try: + value = NormalizeAAM().fit(value) + normalized_rsmi.append(value) + except Exception as e: + print(e) + continue + + # Standardize the normalized SMILES + curated_smiles = [] + for value in normalized_rsmi: + try: + curated_smiles.append(std.fit(value)) + except Exception as e: + print(e) + curated_smiles.append(None) + continue + + # Standardize the original SMILES for comparison + org_smiles = std.fit(rsmi) + + # Filter out the SMILES that match the original reaction SMILES + final = [] + for key, value in enumerate(curated_smiles): + if value == org_smiles: + final.append(normalized_rsmi[key]) + + return final diff --git a/synutility/SynIO/Format/its_construction.py b/synutility/SynAAM/its_construction.py similarity index 100% rename from synutility/SynIO/Format/its_construction.py rename to synutility/SynAAM/its_construction.py diff --git a/synutility/SynAAM/misc.py b/synutility/SynAAM/misc.py index b08fd07..3efdcfa 100644 --- a/synutility/SynAAM/misc.py +++ b/synutility/SynAAM/misc.py @@ -1,9 +1,14 @@ +import re import networkx as nx +from rdkit import Chem +from rdkit.Chem.MolStandardize import rdMolStandardize + +from typing import Optional, List def get_rc( ITS: nx.Graph, - element_key: str = "element", + element_key: list = ["element", "charge", "typesGH"], bond_key: str = "order", standard_key: str = "standard_order", ) -> nx.Graph: @@ -12,23 +17,27 @@ def get_rc( where the bond order changes, indicating a reaction event. Parameters: - ITS (nx.Graph): The ITS graph to extract the RC from. - element_key (str): Node attribute key for atom symbols. Defaults to 'element'. - bond_key (str): Edge attribute key for bond order. Defaults to 'order'. - standard_key (str): Edge attribute key for standard order information. Defaults to 'standard_order'. + - ITS (nx.Graph): The ITS graph to extract the RC from. + - element_key (list): List of node attribute keys for atom properties. + Defaults to ['element', 'charge', 'typesGH']. + - bond_key (str): Edge attribute key for bond order. Defaults to 'order'. + - standard_key (str): Edge attribute key for standard order information. + Defaults to 'standard_order'. Returns: - nx.Graph: A new graph representing the reaction center of the ITS. + - nx.Graph: A new graph representing the reaction center of the ITS. """ rc = nx.Graph() for n1, n2, data in ITS.edges(data=True): - if data[bond_key][0] != data[bond_key][1]: - rc.add_node(n1, **{element_key: ITS.nodes[n1][element_key]}) - rc.add_node(n2, **{element_key: ITS.nodes[n2][element_key]}) + if data.get(bond_key, [None, None])[0] != data.get(bond_key, [None, None])[1]: + rc.add_node( + n1, **{k: ITS.nodes[n1][k] for k in element_key if k in ITS.nodes[n1]} + ) + rc.add_node( + n2, **{k: ITS.nodes[n2][k] for k in element_key if k in ITS.nodes[n2]} + ) rc.add_edge( - n1, - n2, - **{bond_key: data[bond_key], standard_key: data[standard_key]}, + n1, n2, **{bond_key: data[bond_key], standard_key: data[standard_key]} ) return rc @@ -140,3 +149,95 @@ def compare_graphs( return False return True + + +def enumerate_tautomers(reaction_smiles: str) -> Optional[List[str]]: + """ + Enumerates possible tautomers for reactants while canonicalizing the products in a + reaction SMILES string. This function first splits the reaction SMILES string into + reactants and products. It then generates all possible tautomers for the reactants and + canonicalizes the product molecule. The function returns a list of reaction SMILES + strings for each tautomer of the reactants combined with the canonical product. + + Parameters: + - reaction_smiles (str): A SMILES string of the reaction formatted as + 'reactants>>products'. + + Returns: + - List[str] | None: A list of SMILES strings for the reaction, with each string + representing a different + - tautomer of the reactants combined with the canonicalized products. Returns None if + an error occurs or if invalid SMILES strings are provided. + + Raises: + - ValueError: If the provided SMILES strings cannot be converted to molecule objects, + indicating invalid input. + """ + try: + # Split the input reaction SMILES string into reactants and products + reactants_smiles, products_smiles = reaction_smiles.split(">>") + + # Convert SMILES strings to molecule objects + reactants_mol = Chem.MolFromSmiles(reactants_smiles) + products_mol = Chem.MolFromSmiles(products_smiles) + + if reactants_mol is None or products_mol is None: + raise ValueError( + "Invalid SMILES string provided for reactants or products." + ) + + # Initialize tautomer enumerator + + enumerator = rdMolStandardize.TautomerEnumerator() + + # Enumerate tautomers for the reactants and canonicalize the products + try: + reactants_can = enumerator.Enumerate(reactants_mol) + except Exception as e: + print(f"An error occurred: {e}") + reactants_can = [reactants_mol] + products_can = products_mol + + # Convert molecule objects back to SMILES strings + reactants_can_smiles = [Chem.MolToSmiles(i) for i in reactants_can] + products_can_smiles = Chem.MolToSmiles(products_can) + + # Combine each reactant tautomer with the canonical product in SMILES format + rsmi_list = [i + ">>" + products_can_smiles for i in reactants_can_smiles] + if len(rsmi_list) == 0: + return [reaction_smiles] + else: + # rsmi_list.remove(reaction_smiles) + rsmi_list.insert(0, reaction_smiles) + return rsmi_list + + except Exception as e: + print(f"An error occurred: {e}") + return [reaction_smiles] + + +def mapping_success_rate(list_mapping_data): + """ + Calculate the success rate of entries containing atom mappings in a list of data + strings. + + Parameters: + - list_mapping_in_data (list of str): List containing strings to be searched for atom + mappings. + + Returns: + - float: The success rate of finding atom mappings in the list as a percentage. + + Raises: + - ValueError: If the input list is empty. + """ + atom_map_pattern = re.compile(r":\d+") + if not list_mapping_data: + raise ValueError("The input list is empty, cannot calculate success rate.") + + success = sum( + 1 for entry in list_mapping_data if re.search(atom_map_pattern, entry) + ) + rate = 100 * (success / len(list_mapping_data)) + + return round(rate, 2) diff --git a/synutility/SynAAM/normalize_aam.py b/synutility/SynAAM/normalize_aam.py index 6e8c2d8..f837493 100644 --- a/synutility/SynAAM/normalize_aam.py +++ b/synutility/SynAAM/normalize_aam.py @@ -3,9 +3,9 @@ from rdkit import Chem from typing import List -from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph from synutility.SynIO.Format.graph_to_mol import GraphToMol -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynAAM.its_construction import ITSConstruction from synutility.SynAAM.misc import its_decompose, get_rc @@ -50,7 +50,7 @@ def fix_atom_mapping(smiles: str) -> str: return pattern.sub(NormalizeAAM.increment, smiles) @staticmethod - def fix_rsmi(rsmi: str) -> str: + def fix_aam_rsmi(rsmi: str) -> str: """ Adjusts atom mapping numbers in both reactant and product parts of a reaction SMILES (RSMI). @@ -63,6 +63,54 @@ def fix_rsmi(rsmi: str) -> str: r, p = rsmi.split(">>") return f"{NormalizeAAM.fix_atom_mapping(r)}>>{NormalizeAAM.fix_atom_mapping(p)}" + @staticmethod + def fix_rsmi_kekulize(rsmi: str) -> str: + """ + Filters the reactants and products of a reaction SMILES string. + + Parameters: + - rsmi (str): A string representing the reaction SMILES in the form of "reactants >> products". + + Returns: + - str: A filtered reaction SMILES string where invalid reactants/products are removed. + """ + # Split the reaction into reactants and products + reactants, products = rsmi.split(">>") + + # Filter valid reactants and products + filtered_reactants = NormalizeAAM.fix_kekulize(reactants) + filtered_products = NormalizeAAM.fix_kekulize(products) + + # Return the filtered reaction SMILES + return f"{filtered_reactants}>>{filtered_products}" + + @staticmethod + def fix_kekulize(smiles: str) -> str: + """ + Filters and returns valid SMILES strings from a string of SMILES, joined by '.'. + + This function processes a string of SMILES separated by periods (e.g., "CCO.CC=O"), + filters out invalid SMILES, and returns a string of valid SMILES joined by periods. + + Parameters: + - smiles (str): A string containing SMILES strings separated by periods ('.'). + + Returns: + - str: A string of valid SMILES, joined by periods ('.'). + """ + smiles_list = smiles.split(".") # Split SMILES by period + valid_smiles = [] # List to store valid SMILES strings + + for smile in smiles_list: + mol = Chem.MolFromSmiles(smile, sanitize=False) + if mol: # Check if molecule is valid + valid_smiles.append( + Chem.MolToSmiles( + mol, canonical=True, kekuleSmiles=True, allHsExplicit=True + ) + ) + return ".".join(valid_smiles) # Return valid SMILES joined by '.' + @staticmethod def extract_subgraph(graph: nx.Graph, indices: List[int]) -> nx.Graph: """ @@ -114,8 +162,9 @@ def fit(self, rsmi: str, fix_aam_indice: bool = True) -> str: Returns: str: The resulting reaction SMILES string with updated atom mappings. """ + rsmi = self.fix_rsmi_kekulize(rsmi) if fix_aam_indice: - rsmi = self.fix_rsmi(rsmi) + rsmi = self.fix_aam_rsmi(rsmi) r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) its = ITSConstruction().ITSGraph(r_graph, p_graph) rc = get_rc(its) diff --git a/synutility/SynAAM/partial_expand.py b/synutility/SynAAM/partial_expand.py index e26d249..71c5964 100644 --- a/synutility/SynAAM/partial_expand.py +++ b/synutility/SynAAM/partial_expand.py @@ -1,13 +1,13 @@ +import networkx as nx from synutility.SynIO.Format.nx_to_gml import NXToGML -from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph from synutility.SynAAM.misc import its_decompose, get_rc -from synutility.SynAAM.normalize_aam import NormalizeAAM - +from synutility.SynAAM.its_construction import ITSConstruction from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynAAM.inference import aam_infer -from synutility.SynGraph.Transform.rule_apply import rule_apply, getReactionSmiles +std = Standardize() class PartialExpand: @@ -15,68 +15,102 @@ class PartialExpand: A class for partially expanding reaction SMILES (RSMI) by applying transformation rules based on the reaction center (RC) graph. + This class provides methods for expanding a given RSMI by identifying the + reaction center (RC), applying transformation rules, and standardizing atom mappings + to generate a full AAM RSMI. + Methods: - expand(rsmi: str) -> str: - Expands a reaction SMILES string and returns the transformed RSMI. + - expand(rsmi: str) -> str: + Expands a reaction SMILES string by identifying the reaction center (RC), + applying transformation rules, and standardizing atom mappings. + + - graph_expand(partial_its: nx.Graph, rsmi: str) -> str: + Expands a reaction SMILES string using an Imaginary Transition State + (ITS) graph and applies the transformation rule based on the reaction center (RC). """ def __init__(self) -> None: """ Initializes the PartialExpand class. + + This constructor currently does not initialize any instance-specific attributes. """ pass @staticmethod - def expand(rsmi: str) -> str: + def graph_expand(partial_its: nx.Graph, rsmi: str) -> str: """ - Expands a reaction SMILES string by identifying the reaction center (RC), - applying transformation rules, and standardizing the atom mappings. + Expands a reaction SMILES string by applying transformation rules using an + ITS graph based on the reaction center (RC) graph. + + This method extracts the reaction center (RC) from the ITS graph, decomposes it + into reactant and product graphs, generates a GML rule for transformation, + and applies the rule to the RSMI string. Parameters: - - rsmi (str): The input reaction SMILES string. + - partial_its (nx.Graph): The Intermediate Transition State (ITS) graph. + - rsmi (str): The input reaction SMILES string to be expanded. Returns: - - str: The transformed reaction SMILES string. + - str: The transformed reaction SMILES string after applying the + transformation rules. """ - try: - # Convert RSMI to reactant and product graphs - r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) + # Extract the reaction center (RC) graph from the ITS graph + rc = get_rc(partial_its) - # Construct ITS (Intermediate Transition State) graph - its = ITSConstruction().ITSGraph(r_graph, p_graph) + # Decompose the RC into reactant and product graphs + r_graph, p_graph = its_decompose(rc) - # Extract the reaction center (RC) graph - rc = get_rc(its) + # Transform the graph into a GML rule + rule = NXToGML().transform((r_graph, p_graph, rc)) - # Decompose the RC into reactant and product graphs - r_graph, p_graph = its_decompose(rc) + # Apply the transformation rule to the RSMI + transformed_rsmi = aam_infer(rsmi, rule)[0] - # Transform the graph to a GML rule - rule = NXToGML().transform((r_graph, p_graph, rc)) + return transformed_rsmi - # Standardize the input reaction SMILES - original_rsmi = Standardize().fit(rsmi) + @staticmethod + def expand(rsmi: str) -> str: + """ + Expands a reaction SMILES string by identifying the reaction center (RC), + applying transformation rules, and standardizing the atom mappings. - # Extract reactants from the standardized RSMI - reactants = original_rsmi.split(">>")[0].split(".") + This method constructs the Intermediate Transition State (ITS) graph from the + input RSMI, applies the reaction transformation rules using `graph_expand`, + and returns the transformed reaction SMILES string. - # Apply the transformation rule to the reactants - transformed_graph = rule_apply(reactants, rule) + Parameters: + - rsmi (str): The input reaction SMILES string to be expanded. - # Extract the transformed reaction SMILES - transformed_rsmi = list(getReactionSmiles(transformed_graph).values())[0][0] + Returns: + - str: The transformed reaction SMILES string after applying the + transformation rules. - # Normalize atom mappings in the transformed RSMI - normalized_rsmi = NormalizeAAM().fit(transformed_rsmi) + Raises: + - Exception: If an error occurs during the expansion process, the original RSMI + is returned. + """ + try: + # Convert RSMI to reactant and product graphs + r_graph, p_graph = rsmi_to_graph(rsmi, light_weight=True, sanitize=False) + + # Construct the ITS graph from the reactant and product graphs + its = ITSConstruction().ITSGraph(r_graph, p_graph) - return normalized_rsmi + # Standardize smiles + rsmi = std.fit(rsmi) + # Apply graph expansion and return the result + return PartialExpand.graph_expand(its, rsmi) except Exception as e: + # Log the error and return the original RSMI if something goes wrong print(f"An error occurred during RSMI expansion: {e}") - return rsmi + return None if __name__ == "__main__": rsmi = "[CH3][CH:1]=[CH2:2].[H:3][H:4]>>[CH3][CH:1]([H:3])[CH2:2][H:4]" rsmi = "CC[CH2:3][Cl:1].[NH2:2][H:4]>>CC[CH2:3][NH2:2].[Cl:1][H:4]" print(PartialExpand.expand(rsmi)) +# self.rsmi = "BrCc1ccc(Br)cc1.COCCO>>Br.COCCOCc1ccc(Br)cc1" +# self.gml = smart_to_gml("[Br:1][CH3:2].[OH:3][H:4]>>[Br:1][H:4].[CH3:2][OH:3]") diff --git a/synutility/SynMOD/__init__.py b/synutility/SynGraph/Morphism/__init__.py similarity index 100% rename from synutility/SynMOD/__init__.py rename to synutility/SynGraph/Morphism/__init__.py diff --git a/synutility/SynGraph/Morphism/misc.py b/synutility/SynGraph/Morphism/misc.py new file mode 100644 index 0000000..2a1bb04 --- /dev/null +++ b/synutility/SynGraph/Morphism/misc.py @@ -0,0 +1,29 @@ +from mod import ruleGMLString + + +def rule_isomorphism(rule_1: str, rule_2: str) -> bool: + """ + Determines if two rule representations, given in GML format, are isomorphic. + + This function converts two GML strings into `ruleGMLString` objects and checks + if these two objects are isomorphic. Isomorphism here is determined by the method + `isomorphism` of the `ruleGMLString` class, which should return `1` for isomorphic + structures and `0` otherwise. + + Parameters: + - rule_1 (str): The GML string representation of the first rule. + - rule_2 (str): The GML string representation of the second rule. + + Returns: + - bool: `True` if the two rules are isomorphic; `False` otherwise. + + Raises: + - Any exceptions thrown by the `ruleGMLString` initialization or methods should + be documented here, if there are any known potential issues. + """ + # Create ruleGMLString objects from the GML strings + rule_obj_1 = ruleGMLString(rule_1) + rule_obj_2 = ruleGMLString(rule_2) + + # Check for isomorphism and return the result + return rule_obj_1.isomorphism(rule_obj_2) == 1 diff --git a/synutility/SynGraph/Transform/core_engine.py b/synutility/SynGraph/Transform/core_engine.py index 9b1c776..7f237a5 100644 --- a/synutility/SynGraph/Transform/core_engine.py +++ b/synutility/SynGraph/Transform/core_engine.py @@ -1,7 +1,9 @@ -from typing import List -from synutility.SynIO.data_type import load_gml_as_text from rdkit import Chem -from copy import deepcopy +from pathlib import Path +from typing import List, Union +from collections import Counter +from synutility.SynIO.data_type import load_gml_as_text + import torch from mod import * @@ -49,7 +51,7 @@ def generate_reaction_smiles( @staticmethod def perform_reaction( - rule_file_path: str, + rule_file_path: Union[str, str], initial_smiles: List[str], prediction_type: str = "forward", print_results: bool = False, @@ -94,7 +96,16 @@ def deduplicateGraphs(initial): initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False ) # Load the reaction rule from the GML file - gml_content = load_gml_as_text(rule_file_path) + rule_path = Path(rule_file_path) + + try: + if rule_path.is_file(): + gml_content = load_gml_as_text(rule_file_path) + else: + gml_content = rule_file_path + except Exception as e: + # print(f"An error occurred while loading the GML file: {e}") + gml_content = rule_file_path reaction_rule = ruleGMLString(gml_content, invert=invert_rule, add=False) # Initialize the derivation graph and execute the strategy dg = DG(graphDatabase=initial_molecules) @@ -107,8 +118,10 @@ def deduplicateGraphs(initial): for e in dg.edges: productSmiles = [v.graph.smiles for v in e.targets] temp_results.append(productSmiles) + # print(productSmiles) if len(temp_results) == 0: + # print(1) dg = DG(graphDatabase=initial_molecules) # dg.build().execute(strategy, verbosity=8) config.dg.doRuleIsomorphismDuringBinding = False @@ -118,27 +131,22 @@ def deduplicateGraphs(initial): temp_results, small_educt = [], [] for edge in dg.edges: temp_results.append([vertex.graph.smiles for vertex in edge.targets]) - small_educt.extend([vertex.graph.smiles for vertex in edge.sources]) - - small_educt_set = [ - Chem.CanonSmiles(smile) for smile in small_educt if smile is not None - ] - - reagent = deepcopy(initial_smiles) - for value in small_educt_set: - if value in reagent: - reagent.remove(value) - - # Update solutions with reagents and normalize SMILES - for solution in temp_results: + small_educt.append([vertex.graph.smiles for vertex in edge.sources]) + + for key, solution in enumerate(temp_results): + educt = small_educt[key] + small_educt_counts = Counter( + Chem.CanonSmiles(smile) for smile in educt if smile is not None + ) + reagent_counts = Counter([Chem.CanonSmiles(s) for s in initial_smiles]) + reagent_counts.subtract(small_educt_counts) + reagent = [ + smile + for smile, count in reagent_counts.items() + for _ in range(count) + if count > 0 + ] solution.extend(reagent) - for i, smile in enumerate(solution): - try: - mol = Chem.MolFromSmiles(smile) - if mol: # Only convert if mol creation was successful - solution[i] = Chem.MolToSmiles(mol) - except Exception as e: - print(f"Error processing SMILES {smile}: {str(e)}") reaction_processing_map = { "forward": lambda smiles: CoreEngine.generate_reaction_smiles( diff --git a/synutility/SynGraph/Transform/multi_step.py b/synutility/SynGraph/Transform/multi_step.py new file mode 100644 index 0000000..b25f2a4 --- /dev/null +++ b/synutility/SynGraph/Transform/multi_step.py @@ -0,0 +1,223 @@ +from collections import Counter +from typing import List, Dict, Tuple +from synutility.SynChem.Reaction.standardize import Standardize +from synutility.SynGraph.Transform.core_engine import CoreEngine + +std = Standardize() + + +def remove_reagent_from_smiles(rsmi: str) -> str: + """ + Removes common molecules from the reactants and products in a SMILES reaction string. + + This function identifies the molecules that appear on both sides of the reaction + (reactants and products) and removes one occurrence of each common molecule from + both sides. + + Parameters: + - rsmi (str): A SMILES string representing a chemical reaction in the form: + 'reactant1.reactant2...>>product1.product2...' + + Returns: + - str: A new SMILES string with the common molecules removed, in the form: + 'reactant1.reactant2...>>product1.product2...' + + Example: + >>> remove_reagent_from_smiles('CC=O.CC=O.CCC=O>>CC=CO.CC=O.CC=O') + 'CCC=O>>CC=CO' + """ + + # Split the input SMILES string into reactants and products + reactants, products = rsmi.split(">>") + + # Split the reactants and products by '.' to separate molecules + reactant_molecules = reactants.split(".") + product_molecules = products.split(".") + + # Count the occurrences of each molecule in reactants and products + reactant_count = Counter(reactant_molecules) + product_count = Counter(product_molecules) + + # Find common molecules between reactants and products + common_molecules = set(reactant_count) & set(product_count) + + # Remove common molecules by the minimum occurrences in both reactants and products + for molecule in common_molecules: + common_occurrences = min(reactant_count[molecule], product_count[molecule]) + + # Decrease the count by the common occurrences + reactant_count[molecule] -= common_occurrences + product_count[molecule] -= common_occurrences + + # Rebuild the lists of reactant and product molecules after removal + filtered_reactant_molecules = [ + molecule for molecule, count in reactant_count.items() for _ in range(count) + ] + filtered_product_molecules = [ + molecule for molecule, count in product_count.items() for _ in range(count) + ] + + # Join the remaining molecules back into SMILES strings + new_reactants = ".".join(filtered_reactant_molecules) + new_products = ".".join(filtered_product_molecules) + + # Return the updated reaction string + return f"{new_reactants}>>{new_products}" + + +def perform_multi_step_reaction( + gml_list: List[str], order: List[int], rsmi: str +) -> Tuple[List[List[str]], Dict[str, List[str]]]: + """ + Applies a sequence of multi-step reactions to a starting SMILES string. The function + processes each reaction step in a specified order, and returns both the intermediate + and final products, as well as a mapping of reactant SMILES to their + corresponding products. + + Parameters: + - gml_list (List[str]): A list of reaction rules (in GML format) to be applied. + Each element corresponds to a reaction step. + - order (List[int]): A list of integers that defines the order in which the + reaction steps should be applied. Each integer is an index referring to the position + of a reaction rule in the `gml_list`. + - rsmi (str): The starting reaction SMILES string, representing the reactants for the + first reaction. + + Returns: + - Tuple[List[List[str]], Dict[str, List[str]]]: + - A list of lists of SMILES strings, where each inner list contains the + RSMI generated at each reaction step. + - A dictionary mapping each RSMI string to the resulting products after applying + the reaction rules. The keys are the input RSMIs, and the values are the + resulting product SMILES strings. + """ + + # Initialize CoreEngine for reaction processing + core = CoreEngine() + # Initialize a dictionary to hold reaction results + reaction_results = {} + + # List to store the results of each reaction step + all_steps: List[List[str]] = [] + result: List[str] = [rsmi] # Initial result is the input SMILES string + + # Loop over the reaction steps in the specified order + for i, j in enumerate(order): + # Get the reaction SMILES (RSMI) for the current step + current_step_gml = gml_list[j] + new_result: List[str] = [] # List to hold products for this step + + # Apply the reaction for each current reactant SMILES + for current_rsmi in result: + smi_lst = ( + current_rsmi.split(">>")[0].split( + "." + ) # Split reactants at the first step + if i == 0 + else current_rsmi.split(">>")[1].split( + "." + ) # Split products for subsequent steps + ) + + # Perform the reaction using the CoreEngine + o = core.perform_reaction(current_step_gml, smi_lst) + + # Apply standardization on the products + o = [std.fit(i) for i in o] + + # Collect the new results (products) from this reaction step + new_result.extend(o) + + # Record the reaction results in the dictionary, mapping input RSMI to output products + if len(o) > 0: + reaction_results[current_rsmi] = o + + # Update the result list for the next step + result = new_result + + # Append the results of this step to the overall steps list + all_steps.append(result) + + # Return the results: a list of all steps and a dictionary of reaction results + return all_steps, reaction_results + + +def calculate_max_depth(reaction_tree, current_node=None, depth=0): + """ + Calculate the maximum depth of a reaction tree. + + Parameters: + - reaction_tree (dict): A dictionary where keys are reaction SMILES (RSMI) + and values are lists of product reactions. + - current_node (str): The current node in the tree being explored (reaction SMILES). + - depth (int): The current depth of the tree. + + Returns: + - int: The maximum depth of the tree. + """ + # If current_node is None, start from the root node (first key in the reaction tree) + if current_node is None: + current_node = list(reaction_tree.keys())[0] + + # Get the products of the current node (reaction) + products = reaction_tree.get(current_node, []) + + # If no products, we are at a leaf node, return the current depth + if not products: + return depth + + # Recursively calculate the depth for each product and return the maximum + max_subtree_depth = max( + calculate_max_depth(reaction_tree, product, depth + 1) for product in products + ) + return max_subtree_depth + + +def find_all_paths( + reaction_tree, + target_products, + current_node, + target_depth, + current_depth=0, + path=None, +): + """ + Recursively find all paths from the root to the maximum depth in the reaction tree. + + Parameters: + - reaction_tree (dict): A dictionary of reaction SMILES with products. + - current_node (str): The current node (reaction SMILES). + - target_depth (int): The depth at which the product matches the root's product. + - current_depth (int): The current depth of the search. + - path (list): The current path in the tree. + + Returns: + - List of all paths to the max depth. + """ + if path is None: + path = [] + + # Add the current node (reaction SMILES) to the path + path.append(current_node) + + # If we have reached the target depth, check the product + if current_depth == target_depth: + # Extract products of the current node + products = sorted(current_node.split(">>")[1].split(".")) + return [path] if products == target_products else [] + + # If we haven't reached the target depth, recurse on the products + paths = [] + for product in reaction_tree.get(current_node, []): + paths.extend( + find_all_paths( + reaction_tree, + target_products, + product, + target_depth, + current_depth + 1, + path.copy(), + ) + ) + + return paths diff --git a/synutility/SynGraph/Transform/rule_apply.py b/synutility/SynGraph/Transform/rule_apply.py index 0dee2d3..9836e0a 100644 --- a/synutility/SynGraph/Transform/rule_apply.py +++ b/synutility/SynGraph/Transform/rule_apply.py @@ -1,107 +1,80 @@ import os -import regex +from typing import List from synutility.SynIO.debug import setup_logging import torch -from mod import smiles, ruleGMLString, DG, config, DGVertexMapper +from mod import smiles, ruleGMLString, DG, config logger = setup_logging() def deduplicateGraphs(initial): - """ - Removes duplicate graphs from a list based on graph isomorphism. + res = [] + for cand in initial: + for a in res: + if cand.isomorphism(a) != 0: + res.append(a) # the one we had already + break + else: + # didn't find any isomorphic, use the new one + res.append(cand) + return res - Parameters: - - initial (list): List of graph objects. - Returns: - - List of unique graph objects. +def rule_apply( + smiles_list: List[str], rule: str, verbose: int = 0, print_output: bool = False +) -> DG: """ - unique_graphs = [] - for candidate in initial: - # Check if candidate is isomorphic to any graph already in unique_graphs - if not any(candidate.isomorphism(existing) != 0 for existing in unique_graphs): - unique_graphs.append(candidate) - return unique_graphs + Applies a reaction rule to a list of SMILES strings and optionally prints + the derivation graph. - -def rule_apply(smiles_list, rule, print_output=False): - """ - Applies a reaction rule to a list of SMILES and optionally prints the output. + This function first converts the SMILES strings into molecular graphs, + deduplicates them, sorts them based on the number of vertices, and + then applies the provided reaction rule in the GML string format. + The resulting derivation graph (DG) is returned. Parameters: - - smiles_list (list): List of SMILES strings. - - rule (str): Reaction rule in GML string format. - - print_output (bool): If True, output will be printed to a directory. + - smiles_list (List[str]): A list of SMILES strings representing the molecules + to which the reaction rule will be applied. + - rule (str): The reaction rule in GML string format. This rule will be applied to the + molecules represented by the SMILES strings. + - verbose (int, optional): The verbosity level for logging or debugging. + Default is 0 (no verbosity). + - print_output (bool, optional): If True, the derivation graph will be printed + to the "out" directory. Default is False. Returns: - - dg (DG): The derivation graph after applying the rule. + - DG: The derivation graph (DG) after applying the reaction rule to the + initial molecules. + + Raises: + - Exception: If an error occurs during the process of applying the rule, + an exception is raised. """ try: + # Convert SMILES strings to molecular graphs and deduplicate initial_molecules = [smiles(smile, add=False) for smile in smiles_list] initial_molecules = deduplicateGraphs(initial_molecules) + + # Sort molecules based on the number of vertices initial_molecules = sorted( initial_molecules, key=lambda molecule: molecule.numVertices, reverse=False ) + # Convert the reaction rule from GML string format to a reaction rule object reaction_rule = ruleGMLString(rule) + # Create the derivation graph and apply the reaction rule dg = DG(graphDatabase=initial_molecules) config.dg.doRuleIsomorphismDuringBinding = False - dg.build().apply(initial_molecules, reaction_rule, verbosity=8) + dg.build().apply(initial_molecules, reaction_rule, verbosity=verbose) - # Optionally print the output + # Optionally print the output to a directory if print_output: os.makedirs("out", exist_ok=True) dg.print() return dg + except Exception as e: logger.error(f"An error occurred: {e}") - - -def getReactionSmiles(dg): - origSmiles = {} - for v in dg.vertices: - s = v.graph.smilesWithIds - s = regex.sub(":([0-9]+)]", ":o\\1]", s) - origSmiles[v.graph] = s - - res = {} - for e in dg.edges: - vms = DGVertexMapper(e, rightLimit=1, leftLimit=1) - # vms = DGVertexMapper(e) - eductSmiles = [origSmiles[g] for g in vms.left] - - for ev in vms.left.vertices: - s = eductSmiles[ev.graphIndex] - s = s.replace(f":o{ev.vertex.id}]", f":{ev.id}]") - eductSmiles[ev.graphIndex] = s - - strs = set() - for vm in DGVertexMapper(e, rightLimit=1, leftLimit=1): - # for vm in DGVertexMapper(e): - productSmiles = [origSmiles[g] for g in vms.right] - for ev in vms.left.vertices: - pv = vm.map[ev] - if not pv: - continue - s = productSmiles[pv.graphIndex] - s = s.replace(f":o{pv.vertex.id}]", f":{ev.id}]") - productSmiles[pv.graphIndex] = s - count = vms.left.numVertices - for pv in vms.right.vertices: - ev = vm.map.inverse(pv) - if ev: - continue - s = productSmiles[pv.graphIndex] - s = s.replace(f":o{pv.vertex.id}]", f":{count}]") - count += 1 - productSmiles[pv.graphIndex] = s - left = ".".join(eductSmiles) - right = ".".join(productSmiles) - s = f"{left}>>{right}" - assert ":o" not in s - strs.add(s) - res[e] = list(sorted(strs)) - return res + raise diff --git a/synutility/SynIO/Format/chemical_conversion.py b/synutility/SynIO/Format/chemical_conversion.py new file mode 100644 index 0000000..6755eb7 --- /dev/null +++ b/synutility/SynIO/Format/chemical_conversion.py @@ -0,0 +1,130 @@ +import networkx as nx +from rdkit import Chem +from typing import Optional, Tuple + +from synutility.SynIO.debug import setup_logging +from synutility.SynIO.Format.mol_to_graph import MolToGraph +from synutility.SynIO.Format.graph_to_mol import GraphToMol +from synutility.SynAAM.its_construction import ITSConstruction +from synutility.SynIO.Format.nx_to_gml import NXToGML +from synutility.SynIO.Format.gml_to_nx import GMLToNX +from synutility.SynAAM.misc import get_rc, its_decompose + + +logger = setup_logging() + + +def smiles_to_graph( + smiles: str, drop_non_aam: bool, light_weight: bool, sanitize: bool +) -> Optional[nx.Graph]: + """ + Helper function to convert SMILES string to a graph using MolToGraph class. + + Parameters: + - smiles (str): SMILES representation of the molecule. + - drop_non_aam (bool): Whether to drop nodes without atom mapping. + - light_weight (bool): Whether to create a light-weight graph. + - sanitize (bool): Whether to sanitize the molecule during conversion. + + Returns: + - nx.Graph or None: The networkx graph representation of the molecule, + or None if conversion fails. + """ + try: + mol = Chem.MolFromSmiles(smiles, sanitize) + if mol: + return MolToGraph().mol_to_graph(mol, drop_non_aam, light_weight) + else: + logger.warning(f"Failed to parse SMILES: {smiles}") + except Exception as e: + logger.error(f"Error converting SMILES to graph: {smiles}, Error: {str(e)}") + return None + + +def rsmi_to_graph( + rsmi: str, + drop_non_aam: bool = True, + light_weight: bool = True, + sanitize: bool = True, +) -> Tuple[Optional[nx.Graph], Optional[nx.Graph]]: + """ + Converts reactant and product SMILES strings from a reaction SMILES (RSMI) format + to graph representations. + + Parameters: + - rsmi (str): Reaction SMILES string in "reactants>>products" format. + - drop_non_aam (bool, optional): If True, nodes without atom mapping numbers + will be dropped. + - light_weight (bool, optional): If True, creates a light-weight graph. + - sanitize (bool, optional): If True, sanitizes molecules during conversion. + + Returns: + - Tuple[Optional[nx.Graph], Optional[nx.Graph]]: A tuple containing t + he graph representations of the reactants and products. + """ + try: + reactants_smiles, products_smiles = rsmi.split(">>") + r_graph = smiles_to_graph( + reactants_smiles, drop_non_aam, light_weight, sanitize + ) + p_graph = smiles_to_graph(products_smiles, drop_non_aam, light_weight, sanitize) + return (r_graph, p_graph) + except ValueError: + logger.error(f"Invalid RSMI format: {rsmi}") + return (None, None) + + +def graph_to_rsmi(r: nx.Graph, p: nx.Graph) -> str: + """ + Converts graph representations of reactants and products to a reaction SMILES string. + + Parameters: + - r (nx.Graph): Graph of the reactants. + - p (nx.Graph): Graph of the products. + + Returns: + - str: Reaction SMILES string. + """ + r = GraphToMol().graph_to_mol(r) + p = GraphToMol().graph_to_mol(p) + return f"{Chem.MolToSmiles(r)}>>{Chem.MolToSmiles(p)}" + + +def smart_to_gml( + smart: str, + core: bool = True, + sanitize: bool = False, + rule_name: str = "rule", + reindex: bool = False, +) -> str: + """ + Converts a SMARTS string to GML format, optionally focusing on the reaction core. + + Parameters: + - smart (str): The SMARTS string representing the reaction. + - core (bool): Whether to extract and focus on the reaction core. Defaults to True. + + Returns: + - str: The GML representation of the reaction. + """ + r, p = rsmi_to_graph(smart, sanitize=sanitize) + its = ITSConstruction.ITSGraph(r, p) + if core: + its = get_rc(its) + r, p = its_decompose(its) + gml = NXToGML().transform((r, p, its), reindex=reindex, rule_name=rule_name) + return gml + + +def gml_to_smart(gml: str) -> str: + """ + Converts a GML string back to a SMARTS string by interpreting the graph structures. + + Parameters: + - gml (str): The GML string to convert. + + Returns: + - str: The corresponding SMARTS string. + """ + r, p, rc = GMLToNX(gml).transform() + return graph_to_rsmi(r, p), rc diff --git a/synutility/SynIO/Format/dg_to_gml.py b/synutility/SynIO/Format/dg_to_gml.py index b2d079d..321a74b 100644 --- a/synutility/SynIO/Format/dg_to_gml.py +++ b/synutility/SynIO/Format/dg_to_gml.py @@ -22,6 +22,7 @@ def getReactionSmiles(dg): res = {} for e in dg.edges: vms = DGVertexMapper(e, rightLimit=1, leftLimit=1) + # vms = DGVertexMapper(e) eductSmiles = [origSmiles[g] for g in vms.left] for ev in vms.left.vertices: @@ -31,6 +32,7 @@ def getReactionSmiles(dg): strs = set() for vm in DGVertexMapper(e, rightLimit=1, leftLimit=1): + # for vm in DGVertexMapper(e): productSmiles = [origSmiles[g] for g in vms.right] for ev in vms.left.vertices: pv = vm.map[ev] diff --git a/synutility/SynIO/Format/gml_to_nx.py b/synutility/SynIO/Format/gml_to_nx.py index 7e6d4f9..47a70a3 100644 --- a/synutility/SynIO/Format/gml_to_nx.py +++ b/synutility/SynIO/Format/gml_to_nx.py @@ -1,7 +1,7 @@ import networkx as nx import re from typing import Tuple -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynAAM.its_construction import ITSConstruction class GMLToNX: @@ -90,7 +90,7 @@ def _extract_element_and_charge(self, label: str) -> Tuple[str, int]: is conservative and primarily for error handling. """ # Regex to separate the element symbols from the optional charge and sign - match = re.match(r"([A-Za-z]+)(\d+)?([+-])?$", label) + match = re.match(r"([A-Za-z*]+)(\d+)?([+-])?$", label) if not match: return ( "X", diff --git a/synutility/SynIO/Format/mol_to_graph.py b/synutility/SynIO/Format/mol_to_graph.py index af5cd51..3c44a28 100644 --- a/synutility/SynIO/Format/mol_to_graph.py +++ b/synutility/SynIO/Format/mol_to_graph.py @@ -143,7 +143,9 @@ def _create_light_weight_graph(cls, mol: Chem.Mol, drop_non_aam: bool) -> nx.Gra aromatic=atom.GetIsAromatic(), hcount=atom.GetTotalNumHs(), charge=atom.GetFormalCharge(), - neighbors=[neighbor.GetSymbol() for neighbor in atom.GetNeighbors()], + neighbors=sorted( + neighbor.GetSymbol() for neighbor in atom.GetNeighbors() + ), atom_map=atom_map, ) for bond in atom.GetBonds(): diff --git a/synutility/SynVis/graph_visualizer.py b/synutility/SynVis/graph_visualizer.py index acf7426..84179e7 100644 --- a/synutility/SynVis/graph_visualizer.py +++ b/synutility/SynVis/graph_visualizer.py @@ -4,12 +4,14 @@ Adaptations were made to enhance functionality and integrate with other system components. """ +import networkx as nx from rdkit import Chem from rdkit.Chem import rdDepictor + +import matplotlib.pyplot as plt from typing import Dict, Optional -import networkx as nx + from synutility.SynIO.Format.graph_to_mol import GraphToMol -import matplotlib.pyplot as plt class GraphVisualizer: @@ -42,7 +44,7 @@ def _get_its_as_mol(self, its: nx.Graph) -> Optional[Chem.Mol]: _its[u][v]["order"] = 1 return GraphToMol(self.node_attributes, self.edge_attributes).graph_to_mol( _its, False, False - ) # Ensure this function is defined correctly elsewhere + ) def plot_its( self, @@ -85,7 +87,7 @@ def plot_its( Returns: - None """ - bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡"} + bond_char = {None: "∅", 0: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"} positions = self._calculate_positions(its, use_mol_coords) @@ -98,9 +100,9 @@ def plot_its( if use_edge_color: edge_colors = [ ( - "green" + "red" if data.get(standard_order_key, 0) > 0 - else "red" if data.get(standard_order_key, 0) < 0 else "black" + else "green" if data.get(standard_order_key, 0) < 0 else "black" ) for _, _, data in its.edges(data=True) ] @@ -198,7 +200,7 @@ def plot_as_mol( # Set default bond characters if not provided if bond_char is None: - bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡"} + bond_char = {None: "∅", 1: "—", 2: "=", 3: "≡", 1.5: ":"} # Determine positions based on use_mol_coords flag if use_mol_coords: @@ -229,7 +231,14 @@ def plot_as_mol( # Preparing labels labels = {} for n, d in g.nodes(data=True): - label = f"{d.get(symbol_key, '')}" + charge = d.get("charge", 0) + if charge == 0: + charge = "" + elif charge > 0: + charge = f"{charge}+" if charge > 1 else "+" + else: + charge = f"{-charge}-" if charge < -1 else "-" + label = f"{d.get(symbol_key, '')}{charge}" if show_atom_map: label += f" ({d.get(aam_key, '')})" labels[n] = label diff --git a/synutility/SynVis/rsmi_to_fig.py b/synutility/SynVis/rsmi_to_fig.py index 690d326..a08520d 100644 --- a/synutility/SynVis/rsmi_to_fig.py +++ b/synutility/SynVis/rsmi_to_fig.py @@ -4,8 +4,8 @@ from synutility.SynVis.graph_visualizer import GraphVisualizer -from synutility.SynIO.Format.smi_to_graph import rsmi_to_graph -from synutility.SynIO.Format.its_construction import ITSConstruction +from synutility.SynIO.Format.chemical_conversion import rsmi_to_graph +from synutility.SynAAM.its_construction import ITSConstruction vis_graph = GraphVisualizer()