Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tag propagation graph generation #220

Merged
merged 2 commits into from
Sep 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions pytools/graph.py
a-alveyblanc marked this conversation as resolved.
Show resolved Hide resolved
a-alveyblanc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Copyright (C) 2009-2013 Andreas Kloeckner
Copyright (C) 2020 Matt Wala
Copyright (C) 2020 James Stevens
Copyright (C) 2024 Addison Alvey-Blanco
"""

__license__ = """
Expand Down Expand Up @@ -47,6 +48,8 @@
.. autofunction:: as_graphviz_dot
.. autofunction:: validate_graph
.. autofunction:: is_connected
.. autofunction:: undirected_graph_from_edges
.. autofunction:: get_reachable_nodes

Type Variables Used
-------------------
Expand All @@ -71,13 +74,16 @@
Callable,
Collection,
Dict,
FrozenSet,
Generic,
Hashable,
Iterable,
Iterator,
List,
Mapping,
MutableSet,
Optional,
Protocol,
Set,
Tuple,
TypeVar,
Expand All @@ -98,7 +104,6 @@

NodeT = TypeVar("NodeT", bound=Hashable)


GraphT: TypeAlias[NodeT] = Mapping[NodeT, Collection[NodeT]]


Expand Down Expand Up @@ -263,8 +268,13 @@ def __init__(self, node: NodeT) -> None:
self.node = node


class _SupportsLT(Protocol):
def __lt__(self, other: object) -> bool:
...


@dataclass(frozen=True)
class HeapEntry(Generic[NodeT]):
class _HeapEntry(Generic[NodeT]):
"""
Helper class to compare associated keys while comparing the elements in
heap operations.
Expand All @@ -273,9 +283,9 @@ class HeapEntry(Generic[NodeT]):
<https://github.com/python/cpython/blob/8d21aa21f2cbc6d50aab3f420bb23be1d081dac4/Lib/heapq.py#L135-L138>.
"""
node: NodeT
key: Any
key: _SupportsLT

def __lt__(self, other: HeapEntry) -> bool:
def __lt__(self, other: _HeapEntry[NodeT]) -> bool:
return self.key < other.key


Expand Down Expand Up @@ -321,7 +331,7 @@ def compute_topological_order(graph: GraphT[NodeT],
# heap: list of instances of HeapEntry(n) where 'n' is a node in
# 'graph' with no predecessor. Nodes with no predecessors are the
# schedulable candidates.
heap = [HeapEntry(n, keyfunc(n))
heap = [_HeapEntry(n, keyfunc(n))
for n, num_preds in nodes_to_num_predecessors.items()
if num_preds == 0]
heapify(heap)
Expand All @@ -336,7 +346,7 @@ def compute_topological_order(graph: GraphT[NodeT],
for child in graph.get(node_to_be_scheduled, ()):
nodes_to_num_predecessors[child] -= 1
if nodes_to_num_predecessors[child] == 0:
heappush(heap, HeapEntry(child, keyfunc(child)))
heappush(heap, _HeapEntry(child, keyfunc(child)))

if len(order) != total_num_nodes:
# any node which has a predecessor left is a part of a cycle
Expand Down Expand Up @@ -457,11 +467,11 @@ def as_graphviz_dot(graph: GraphT[NodeT],
from pytools.graphviz import dot_escape

if node_labels is None:
def node_labels(x):
def node_labels(x: NodeT) -> str:
return str(x)

if edge_labels is None:
def edge_labels(x, y):
def edge_labels(x: NodeT, y: NodeT) -> str:
return ""

node_to_id = {}
Expand Down Expand Up @@ -511,7 +521,7 @@ def validate_graph(graph: GraphT[NodeT]) -> None:
# }}}


# {{{
# {{{ is_connected

def is_connected(graph: GraphT[NodeT]) -> bool:
"""
Expand Down Expand Up @@ -542,5 +552,52 @@ def dfs(node: NodeT) -> None:

return visited == graph.keys()

# }}}


def undirected_graph_from_edges(
edges: Iterable[Tuple[NodeT, NodeT]],
) -> GraphT[NodeT]:
"""
Constructs an undirected graph using *edges*.

:arg edges: An :class:`Iterable` of pairs of related :class:`NodeT` s.

:returns: A :class:`GraphT` that is the undirected graph.
"""
undirected_graph: Dict[NodeT, Set[NodeT]] = {}

for lhs, rhs in edges:
if lhs == rhs:
raise TypeError("Found loop in edges,"
f" LHS, RHS = {lhs}")

undirected_graph.setdefault(lhs, set()).add(rhs)
undirected_graph.setdefault(rhs, set()).add(lhs)

return undirected_graph


def get_reachable_nodes(
undirected_graph: GraphT[NodeT],
source_node: NodeT) -> FrozenSet[NodeT]:
"""
Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
reachable from *source_node*.
"""
nodes_visited: Set[NodeT] = set()
nodes_to_visit = {source_node}

while nodes_to_visit:
current_node = nodes_to_visit.pop()
nodes_visited.add(current_node)

neighbors = undirected_graph[current_node]
nodes_to_visit.update({node
for node in neighbors
if node not in nodes_visited})

return frozenset(nodes_visited)


# vim: foldmethod=marker
32 changes: 32 additions & 0 deletions pytools/test/test_graph_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,38 @@ def test_is_connected():
assert is_connected({})


def test_propagation_graph_tools():
from pytools.graph import (
get_reachable_nodes,
undirected_graph_from_edges,
)

vars = {"a", "b", "c", "d", "e", "f", "g"}

constraints = [
("a", "b"),
("a", "d"),
("c", "d"),
("e", "f"),
("f", "g")
]

all_reachable_nodes = {
"a": frozenset({"a", "b", "c", "d"}),
"b": frozenset({"a", "b", "c", "d"}),
"c": frozenset({"a", "b", "c", "d"}),
"e": frozenset({"e", "f", "g"}),
"f": frozenset({"e", "f", "g"}),
"g": frozenset({"e", "f", "g"})
}

propagation_graph = undirected_graph_from_edges(constraints)
assert (
all_reachable_nodes[var] == get_reachable_nodes(propagation_graph, var)
for var in vars
)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down
1 change: 1 addition & 0 deletions run-mypy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ mypy --show-error-codes pytools

mypy --strict --follow-imports=silent \
pytools/tag.py \
pytools/graph.py \
pytools/datatable.py \
pytools/persistent_dict.py
Loading