Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update tag propagation logic #494

Merged
merged 12 commits into from
Sep 9, 2024
58 changes: 20 additions & 38 deletions pytato/transform/metadata.py
a-alveyblanc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
"""


from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, FrozenSet,
Mapping, Iterable, Any, TypeVar, cast)
from typing import (TYPE_CHECKING, Type, Set, Tuple, List, Dict, Mapping,
Iterable, Any, TypeVar, cast)
from bidict import bidict
from pytato.scalar_expr import SCALAR_CLASSES
from pytato.transform import ArrayOrNames, Mapper, CopyMapper
Expand Down Expand Up @@ -556,38 +556,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array:
# }}}


def _get_propagation_graph_from_constraints(
equations: List[Tuple[str, str]]) -> Mapping[str, FrozenSet[str]]:
from immutabledict import immutabledict
propagation_graph: Dict[str, Set[str]] = {}
for lhs, rhs in equations:
assert lhs != rhs
propagation_graph.setdefault(lhs, set()).add(rhs)
propagation_graph.setdefault(rhs, set()).add(lhs)

return immutabledict({k: frozenset(v)
for k, v in propagation_graph.items()})


def get_reachable_nodes(undirected_graph: Mapping[GraphNodeT, Iterable[GraphNodeT]],
source_node: GraphNodeT) -> FrozenSet[GraphNodeT]:
"""
Returns a :class:`frozenset` of all nodes in *undirected_graph* that are
reachable from *source_node*.
"""
nodes_visited: Set[GraphNodeT] = set()
nodes_to_visit = {source_node}
while nodes_to_visit:
current_node = nodes_to_visit.pop()
nodes_visited.add(current_node)

neighbors = undirected_graph[current_node]
nodes_to_visit.update({node
for node in neighbors
if node not in nodes_visited})

return frozenset(nodes_visited)

# {{{ AxisTagAttacher

class AxisTagAttacher(CopyMapper):
"""
Expand Down Expand Up @@ -659,6 +628,8 @@ def __call__(self, expr: ArrayOrNames) -> ArrayOrNames: # type: ignore[override
assert isinstance(result, (Array, AbstractResultWithNamedArrays))
return result

# }}}


def unify_axes_tags(
expr: ArrayOrNames,
Expand Down Expand Up @@ -693,15 +664,26 @@ def unify_axes_tags(
# Defn. A Propagation graph is a graph where nodes denote variables and an
# edge between 2 nodes denotes an equality criterion.

propagation_graph = _get_propagation_graph_from_constraints(
equations_collector.equations)
from pytools.graph import (
get_propagation_graph_from_constraints,
get_reachable_nodes
)
from pytools.tag import IgnoredForPropagationTag

known_tag_vars = frozenset(equations_collector.known_tag_to_var.values())
axis_to_solved_tags: Dict[Tuple[Array, int], Set[Tag]] = {}

propagation_graph = get_propagation_graph_from_constraints(
equations_collector.equations,
equations_collector.known_tag_to_var
)

for tag, var in equations_collector.known_tag_to_var.items():
for reachable_var in (get_reachable_nodes(propagation_graph, var)
- known_tag_vars):
if isinstance(tag, IgnoredForPropagationTag):
continue

reachable_nodes = get_reachable_nodes(propagation_graph, var)
for reachable_var in (reachable_nodes - known_tag_vars):
axis_to_solved_tags.setdefault(
equations_collector.axis_to_var.inverse[reachable_var],
set()
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
a-alveyblanc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
git+https://github.com/a-alveyblanc/pytools.git@tag-propagation#egg=pytools
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/inducer/genpy.git#egg=genpy
git+https://github.com/inducer/loopy.git#egg=loopy
Expand Down
Loading