Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test, fix reductions with no inames #527

Merged
merged 2 commits into from
Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,8 @@ def map_scan_local(expr, rec, callables_table, nresults, arg_dtypes,

def map_reduction(expr, rec, callables_table,
guarding_predicates, nresults=1):
nonlocal insn_changed

# Only expand one level of reduction at a time, going from outermost to
# innermost. Otherwise we get the (iname + insn) dependencies wrong.

Expand Down Expand Up @@ -1827,6 +1829,10 @@ def _error_if_force_scan_on(cls, msg):
", ".join(str(kernel.iname_tags(iname))
for iname in bad_inames)))

# }}}

insn_changed = True

if n_local_par == 0 and n_sequential == 0:
from loopy.diagnostic import warn_with_kernel
warn_with_kernel(kernel, "empty_reduction",
Expand All @@ -1840,8 +1846,6 @@ def _error_if_force_scan_on(cls, msg):

return expr.expr, callables_table

# }}}

if may_be_implemented_as_scan:
assert force_scan or automagic_scans_ok

Expand Down Expand Up @@ -1916,7 +1920,7 @@ def _error_if_force_scan_on(cls, msg):
domains = kernel.domains[:]

temp_kernel = kernel
changed = False
kernel_changed = False

import loopy as lp
while insn_queue:
Expand All @@ -1925,6 +1929,7 @@ def _error_if_force_scan_on(cls, msg):
new_insn_add_within_inames = set()

generated_insns = []
insn_changed = False

insn = insn_queue.pop(0)

Expand All @@ -1947,7 +1952,7 @@ def _error_if_force_scan_on(cls, msg):
callables_table=cb_mapper.callables_table,
guarding_predicates=insn.predicates),

if generated_insns:
if insn_changed:
# An expansion happened, so insert the generated stuff plus
# ourselves back into the queue.

Expand Down Expand Up @@ -2010,14 +2015,14 @@ def _error_if_force_scan_on(cls, msg):
domains=domains)
temp_kernel = lp.replace_instruction_ids(
temp_kernel, insn_id_replacements)
changed = True
kernel_changed = True
else:
# nothing happened, we're done with insn
assert not new_insn_add_depends_on

new_insns.append(insn)

if changed:
if kernel_changed:
kernel = kernel.copy(
instructions=new_insns,
temporary_variables=new_temporary_variables,
Expand Down
21 changes: 21 additions & 0 deletions test/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,27 @@ def test_any_all(ctx_factory):
assert not out_dict["out2"].get()


def test_reduction_without_inames(ctx_factory):
"""Ensure that reductions with no inames get rewritten to the element
being reduced over. This was sometimes erroneously eliminated because
reduction realization used the generation of new statements as a criterion
for whether work was done.
"""
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

knl = lp.make_kernel(
"{:}",
"""
out = reduce(any, [], 5)
""")
knl = lp.set_options(knl, return_dict=True)

_, out_dict = knl(cq)

assert out_dict["out"].get() == 5


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down