Skip to content

Commit

Permalink
Merge branch 'main' into reductions-without-inames
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer authored Dec 17, 2021
2 parents ea9a865 + da5c36e commit 8bdcaec
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 1 deletion.
5 changes: 5 additions & 0 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,11 @@ def __init__(self, name, dtype=None, shape=None, dim_tags=None, offset=0,
tags=tags,
**kwargs)

# Without this __hash__ is set to None because this class overrides __eq__.
# Source: https://docs.python.org/3/reference/datamodel.html#object.__hash__
def __hash__(self):
return super().__hash__()

def __eq__(self, other):
from loopy.symbolic import (
is_tuple_of_expressions_equal as istoee,
Expand Down
3 changes: 2 additions & 1 deletion loopy/transform/iname.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,7 +935,8 @@ def duplicate_inames(kernel, inames, within, new_inames=None, suffix=None,
within=within)

def _does_access_old_inames(kernel, insn, *args):
return bool(frozenset(inames) & insn.dependency_names())
return bool(frozenset(inames) & (insn.dependency_names()
| insn.reduction_inames()))

kernel = rule_mapping_context.finish_kernel(
indup.map_kernel(kernel, within=_does_access_old_inames,
Expand Down
13 changes: 13 additions & 0 deletions test/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,19 @@ def test_precompute_does_not_lead_to_dep_cycle(ctx_factory):
lp.auto_test_vs_ref(knl, ctx, ref_knl)


def test_rename_inames_redn():
t_unit = lp.make_kernel(
"{[i, j0, j1]: 0<=i, j0, j1<10}",
"""
y0[i] = sum(j0, sum([j1], 2*A[i, j0, j1]))
""")

t_unit = lp.rename_iname(t_unit, "j1", "ifused")

assert "j1" not in t_unit.default_entrypoint.all_inames()
assert "ifused" in t_unit.default_entrypoint.all_inames()


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit 8bdcaec

Please sign in to comment.