Skip to content

Commit

Permalink
group functions into classes
Browse files Browse the repository at this point in the history
  • Loading branch information
Klaus Weinbauer committed Mar 13, 2024
1 parent 44dfeb4 commit eb06eb4
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 198 deletions.
3 changes: 2 additions & 1 deletion fgutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .query import get_functional_groups, FGQuery
from .permutation import PermutationMapper
from .query import FGQuery
125 changes: 58 additions & 67 deletions fgutils/fgconfig.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations
import numpy as np

from fgutils.permutation import PermutationMapper
from fgutils.parse import parse
from fgutils.mapping import map_pattern
from fgutils.mapping import map_to_entire_graph

functional_group_config = [
_default_fg_config = [
{
"name": "carbonyl",
"pattern": "C(=O)",
Expand Down Expand Up @@ -101,9 +102,9 @@ def pattern_len(self) -> int:
)


def is_subgroup(parent: FGConfig, child: FGConfig) -> bool:
p2c = map_full(child.pattern, parent.pattern)
c2p = map_full(parent.pattern, child.pattern)
def is_subgroup(parent: FGConfig, child: FGConfig, mapper: PermutationMapper) -> bool:
p2c = map_to_entire_graph(child.pattern, parent.pattern, mapper)
c2p = map_to_entire_graph(parent.pattern, child.pattern, mapper)
if p2c:
assert c2p is False, "{} ({}) -> {} ({}) matches in both directions.".format(
parent.name, parent.pattern_str, child.name, child.pattern_str
Expand All @@ -112,26 +113,11 @@ def is_subgroup(parent: FGConfig, child: FGConfig) -> bool:
return False


class TreeNode:
def __init__(self, is_child_callback):
self.parents: list[TreeNode] = []
self.children: list[TreeNode] = []
self.is_child_callback = is_child_callback

def is_child(self, parent: TreeNode) -> bool:
return self.is_child_callback(parent, self)

def add_child(self, child: TreeNode):
child.parents.append(self)
self.children.append(child)


class FGTreeNode(TreeNode):
class FGTreeNode:
def __init__(self, fgconfig: FGConfig):
self.fgconfig = fgconfig
self.parents: list[FGTreeNode]
self.children: list[FGTreeNode]
super().__init__(lambda a, b: is_subgroup(a.fgconfig, b.fgconfig))
self.parents: list[FGTreeNode] = []
self.children: list[FGTreeNode] = []

def order_id(self):
return (
Expand All @@ -141,35 +127,12 @@ def order_id(self):
)

def add_child(self, child: FGTreeNode):
super().add_child(child)
child.parents.append(self)
self.children.append(child)
self.parents = sorted(self.parents, key=lambda x: x.order_id(), reverse=True)
self.children = sorted(self.children, key=lambda x: x.order_id(), reverse=True)


fg_configs = None


def get_FG_list() -> list[FGConfig]:
global fg_configs
if fg_configs is None:
c = []
for fgc in functional_group_config:
c.append(FGConfig(**fgc))
fg_configs = c
return fg_configs


def get_FG_by_name(name: str) -> FGConfig:
for fg in get_FG_list():
if fg.name == name:
return fg
raise KeyError("No functional group config with name '{}' found.".format(name))


def get_FG_names() -> list[str]:
return [c.name for c in get_FG_list()]


def sort_by_pattern_len(configs: list[FGConfig], reverse=False) -> list[FGConfig]:
return list(
sorted(
Expand All @@ -180,19 +143,13 @@ def sort_by_pattern_len(configs: list[FGConfig], reverse=False) -> list[FGConfig
)


def map_full(graph, pattern):
for i in range(len(graph)):
r, _ = map_pattern(graph, i, pattern)
if r is True:
return True
return False


def search_parents(roots: list[TreeNode], child: TreeNode) -> None | list[TreeNode]:
def search_parents(
roots: list[FGTreeNode], child: FGTreeNode, mapper: PermutationMapper
) -> None | list[FGTreeNode]:
parents = set()
for root in roots:
if child.is_child(root):
_parents = search_parents(root.children, child)
if is_subgroup(root.fgconfig, child.fgconfig, mapper):
_parents = search_parents(root.children, child, mapper)
if _parents is None:
parents.add(root)
else:
Expand Down Expand Up @@ -222,11 +179,13 @@ def _print(node: FGTreeNode, indent=0):
_print(root)


def build_config_tree_from_list(config_list: list[FGConfig]) -> list[FGTreeNode]:
def build_config_tree_from_list(
config_list: list[FGConfig], mapper: PermutationMapper
) -> list[FGTreeNode]:
roots = []
for config in sort_by_pattern_len(config_list):
node = FGTreeNode(config)
parents = search_parents(roots, node)
parents = search_parents(roots, node, mapper)
if parents is None:
roots.append(node)
else:
Expand All @@ -235,11 +194,43 @@ def build_config_tree_from_list(config_list: list[FGConfig]) -> list[FGTreeNode]
return roots


_fg_tree_roots = None
class FGConfigProvider:
def __init__(
self,
config: list[dict] | list[FGConfig] | None = None,
mapper: PermutationMapper | None = None,
):
self.config_list: list[FGConfig] = []
if config is None:
config = _default_fg_config
if isinstance(config, list) and len(config) > 0:
if isinstance(config[0], dict):
for fgc in config:
self.config_list.append(FGConfig(**fgc)) # type: ignore
elif isinstance(config[0], FGConfig):
self.config_list = config # type: ignore
else:
raise ValueError("Invalid config value.")
else:
raise ValueError("Invalid config value.")

self.mapper = (
mapper
if mapper is not None
else PermutationMapper(wildcard="R", ignore_case=True)
)

self.__tree_roots = None

def get_tree(self) -> list[FGTreeNode]:
if self.__tree_roots is None:
self.__tree_roots = build_config_tree_from_list(
self.config_list, self.mapper
)
return self.__tree_roots

def build_FG_tree() -> list[FGTreeNode]:
global _fg_tree_roots
if _fg_tree_roots is None:
_fg_tree_roots = build_config_tree_from_list(get_FG_list())
return _fg_tree_roots
def get_by_name(self, name: str) -> FGConfig:
for fg in self.config_list:
if fg.name == name:
return fg
raise KeyError("No functional group config with name '{}' found.".format(name))
28 changes: 21 additions & 7 deletions fgutils/mapping.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import networkx as nx

from fgutils.permutation import Mapper
from fgutils.permutation import PermutationMapper


def _get_neighbors(graph, idx, excluded_nodes=set()):
Expand All @@ -17,10 +17,12 @@ def _get_symbol(graph, idx):


def map_anchored_pattern(
graph: nx.Graph, anchor: int, pattern: nx.Graph, pattern_anchor: int
graph: nx.Graph,
anchor: int,
pattern: nx.Graph,
pattern_anchor: int,
mapper: PermutationMapper,
):
mapper = Mapper(wildcard="R", ignore_case=True)

def _fit(idx, pidx, visited_nodes=set(), visited_pnodes=set(), indent=0):
visited_nodes = copy.deepcopy(visited_nodes)
visited_nodes.add(idx)
Expand Down Expand Up @@ -90,15 +92,27 @@ def _fit(idx, pidx, visited_nodes=set(), visited_pnodes=set(), indent=0):


def map_pattern(
graph: nx.Graph, anchor: int, pattern: nx.Graph, pattern_anchor: None | int = None
graph: nx.Graph,
anchor: int,
pattern: nx.Graph,
mapper: PermutationMapper,
pattern_anchor: None | int = None,
):
if pattern_anchor is None:
if len(pattern) == 0:
return True, []
for pidx in pattern.nodes:
result = map_anchored_pattern(graph, anchor, pattern, pidx)
result = map_anchored_pattern(graph, anchor, pattern, pidx, mapper)
if result[0]:
return result
return False, []
else:
return map_anchored_pattern(graph, anchor, pattern, pattern_anchor)
return map_anchored_pattern(graph, anchor, pattern, pattern_anchor, mapper)


def map_to_entire_graph(graph: nx.Graph, pattern: nx.Graph, mapper: PermutationMapper):
for i in range(len(graph)):
r, _ = map_pattern(graph, i, pattern, mapper)
if r is True:
return True
return False
2 changes: 1 addition & 1 deletion fgutils/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def generate_mapping_permutations(pattern, structure, wildcard=None):
return mappings


class Mapper:
class PermutationMapper:
def __init__(self, wildcard=None, ignore_case=False, can_map_to_nothing=[]):
self.wildcard = wildcard
self.ignore_case = ignore_case
Expand Down
Loading

0 comments on commit eb06eb4

Please sign in to comment.