diff --git a/loopy/preprocess.py b/loopy/preprocess.py index fc0e82afb..55b735f4a 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -1751,6 +1751,8 @@ def map_scan_local(expr, rec, callables_table, nresults, arg_dtypes, def map_reduction(expr, rec, callables_table, guarding_predicates, nresults=1): + nonlocal insn_changed + # Only expand one level of reduction at a time, going from outermost to # innermost. Otherwise we get the (iname + insn) dependencies wrong. @@ -1827,6 +1829,10 @@ def _error_if_force_scan_on(cls, msg): ", ".join(str(kernel.iname_tags(iname)) for iname in bad_inames))) + # }}} + + insn_changed = True + if n_local_par == 0 and n_sequential == 0: from loopy.diagnostic import warn_with_kernel warn_with_kernel(kernel, "empty_reduction", @@ -1840,8 +1846,6 @@ def _error_if_force_scan_on(cls, msg): return expr.expr, callables_table - # }}} - if may_be_implemented_as_scan: assert force_scan or automagic_scans_ok @@ -1916,7 +1920,7 @@ def _error_if_force_scan_on(cls, msg): domains = kernel.domains[:] temp_kernel = kernel - changed = False + kernel_changed = False import loopy as lp while insn_queue: @@ -1925,6 +1929,7 @@ def _error_if_force_scan_on(cls, msg): new_insn_add_within_inames = set() generated_insns = [] + insn_changed = False insn = insn_queue.pop(0) @@ -1947,7 +1952,7 @@ def _error_if_force_scan_on(cls, msg): callables_table=cb_mapper.callables_table, guarding_predicates=insn.predicates), - if generated_insns: + if insn_changed: # An expansion happened, so insert the generated stuff plus # ourselves back into the queue. @@ -2010,14 +2015,14 @@ def _error_if_force_scan_on(cls, msg): domains=domains) temp_kernel = lp.replace_instruction_ids( temp_kernel, insn_id_replacements) - changed = True + kernel_changed = True else: # nothing happened, we're done with insn assert not new_insn_add_depends_on new_insns.append(insn) - if changed: + if kernel_changed: kernel = kernel.copy( instructions=new_insns, temporary_variables=new_temporary_variables, diff --git a/test/test_reduction.py b/test/test_reduction.py index c623c68c6..931628a04 100644 --- a/test/test_reduction.py +++ b/test/test_reduction.py @@ -460,6 +460,27 @@ def test_any_all(ctx_factory): assert not out_dict["out2"].get() +def test_reduction_without_inames(ctx_factory): + """Ensure that reductions with no inames get rewritten to the element + being reduced over. This was sometimes erroneously eliminated because + reduction realization used the generation of new statements as a criterion + for whether work was done. + """ + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + "{:}", + """ + out = reduce(any, [], 5) + """) + knl = lp.set_options(knl, return_dict=True) + + _, out_dict = knl(cq) + + assert out_dict["out"].get() == 5 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])