Skip to content

Tag propagation graph generation #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions pytools/graph.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
Copyright (C) 2009-2013 Andreas Kloeckner
Copyright (C) 2020 Matt Wala
Copyright (C) 2020 James Stevens
Copyright (C) 2024 Addison Alvey-Blanco
"""

__license__ = """
@@ -47,6 +48,8 @@
.. autofunction:: as_graphviz_dot
.. autofunction:: validate_graph
.. autofunction:: is_connected
.. autofunction:: undirected_graph_from_edges
.. autofunction:: get_reachable_nodes
Type Variables Used
-------------------
@@ -71,13 +74,16 @@
Callable,
Collection,
Dict,
FrozenSet,
Generic,
Hashable,
Iterable,
Iterator,
List,
Mapping,
MutableSet,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
@@ -98,7 +104,6 @@

NodeT = TypeVar("NodeT", bound=Hashable)


GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]]


@@ -263,8 +268,13 @@ def __init__(self, node: NodeT) -> None:
self.node = node


class _SupportsLT(Protocol):
def __lt__(self, other: object) -> bool:
...


@dataclass(frozen=True)
class HeapEntry(Generic[NodeT]):
class _HeapEntry(Generic[NodeT]):
"""
Helper class to compare associated keys while comparing the elements in
heap operations.
@@ -273,9 +283,9 @@ class HeapEntry(Generic[NodeT]):
<https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>.
"""
node: NodeT
key: Any
key: _SupportsLT

def __lt__(self, other: HeapEntry) -> bool:
def __lt__(self, other: _HeapEntry[NodeT]) -> bool:
return self.key < other.key


@@ -321,7 +331,7 @@ def compute_topological_order(graph: GraphT[NodeT],
# heap: list of instances of HeapEntry(n) where 'n' is a node in
# 'graph' with no predecessor. Nodes with no predecessors are the
# schedulable candidates.
heap = [HeapEntry(n, keyfunc(n))
heap = [_HeapEntry(n, keyfunc(n))
for n, num_preds in nodes_to_num_predecessors.items()
if num_preds == 0]
heapify(heap)
@@ -336,7 +346,7 @@ def compute_topological_order(graph: GraphT[NodeT],
for child in graph.get(node_to_be_scheduled, ()):
nodes_to_num_predecessors[child] -= 1
if nodes_to_num_predecessors[child] == 0:
heappush(heap, HeapEntry(child, keyfunc(child)))
heappush(heap, _HeapEntry(child, keyfunc(child)))

if len(order) != total_num_nodes:
# any node which has a predecessor left is a part of a cycle
@@ -457,11 +467,11 @@ def as_graphviz_dot(graph: GraphT[NodeT],
from pytools.graphviz import dot_escape

if node_labels is None:
def node_labels(x):
def node_labels(x: NodeT) -> str:
return str(x)

if edge_labels is None:
def edge_labels(x, y):
def edge_labels(x: NodeT, y: NodeT) -> str:
return ""

node_to_id = {}
@@ -511,7 +521,7 @@ def validate_graph(graph: GraphT[NodeT]) -> None:
# }}}


# {{{
# {{{ is_connected

def is_connected(graph: GraphT[NodeT]) -> bool:
"""
@@ -542,5 +552,52 @@ def dfs(node: NodeT) -> None:

return visited == graph.keys()

# }}}


def undirected_graph_from_edges(
edges: Iterable[Tuple[NodeT, NodeT]],
) -> GraphT[NodeT]:
"""
Constructs an undirected graph using *edges*.
:arg edges: An :class:`Iterable` of pairs of related :class:`NodeT` s.
:returns: A :class:`GraphT` that is the undirected graph.
"""
undirected_graph: Dict[NodeT, Set[NodeT]] = {}

for lhs, rhs in edges:
if lhs == rhs:
raise TypeError("Found loop in edges,"
f" LHS, RHS = {lhs}")

undirected_graph.setdefault(lhs, set()).add(rhs)
undirected_graph.setdefault(rhs, set()).add(lhs)

return undirected_graph


def get_reachable_nodes(
undirected_graph: GraphT[NodeT],
source_node: NodeT) -> FrozenSet[NodeT]:
"""
Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
reachable from *source_node*.
"""
nodes_visited: Set[NodeT] = set()
nodes_to_visit = {source_node}

while nodes_to_visit:
current_node = nodes_to_visit.pop()
nodes_visited.add(current_node)

neighbors = undirected_graph[current_node]
nodes_to_visit.update({node
for node in neighbors
if node not in nodes_visited})

return frozenset(nodes_visited)


# vim: foldmethod=marker
32 changes: 32 additions & 0 deletions pytools/test/test_graph_tools.py
Original file line number Diff line number Diff line change
@@ -431,6 +431,38 @@ def test_is_connected():
assert is_connected({})


def test_propagation_graph_tools():
from pytools.graph import (
get_reachable_nodes,
undirected_graph_from_edges,
)

vars = {"a", "b", "c", "d", "e", "f", "g"}

constraints = [
("a", "b"),
("a", "d"),
("c", "d"),
("e", "f"),
("f", "g")
]

all_reachable_nodes = {
"a": frozenset({"a", "b", "c", "d"}),
"b": frozenset({"a", "b", "c", "d"}),
"c": frozenset({"a", "b", "c", "d"}),
"e": frozenset({"e", "f", "g"}),
"f": frozenset({"e", "f", "g"}),
"g": frozenset({"e", "f", "g"})
}

propagation_graph = undirected_graph_from_edges(constraints)
assert (
all_reachable_nodes[var] == get_reachable_nodes(propagation_graph, var)
for var in vars
)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
1 change: 1 addition & 0 deletions run-mypy.sh
Original file line number Diff line number Diff line change
@@ -6,5 +6,6 @@ mypy --show-error-codes pytools

mypy --strict --follow-imports=silent \
pytools/tag.py \
pytools/graph.py \
pytools/datatable.py \
pytools/persistent_dict.py