Skip to content

Commit

Permalink
Test, fix reductions with no inames
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 17, 2021
1 parent 752d758 commit ea9a865
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 1 deletion.
9 changes: 8 additions & 1 deletion loopy/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -1751,6 +1751,8 @@ def map_scan_local(expr, rec, callables_table, nresults, arg_dtypes,

def map_reduction(expr, rec, callables_table,
guarding_predicates, nresults=1):
nonlocal made_changes

# Only expand one level of reduction at a time, going from outermost to
# innermost. Otherwise we get the (iname + insn) dependencies wrong.

Expand Down Expand Up @@ -1838,6 +1840,7 @@ def _error_if_force_scan_on(cls, msg):
# to reduce over. It's rather similar to an array with () shape in
# numpy.)

made_changes = True
return expr.expr, callables_table

# }}}
Expand Down Expand Up @@ -1866,6 +1869,7 @@ def _error_if_force_scan_on(cls, msg):
", ".join(tag.key
for tag in temp_kernel.iname_tags(sweep_iname))))
elif parallel:
made_changes = True
return map_scan_local(
expr, rec, callables_table, nresults,
arg_dtypes, reduction_dtypes,
Expand All @@ -1875,6 +1879,7 @@ def _error_if_force_scan_on(cls, msg):
scan_param.stride,
guarding_predicates)
elif sequential:
made_changes = True
return map_scan_seq(
expr, rec, callables_table, nresults,
arg_dtypes, reduction_dtypes, sweep_iname,
Expand Down Expand Up @@ -1903,6 +1908,7 @@ def _error_if_force_scan_on(cls, msg):
guarding_predicates)
else:
assert n_local_par > 0
made_changes = True
return map_reduction_local(
expr, rec, callables_table, nresults, arg_dtypes,
reduction_dtypes, guarding_predicates)
Expand All @@ -1925,6 +1931,7 @@ def _error_if_force_scan_on(cls, msg):
new_insn_add_within_inames = set()

generated_insns = []
made_changes = False

insn = insn_queue.pop(0)

Expand All @@ -1947,7 +1954,7 @@ def _error_if_force_scan_on(cls, msg):
callables_table=cb_mapper.callables_table,
guarding_predicates=insn.predicates),

if generated_insns:
if made_changes:
# An expansion happened, so insert the generated stuff plus
# ourselves back into the queue.

Expand Down
19 changes: 19 additions & 0 deletions test/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,25 @@ def test_any_all(ctx_factory):
assert not out_dict["out2"].get()


def test_reduction_without_inames(ctx_factory):
"""Ensure that reductions with no inames get rewritten to the element
being reduced over. This was sometimes erroneously eliminated because
reduction realization used the generation of new statements as a criterion
for whether work was done.
"""
ctx = ctx_factory()
cq = cl.CommandQueue(ctx)

knl = lp.make_kernel(
"{:}",
"""
out = reduce(any, [], 5)
""")
knl = lp.set_options(knl, return_dict=True)

_, out_dict = knl(cq)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit ea9a865

Please sign in to comment.