diff --git a/loopy/preprocess.py b/loopy/preprocess.py index fc0e82afb..6d5346d7e 100644 --- a/loopy/preprocess.py +++ b/loopy/preprocess.py @@ -1751,6 +1751,8 @@ def map_scan_local(expr, rec, callables_table, nresults, arg_dtypes, def map_reduction(expr, rec, callables_table, guarding_predicates, nresults=1): + nonlocal made_changes + # Only expand one level of reduction at a time, going from outermost to # innermost. Otherwise we get the (iname + insn) dependencies wrong. @@ -1838,6 +1840,7 @@ def _error_if_force_scan_on(cls, msg): # to reduce over. It's rather similar to an array with () shape in # numpy.) + made_changes = True return expr.expr, callables_table # }}} @@ -1866,6 +1869,7 @@ def _error_if_force_scan_on(cls, msg): ", ".join(tag.key for tag in temp_kernel.iname_tags(sweep_iname)))) elif parallel: + made_changes = True return map_scan_local( expr, rec, callables_table, nresults, arg_dtypes, reduction_dtypes, @@ -1875,6 +1879,7 @@ def _error_if_force_scan_on(cls, msg): scan_param.stride, guarding_predicates) elif sequential: + made_changes = True return map_scan_seq( expr, rec, callables_table, nresults, arg_dtypes, reduction_dtypes, sweep_iname, @@ -1903,6 +1908,7 @@ def _error_if_force_scan_on(cls, msg): guarding_predicates) else: assert n_local_par > 0 + made_changes = True return map_reduction_local( expr, rec, callables_table, nresults, arg_dtypes, reduction_dtypes, guarding_predicates) @@ -1925,6 +1931,7 @@ def _error_if_force_scan_on(cls, msg): new_insn_add_within_inames = set() generated_insns = [] + made_changes = False insn = insn_queue.pop(0) @@ -1947,7 +1954,7 @@ def _error_if_force_scan_on(cls, msg): callables_table=cb_mapper.callables_table, guarding_predicates=insn.predicates), - if generated_insns: + if made_changes: # An expansion happened, so insert the generated stuff plus # ourselves back into the queue. diff --git a/test/test_reduction.py b/test/test_reduction.py index c623c68c6..c06ec5ec7 100644 --- a/test/test_reduction.py +++ b/test/test_reduction.py @@ -460,6 +460,25 @@ def test_any_all(ctx_factory): assert not out_dict["out2"].get() +def test_reduction_without_inames(ctx_factory): + """Ensure that reductions with no inames get rewritten to the element + being reduced over. This was sometimes erroneously eliminated because + reduction realization used the generation of new statements as a criterion + for whether work was done. + """ + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + knl = lp.make_kernel( + "{:}", + """ + out = reduce(any, [], 5) + """) + knl = lp.set_options(knl, return_dict=True) + + _, out_dict = knl(cq) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1])