diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 1db9047199..f421778051 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -12,6 +12,7 @@ import dace from dace import ( + data as dace_data, dtypes as dace_dtypes, properties as dace_properties, subsets as dace_sbs, @@ -154,6 +155,16 @@ def can_be_applied( if all(len(rel_df) == 0 for rel_df in relocatable_dataflow.values()): return False + # Check if relatability is possible. + if not self._check_for_data_and_symbol_conflicts( + sdfg=sdfg, + state=graph, + relocatable_dataflow=relocatable_dataflow, + enclosing_map=enclosing_map, + if_block=if_block, + ): + return False + # Because the transformation can only handle `if` expressions that # are _directly_ inside a Map, we must check if the upstream contains # suitable `if` expressions that must be processed first. The simplest way @@ -480,6 +491,74 @@ def _update_symbol_mapping( or dace_dtypes.typeclass(int), ) + def _check_for_data_and_symbol_conflicts( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + relocatable_dataflow: dict[str, set[dace_nodes.Node]], + if_block: dace_nodes.NestedSDFG, + enclosing_map: dace_nodes.MapEntry, + ) -> bool: + """Check if the relocation would cause any conflict, such as a symbol clash.""" + + # TODO(phimuell): There is an obscure case where the nested SDFG, on its own, + # defines a symbol that is also mapped, for example a dynamic Map range. + # It is probably not a problem, because of the scopes DaCe adds when + # generating the C++ code. + + # Create a subgraph to compute the free symbols, i.e. the symbols that + # need to be supplied from the outside. However, this are not all. + # Note, just adding some "well chosen" nodes to the set will not work. + all_relocated_dataflow: set[dace_nodes.Node] = functools.reduce( + lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() + ) + requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView( + state, all_relocated_dataflow + ).free_symbols + + inner_data_names = if_block.sdfg.arrays.keys() + for node_to_check in all_relocated_dataflow: + if ( + isinstance(node_to_check, dace_nodes.AccessNode) + and node_to_check.data in inner_data_names + ): + # There is already a data descriptor that is used on the inside as on + # the outside. Thus we would have to perform some renaming, which we + # currently do not. + # TODO(phimell): Handle this case. + return False + + for iedge in state.in_edges(node_to_check): + src_node = iedge.src + if src_node not in all_relocated_dataflow: + # This means that `src_node` is not relocated but mapped into the + # `if` block. This means that `edge` is replicated as well. + # NOTE: This code is based on the one found in `DataflowGraphView`. + # TODO(phimuell): Do we have to inspect the full Memlet path here? + assert isinstance(src_node, dace_nodes.AccessNode) or src_node is enclosing_map + requiered_symbols |= iedge.data.used_symbols(True, edge=iedge) + + # The (beyond the enclosing Map) data is also mapped into the `if` block, so we + # have to consider that as well. + for iedge in state.in_edges(if_block): + if iedge.src is enclosing_map and (not iedge.data.is_empty()): + outside_desc = sdfg.arrays[iedge.data.data] + if isinstance(outside_desc, dace_data.View): + return False # Handle this case. + requiered_symbols |= outside_desc.used_symbols(True) + + # A conflicting symbol is a free symbol of the relocatable dataflow, that is not a + # direct mapping. For example if there is a symbol `n` on the inside and outside + # then everything is okay if the symbol mapping is `{n: n}` i.e. the symbol has the + # same meaning inside and outside. Everything else is not okay. + symbol_mapping = if_block.symbol_mapping + conflicting_symbols = requiered_symbols.intersection((str(k) for k in symbol_mapping)) + for conflicting_symbol in conflicting_symbols: + if conflicting_symbol != str(symbol_mapping[conflicting_symbol]): + return False + + return True + def _find_branch_for( self, if_block: dace_nodes.NestedSDFG, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 03aba6599e..b185fc88e1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -93,9 +93,8 @@ def _perform_test( return sdfg # General case, run the SDFG first and then compare the result. - ref, res = util.make_sdfg_args(sdfg) - if explected_applies != 0: + ref, res = util.make_sdfg_args(sdfg) util.compile_and_run_sdfg(sdfg, **ref) nb_apply = sdfg.apply_transformations_repeated( @@ -1199,3 +1198,102 @@ def test_if_mover_access_node_between(): "__cond", } assert set(top_if_block.sdfg.arrays.keys()) == expected_top_if_block_data + + +def test_if_mover_symbol_aliasing(): + """Tests if symbol clashes are detected. + + Essentially there is a symbol `n` both in the parent SDFG and the `if_block`, + however, with different meanings. Thus the relocation will lead to invalid + behaviour and should be rejected. + """ + sdfg = dace.SDFG(util.unique_name("if_mover_symbol_alias")) + state = sdfg.add_state(is_start_block=True) + + scalar_names = ["cond", "a1", "b2"] + array_names = list("abcd") + sdfg.add_symbol("n", stype=dace.int32) + for aname in array_names: + sdfg.add_array( + aname, + shape=((10, "n") if aname in "ab" else (10,)), + dtype=dace.float64, + transient=False, + ) + for sname in scalar_names: + sdfg.add_scalar( + sname, + dtype=(dace.bool_ if sname == "cond" else dace.float64), + transient=True, + ) + a, b, c, d, cond_ac, true_ac, false_ac = ( + state.add_access(name) for name in array_names + scalar_names + ) + + me, mx = state.add_map("outer_map", ndrange={"__i": "0:10"}) + + for ac in [a, b, c]: + state.add_edge( + ac, + None, + me, + f"IN_{ac.data}", + dace.Memlet(f"{ac.data}[0:10" + ("]" if ac is c else ", 0:n]")), + ) + me.add_scope_connectors(ac.data) + + # Make the condition. + cond_tlet = state.add_tasklet( + "cond_tlet", + inputs={"__in0"}, + outputs={"__out"}, + code="__out = __in0 < 0.0", + ) + state.add_edge(me, "OUT_c", cond_tlet, "__in0", dace.Memlet("c[__i]")) + state.add_edge(cond_tlet, "__out", cond_ac, None, dace.Memlet(f"{cond_ac.data}[0]")) + + # The true branch. + true_tlet = state.add_tasklet( + "true_tlet", + inputs={"__in0"}, + outputs={"__out"}, + code="__out = __in0 + 1.0", + ) + state.add_edge(me, "OUT_a", true_tlet, "__in0", dace.Memlet("a[__i, n - 1]")) + state.add_edge(true_tlet, "__out", true_ac, None, dace.Memlet(f"{true_ac.data}[0]")) + + # False branch + false_tlet = state.add_tasklet( + "false_tlet", + inputs={"__in0"}, + outputs={"__out"}, + code="__out = __in0 + 1.0", + ) + state.add_edge(me, "OUT_b", false_tlet, "__in0", dace.Memlet("b[__i, n - 3]")) + state.add_edge(false_tlet, "__out", false_ac, None, dace.Memlet(f"{false_ac.data}[0]")) + + # Create the top `if_block` + if_block = _make_if_block(state, sdfg) + + # By Adding this symbol mapping, we emulate the case where something is used + # inside and special case must be taken. + assert len(if_block.symbol_mapping) == 0 + if_block.symbol_mapping["n"] = "n - 1" + + # Connect the inputs to the if block. + state.add_edge(true_ac, None, if_block, "__arg1", dace.Memlet(f"{true_ac}[0]")) + state.add_edge(false_ac, None, if_block, "__arg2", dace.Memlet(f"{false_ac}[0]")) + state.add_edge(cond_ac, None, if_block, "__cond", dace.Memlet(f"{cond_ac}[0]")) + + state.add_edge(if_block, "__output", mx, "IN_d", dace.Memlet("d[__i]")) + state.add_edge(mx, "OUT_d", d, None, dace.Memlet("d[0:10]")) + mx.add_scope_connectors("d") + + sdfg.validate() + + # Because `n` is already taken, see above, we need an additional symbol mapping + # to account for the access on the Memlets of the `{true, false}_tlet`. + _perform_test( + sdfg=sdfg, + explected_applies=0, + )