diff --git a/pytools/graph.py b/pytools/graph.py index 91f2bffa..baed0995 100644 --- a/pytools/graph.py +++ b/pytools/graph.py @@ -5,6 +5,7 @@ Copyright (C) 2009-2013 Andreas Kloeckner Copyright (C) 2020 Matt Wala Copyright (C) 2020 James Stevens +Copyright (C) 2024 Addison Alvey-Blanco """ __license__ = """ @@ -47,6 +48,8 @@ .. autofunction:: as_graphviz_dot .. autofunction:: validate_graph .. autofunction:: is_connected +.. autofunction:: undirected_graph_from_edges +.. autofunction:: get_reachable_nodes Type Variables Used ------------------- @@ -71,13 +74,16 @@ Callable, Collection, Dict, + FrozenSet, Generic, Hashable, + Iterable, Iterator, List, Mapping, MutableSet, Optional, + Protocol, Set, Tuple, TypeVar, @@ -98,7 +104,6 @@ NodeT = TypeVar("NodeT", bound=Hashable) - GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]] @@ -263,8 +268,13 @@ def __init__(self, node: NodeT) -> None: self.node = node +class _SupportsLT(Protocol): + def __lt__(self, other: object) -> bool: + ... + + @dataclass(frozen=True) -class HeapEntry(Generic[NodeT]): +class _HeapEntry(Generic[NodeT]): """ Helper class to compare associated keys while comparing the elements in heap operations. @@ -273,9 +283,9 @@ class HeapEntry(Generic[NodeT]): . """ node: NodeT - key: Any + key: _SupportsLT - def __lt__(self, other: HeapEntry) -> bool: + def __lt__(self, other: _HeapEntry[NodeT]) -> bool: return self.key < other.key @@ -321,7 +331,7 @@ def compute_topological_order(graph: GraphT[NodeT], # heap: list of instances of HeapEntry(n) where 'n' is a node in # 'graph' with no predecessor. Nodes with no predecessors are the # schedulable candidates. - heap = [HeapEntry(n, keyfunc(n)) + heap = [_HeapEntry(n, keyfunc(n)) for n, num_preds in nodes_to_num_predecessors.items() if num_preds == 0] heapify(heap) @@ -336,7 +346,7 @@ def compute_topological_order(graph: GraphT[NodeT], for child in graph.get(node_to_be_scheduled, ()): nodes_to_num_predecessors[child] -= 1 if nodes_to_num_predecessors[child] == 0: - heappush(heap, HeapEntry(child, keyfunc(child))) + heappush(heap, _HeapEntry(child, keyfunc(child))) if len(order) != total_num_nodes: # any node which has a predecessor left is a part of a cycle @@ -457,11 +467,11 @@ def as_graphviz_dot(graph: GraphT[NodeT], from pytools.graphviz import dot_escape if node_labels is None: - def node_labels(x): + def node_labels(x: NodeT) -> str: return str(x) if edge_labels is None: - def edge_labels(x, y): + def edge_labels(x: NodeT, y: NodeT) -> str: return "" node_to_id = {} @@ -511,7 +521,7 @@ def validate_graph(graph: GraphT[NodeT]) -> None: # }}} -# {{{ +# {{{ is_connected def is_connected(graph: GraphT[NodeT]) -> bool: """ @@ -542,5 +552,52 @@ def dfs(node: NodeT) -> None: return visited == graph.keys() +# }}} + + +def undirected_graph_from_edges( + edges: Iterable[Tuple[NodeT, NodeT]], + ) -> GraphT[NodeT]: + """ + Constructs an undirected graph using *edges*. + + :arg edges: An :class:`Iterable` of pairs of related :class:`NodeT` s. + + :returns: A :class:`GraphT` that is the undirected graph. + """ + undirected_graph: Dict[NodeT, Set[NodeT]] = {} + + for lhs, rhs in edges: + if lhs == rhs: + raise TypeError("Found loop in edges," + f" LHS, RHS = {lhs}") + + undirected_graph.setdefault(lhs, set()).add(rhs) + undirected_graph.setdefault(rhs, set()).add(lhs) + + return undirected_graph + + +def get_reachable_nodes( + undirected_graph: GraphT[NodeT], + source_node: NodeT) -> FrozenSet[NodeT]: + """ + Returns a :class:`frozenset` of all nodes in *undirected_graph* that are + reachable from *source_node*. + """ + nodes_visited: Set[NodeT] = set() + nodes_to_visit = {source_node} + + while nodes_to_visit: + current_node = nodes_to_visit.pop() + nodes_visited.add(current_node) + + neighbors = undirected_graph[current_node] + nodes_to_visit.update({node + for node in neighbors + if node not in nodes_visited}) + + return frozenset(nodes_visited) + # vim: foldmethod=marker diff --git a/pytools/test/test_graph_tools.py b/pytools/test/test_graph_tools.py index a98986ed..57c46295 100644 --- a/pytools/test/test_graph_tools.py +++ b/pytools/test/test_graph_tools.py @@ -431,6 +431,38 @@ def test_is_connected(): assert is_connected({}) +def test_propagation_graph_tools(): + from pytools.graph import ( + get_reachable_nodes, + undirected_graph_from_edges, + ) + + vars = {"a", "b", "c", "d", "e", "f", "g"} + + constraints = [ + ("a", "b"), + ("a", "d"), + ("c", "d"), + ("e", "f"), + ("f", "g") + ] + + all_reachable_nodes = { + "a": frozenset({"a", "b", "c", "d"}), + "b": frozenset({"a", "b", "c", "d"}), + "c": frozenset({"a", "b", "c", "d"}), + "e": frozenset({"e", "f", "g"}), + "f": frozenset({"e", "f", "g"}), + "g": frozenset({"e", "f", "g"}) + } + + propagation_graph = undirected_graph_from_edges(constraints) + assert ( + all_reachable_nodes[var] == get_reachable_nodes(propagation_graph, var) + for var in vars + ) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/run-mypy.sh b/run-mypy.sh index 244d6cc4..0220058b 100755 --- a/run-mypy.sh +++ b/run-mypy.sh @@ -6,5 +6,6 @@ mypy --show-error-codes pytools mypy --strict --follow-imports=silent \ pytools/tag.py \ + pytools/graph.py \ pytools/datatable.py \ pytools/persistent_dict.py