From c09e48d5d82391ef87413c213e797c5a872784b7 Mon Sep 17 00:00:00 2001 From: Affifboudaoud Date: Sat, 22 Nov 2025 15:30:00 +0100 Subject: [PATCH 1/2] Add Nested SDFG initialization removal tranformation --- dace/transformation/interstate/__init__.py | 1 + .../interstate/move_reduce_init.py | 262 ++++++++++++ .../interstate/move_reduction_init_tests.py | 376 ++++++++++++++++++ 3 files changed, 639 insertions(+) create mode 100644 dace/transformation/interstate/move_reduce_init.py create mode 100644 tests/transformations/interstate/move_reduction_init_tests.py diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index 8464f7218f..6469033b33 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -20,3 +20,4 @@ from .trivial_loop_elimination import TrivialLoopElimination from .multistate_inline import InlineMultistateSDFG from .move_assignment_outside_if import MoveAssignmentOutsideIf +from .move_reduce_init import MoveReduceInitOutOfNestedSDFG diff --git a/dace/transformation/interstate/move_reduce_init.py b/dace/transformation/interstate/move_reduce_init.py new file mode 100644 index 0000000000..cbed1736b2 --- /dev/null +++ b/dace/transformation/interstate/move_reduce_init.py @@ -0,0 +1,262 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import copy +from typing import Set, Dict, Optional + +import dace +from dace import symbolic, data as dt, InterstateEdge +from dace.sdfg import nodes, SDFG, SDFGState +from dace.sdfg import utils as sdutil +from dace.transformation import transformation +from dace.transformation.passes.analysis import StateReachability +from dace.properties import make_properties + + +def _is_init_state(state: SDFGState) -> bool: + """ + Check if a state is an initialization state for reduction. + An init state has: + - A single top-level map + - A single tasklet inside the map with no data inputs + - The tasklet writes to an output array + """ + scope_dict = state.scope_dict() + map_entries = [n for n in state.nodes() if isinstance(n, nodes.MapEntry) and scope_dict[n] is None] + + if len(map_entries) != 1: + return False + + map_entry = map_entries[0] + + scope_nodes = state.scope_subgraph(map_entry).nodes() + tasklets = [n for n in scope_nodes if isinstance(n, nodes.Tasklet)] + if len(tasklets) != 1: + return False + + tasklet = tasklets[0] + for e in state.in_edges(tasklet): + if e.data.data is not None and not e.data.is_empty(): + return False + + has_output = False + for e in state.out_edges(tasklet): + if e.data.data is not None and not e.data.is_empty(): + has_output = True + break + + return has_output + + +def _get_init_output_arrays(state: SDFGState) -> Set[str]: + """Get the output arrays being initialized in the state.""" + outputs = set() + for node in state.sink_nodes(): + if isinstance(node, nodes.AccessNode): + outputs.add(node.data) + return outputs + + +def _is_written_before(sdfg: SDFG, + state: SDFGState, + nsdfg_node: nodes.NestedSDFG, + array_name: str, + reachability: Optional[Dict[SDFGState, Set[SDFGState]]] = None) -> bool: + """ + Check if an array is written before the nested SDFG in the parent SDFG. + Uses StateReachability to find all states that can reach the current state, + then checks their write sets. + """ + if reachability is None: + reachability_pass = StateReachability() + all_reachability = reachability_pass.apply_pass(sdfg, {}) + reachability = all_reachability.get(sdfg.cfg_id, {}) + + # Find all states that can reach the current state + states_reaching_current = set() + for s, reachable in reachability.items(): + if state in reachable: + states_reaching_current.add(s) + + # Check if any reaching state writes to the array + for src_state in states_reaching_current: + _, write_set = src_state.read_and_write_sets() + if array_name in write_set: + return True + + # Check if another node in the same state writes to the array + for node in state.nodes(): + if node == nsdfg_node: + continue + if isinstance(node, nodes.AccessNode) and node.data == array_name: + if state.in_degree(node) > 0: + for e in state.in_edges(node): + if e.src != nsdfg_node: + return True + + return False + + +def _substitute_symbols(nsdfg_node: nodes.NestedSDFG, rng: tuple) -> tuple: + """Substitute nested SDFG symbols with outer values in a range tuple.""" + new_rng = list(rng) + for inner_sym, outer_val in nsdfg_node.symbol_mapping.items(): + for i in range(3): + if symbolic.issymbolic(new_rng[i]): + new_rng[i] = new_rng[i].subs({inner_sym: outer_val}) + return tuple(new_rng) + + +@make_properties +@transformation.explicit_cf_compatible +class MoveReduceInitOutOfNestedSDFG(transformation.SingleStateTransformation): + """ + Moves reduction initialization from a nested SDFG to a new state at the start + of the SDFG. Having these initializations in NestedSDFGs blocks inlining. + + The transformation looks for nested SDFGs that have: + 1. An initialization state as the start state (single map, single tasklet with no inputs) + 2. The initialized array is not written before the nested SDFG + + After transformation: + - A new state is added at the start of the SDFG with the initialization map + - The nested SDFG's initialization state is removed + """ + + nested_sdfg = transformation.PatternNode(nodes.NestedSDFG) + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.nested_sdfg)] + + def can_be_applied(self, graph: SDFGState, expr_index, sdfg: SDFG, permissive=False): + nsdfg_node = self.nested_sdfg + nsdfg = nsdfg_node.sdfg + + start_state = nsdfg.start_state + + # Check if the start state matches the init pattern + if not _is_init_state(start_state): + return False + + init_outputs = _get_init_output_arrays(start_state) + if not init_outputs: + return False + + # Compute reachability once for all output checks + reachability_pass = StateReachability() + all_reachability = reachability_pass.apply_pass(sdfg, {}) + reachability = all_reachability.get(sdfg.cfg_id, {}) + + # Verify each output array is valid for transformation + for output in init_outputs: + # Output must be an out_connector of the nested SDFG + if output not in nsdfg_node.out_connectors: + return False + + # Find the corresponding edge in the parent state + outer_edge = None + for e in graph.out_edges(nsdfg_node): + if e.src_conn == output: + outer_edge = e + break + + if outer_edge is None: + return False + + # The outer array must not be written before this nested SDFG + outer_array = outer_edge.data.data + + if _is_written_before(sdfg, graph, nsdfg_node, outer_array, reachability): + return False + + return True + + def apply(self, state: SDFGState, sdfg: SDFG): + nsdfg_node = self.nested_sdfg + nsdfg = nsdfg_node.sdfg + + start_state = nsdfg.start_state + successors = list(nsdfg.successors(start_state)) + next_state = successors[0] if successors else None + + # Create new init state at the start of the parent SDFG + old_start = sdfg.start_state + init_state = sdfg.add_state(label='reduce_init') + sdfg.add_edge(init_state, old_start, InterstateEdge()) + sdfg.start_block = sdfg.node_id(init_state) + + # Build mapping from inner array names to outer (outside NestedSDFG) array names. + # If the destination is a View, resolve to the underlying array. + connector_map = {} + for e in state.out_edges(nsdfg_node): + if e.src_conn in _get_init_output_arrays(start_state): + if isinstance(e.dst, nodes.AccessNode): + arr_name = e.dst.data + arr = sdfg.arrays.get(arr_name) + if isinstance(arr, dt.View): + view_edge = sdutil.get_view_edge(state, e.dst) + if view_edge is not None: + arr_name = view_edge.data.data + connector_map[e.src_conn] = arr_name + else: + connector_map[e.src_conn] = e.data.data + + # Copy map nodes first + + node_map = {} + for node in start_state.nodes(): + if isinstance(node, nodes.MapEntry): + new_entry = copy.deepcopy(node) + new_range = [_substitute_symbols(nsdfg_node, rng) for rng in new_entry.map.range] + new_entry.map.range = dace.subsets.Range(new_range) + node_map[node] = new_entry + exit_node = start_state.exit_node(node) + new_exit = copy.deepcopy(exit_node) + + # MapEntry and MapExit need to share the same Map object + new_exit.map = new_entry.map + node_map[exit_node] = new_exit + + # Copy remaining nodes (AccessNodes with renamed arrays, Tasklets, etc.) + for node in start_state.nodes(): + if node in node_map: + continue + if isinstance(node, nodes.AccessNode): + inner_name = node.data + outer_name = connector_map.get(inner_name, inner_name) + new_node = nodes.AccessNode(outer_name) + else: + new_node = copy.deepcopy(node) + node_map[node] = new_node + + # Add all nodes to the new init state + for node in start_state.nodes(): + init_state.add_node(node_map[node]) + + # Copy edges with updated memlets (renamed arrays, substituted symbols) + for edge in start_state.edges(): + src = node_map.get(edge.src) + dst = node_map.get(edge.dst) + if src is None or dst is None: + continue + + new_memlet = copy.deepcopy(edge.data) + if new_memlet.data is not None: + inner_name = new_memlet.data + outer_name = connector_map.get(inner_name, inner_name) + new_memlet.data = outer_name + + if new_memlet.subset is not None: + new_subset = [_substitute_symbols(nsdfg_node, rng) for rng in new_memlet.subset] + new_memlet.subset = dace.subsets.Range(new_subset) + + init_state.add_edge(src, edge.src_conn, dst, edge.dst_conn, new_memlet) + + # Remove the init state from the nested SDFG and update its start block + nsdfg.remove_node(start_state) + + if len(nsdfg.nodes()) > 0 and next_state is not None: + nsdfg.start_block = nsdfg.node_id(next_state) + nsdfg.reset_cfg_list() + + return init_state diff --git a/tests/transformations/interstate/move_reduction_init_tests.py b/tests/transformations/interstate/move_reduction_init_tests.py new file mode 100644 index 0000000000..299c8f69d5 --- /dev/null +++ b/tests/transformations/interstate/move_reduction_init_tests.py @@ -0,0 +1,376 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. + +import numpy as np + +import dace +from dace import SDFG, Memlet, nodes +from dace.transformation.interstate import MoveReduceInitOutOfNestedSDFG, InlineSDFG + + +def create_nested_sdfg_with_reduce_init(): + """Create a nested SDFG that has an initialization state and a reduction state.""" + outer_sdfg = SDFG('outer') + outer_sdfg.add_array('A', [10, 10], dace.float64) + outer_sdfg.add_array('B', [10], dace.float64) + + nsdfg = SDFG('nested_reduce') + nsdfg.add_array('_A', [10, 10], dace.float64) + nsdfg.add_array('_B', [10], dace.float64) + + init_state = nsdfg.add_state('reduce_init') + init_state.add_mapped_tasklet('reduce_init_map', {'i': '0:10'}, {}, + '__out = 0.0', {'__out': Memlet('_B[i]')}, + external_edges=True) + + reduce_state = nsdfg.add_state('reduce') + reduce_state.add_mapped_tasklet('reduce_map', { + 'i': '0:10', + 'j': '0:10' + }, {'__in': Memlet('_A[i, j]')}, + '__out = __in', {'__out': Memlet('_B[i]', wcr='lambda a, b: a + b')}, + external_edges=True) + + nsdfg.add_edge(init_state, reduce_state, dace.InterstateEdge()) + + outer_state = outer_sdfg.add_state('main') + read_a = outer_state.add_read('A') + write_b = outer_state.add_write('B') + + nsdfg_node = outer_state.add_nested_sdfg(nsdfg, {'_A'}, {'_B'}) + outer_state.add_edge(read_a, None, nsdfg_node, '_A', Memlet('A[0:10, 0:10]')) + outer_state.add_edge(nsdfg_node, '_B', write_b, None, Memlet('B[0:10]')) + + return outer_sdfg + + +def test_move_reduce_init_basic(): + """Test basic application of the transformation.""" + sdfg = create_nested_sdfg_with_reduce_init() + + num_states_before = len(list(sdfg.states())) + assert num_states_before == 1 + + nsdfg_node = None + for node in sdfg.states()[0].nodes(): + if isinstance(node, nodes.NestedSDFG): + nsdfg_node = node + break + + assert nsdfg_node is not None + assert len(list(nsdfg_node.sdfg.nodes())) == 2 + + applied = sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 1 + + num_states_after = len(list(sdfg.states())) + assert num_states_after == 2 + + nsdfg_node = None + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + nsdfg_node = node + break + + assert nsdfg_node is not None + assert len(list(nsdfg_node.sdfg.nodes())) == 1 + + +def test_move_reduce_init_enables_inlining(): + """Test that after moving reduce init, the nested SDFG can be inlined.""" + sdfg = create_nested_sdfg_with_reduce_init() + + inline_before = sdfg.apply_transformations(InlineSDFG) + assert inline_before == 0 + + sdfg = create_nested_sdfg_with_reduce_init() + + applied = sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 1 + + inline_after = sdfg.apply_transformations(InlineSDFG) + assert inline_after == 1 + + has_nested = False + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + has_nested = True + + assert not has_nested + + +def test_move_reduce_init_correctness(): + """Test that the transformation preserves correctness.""" + sdfg = create_nested_sdfg_with_reduce_init() + + A = np.random.rand(10, 10) + B_before = np.zeros(10) + sdfg(A=A.copy(), B=B_before) + + sdfg = create_nested_sdfg_with_reduce_init() + sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + + B_after = np.zeros(10) + sdfg(A=A.copy(), B=B_after) + + assert np.allclose(B_before, B_after) + assert np.allclose(B_after, A.sum(axis=1)) + + +def test_move_reduce_init_not_applied_when_written_before(): + """Test that transformation is not applied when the output is written before.""" + outer_sdfg = SDFG('outer') + outer_sdfg.add_array('A', [10, 10], dace.float64) + outer_sdfg.add_array('B', [10], dace.float64) + + nsdfg = SDFG('nested_reduce') + nsdfg.add_array('_A', [10, 10], dace.float64) + nsdfg.add_array('_B', [10], dace.float64) + + init_state = nsdfg.add_state('reduce_init') + init_state.add_mapped_tasklet('reduce_init_map', {'i': '0:10'}, {}, + '__out = 0.0', {'__out': Memlet('_B[i]')}, + external_edges=True) + + reduce_state = nsdfg.add_state('reduce') + reduce_state.add_mapped_tasklet('reduce_map', { + 'i': '0:10', + 'j': '0:10' + }, {'__in': Memlet('_A[i, j]')}, + '__out = __in', {'__out': Memlet('_B[i]', wcr='lambda a, b: a + b')}, + external_edges=True) + + nsdfg.add_edge(init_state, reduce_state, dace.InterstateEdge()) + + pre_state = outer_sdfg.add_state('pre_write') + pre_state.add_mapped_tasklet('pre_init', {'i': '0:10'}, {}, + '__out = 1.0', {'__out': Memlet('B[i]')}, + external_edges=True) + + main_state = outer_sdfg.add_state('main') + read_a = main_state.add_read('A') + read_b = main_state.add_read('B') + write_b = main_state.add_write('B') + + nsdfg_node = main_state.add_nested_sdfg(nsdfg, {'_A', '_B'}, {'_B'}) + main_state.add_edge(read_a, None, nsdfg_node, '_A', Memlet('A[0:10, 0:10]')) + main_state.add_edge(read_b, None, nsdfg_node, '_B', Memlet('B[0:10]')) + main_state.add_edge(nsdfg_node, '_B', write_b, None, Memlet('B[0:10]')) + + outer_sdfg.add_edge(pre_state, main_state, dace.InterstateEdge()) + + applied = outer_sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 0 + + +def test_move_reduce_init_not_applied_non_init_first_state(): + """Test that transformation is not applied when first state is not an init state.""" + outer_sdfg = SDFG('outer') + outer_sdfg.add_array('A', [10, 10], dace.float64) + outer_sdfg.add_array('B', [10], dace.float64) + + nsdfg = SDFG('nested_reduce') + nsdfg.add_array('_A', [10, 10], dace.float64) + nsdfg.add_array('_B', [10], dace.float64) + nsdfg.add_transient('_tmp', [10, 10], dace.float64) + + state1 = nsdfg.add_state('compute1') + state1.add_mapped_tasklet('compute_map', { + 'i': '0:10', + 'j': '0:10' + }, {'__in': Memlet('_A[i, j]')}, + '__out = __in * 2', {'__out': Memlet('_tmp[i, j]')}, + external_edges=True) + + state2 = nsdfg.add_state('compute2') + state2.add_mapped_tasklet('reduce_map', { + 'i': '0:10', + 'j': '0:10' + }, {'__in': Memlet('_tmp[i, j]')}, + '__out = __in', {'__out': Memlet('_B[i]', wcr='lambda a, b: a + b')}, + external_edges=True) + + nsdfg.add_edge(state1, state2, dace.InterstateEdge()) + + outer_state = outer_sdfg.add_state('main') + read_a = outer_state.add_read('A') + write_b = outer_state.add_write('B') + + nsdfg_node = outer_state.add_nested_sdfg(nsdfg, {'_A'}, {'_B'}) + outer_state.add_edge(read_a, None, nsdfg_node, '_A', Memlet('A[0:10, 0:10]')) + outer_state.add_edge(nsdfg_node, '_B', write_b, None, Memlet('B[0:10]')) + + applied = outer_sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 0 + + +def test_move_reduce_init_from_library_node(): + """Test transformation on reduce library node expanded via ExpandReducePure.""" + + @dace.program + def reduce_sum(A: dace.float64[10, 10], B: dace.float64[10]): + B[:] = dace.reduce(lambda a, b: a + b, A, axis=1, identity=0) + + sdfg = reduce_sum.to_sdfg() + sdfg.simplify() + + sdfg.expand_library_nodes() + + nsdfg_node = None + main_state = None + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + nsdfg_node = node + main_state = state + break + + assert nsdfg_node is not None + num_states_before = len(list(nsdfg_node.sdfg.nodes())) + assert num_states_before >= 2 + + init_state = nsdfg_node.sdfg.start_state + found_reduce_init = False + for node in init_state.nodes(): + if isinstance(node, nodes.MapEntry): + if 'reduce_init' in node.map.label.lower(): + found_reduce_init = True + break + assert found_reduce_init + + A = np.random.rand(10, 10) + B_before = np.zeros(10) + sdfg(A=A.copy(), B=B_before) + + applied = sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 1 + + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + nsdfg_node = node + break + + assert len(list(nsdfg_node.sdfg.nodes())) == num_states_before - 1 + + B_after = np.zeros(10) + sdfg(A=A.copy(), B=B_after) + + assert np.allclose(B_before, B_after) + assert np.allclose(B_after, A.sum(axis=1)) + + +def test_move_reduce_init_from_library_node_enables_inlining(): + """Test that after applying transformation on expanded reduce, inlining works.""" + + @dace.program + def reduce_sum(A: dace.float64[10, 10], B: dace.float64[10]): + B[:] = dace.reduce(lambda a, b: a + b, A, axis=1, identity=0) + + sdfg = reduce_sum.to_sdfg() + sdfg.simplify() + + sdfg.expand_library_nodes() + + inline_before = sdfg.apply_transformations(InlineSDFG) + assert inline_before == 0 + + sdfg = reduce_sum.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + applied = sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 1 + + inline_after = sdfg.apply_transformations(InlineSDFG) + assert inline_after == 1 + + has_nested = False + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + has_nested = True + + assert not has_nested + + +def test_move_reduce_init_from_library_node_max_reduction(): + """Test transformation on max reduce library node.""" + + @dace.program + def reduce_max(A: dace.float64[10, 10], B: dace.float64[10]): + B[:] = dace.reduce(lambda a, b: max(a, b), A, axis=1, identity=-np.inf) + + sdfg = reduce_max.to_sdfg() + sdfg.simplify() + + sdfg.expand_library_nodes() + + A = np.random.rand(10, 10) + B_before = np.zeros(10) + sdfg(A=A.copy(), B=B_before) + + applied = sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 1 + + B_after = np.zeros(10) + sdfg(A=A.copy(), B=B_after) + + assert np.allclose(B_before, B_after) + assert np.allclose(B_after, A.max(axis=1)) + + +def test_move_reduce_init_multiple_reductions(): + """Test transformation with multiple reduce operations applied repeatedly.""" + + @dace.program + def multi_reduce(A: dace.float64[10, 10], B: dace.float64[10], C: dace.float64[10], D: dace.float64[1]): + B[:] = dace.reduce(lambda a, b: a + b, A, axis=1, identity=0) + C[:] = dace.reduce(lambda a, b: max(a, b), A, axis=1, identity=-np.inf) + D[0] = dace.reduce(lambda a, b: a + b, B, identity=0) + + sdfg = multi_reduce.to_sdfg() + sdfg.simplify() + sdfg.expand_library_nodes() + + nsdfg_count = 0 + for state in sdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + nsdfg_count += 1 + + assert nsdfg_count >= 2 + + A = np.random.rand(10, 10) + B_before = np.zeros(10) + C_before = np.zeros(10) + D_before = np.zeros(1) + sdfg(A=A.copy(), B=B_before, C=C_before, D=D_before) + + applied = sdfg.apply_transformations_repeated(MoveReduceInitOutOfNestedSDFG) + assert applied >= 2 + + B_after = np.zeros(10) + C_after = np.zeros(10) + D_after = np.zeros(1) + sdfg(A=A.copy(), B=B_after, C=C_after, D=D_after) + + assert np.allclose(B_before, B_after) + assert np.allclose(C_before, C_after) + assert np.allclose(D_before, D_after) + assert np.allclose(B_after, A.sum(axis=1)) + assert np.allclose(C_after, A.max(axis=1)) + assert np.allclose(D_after, A.sum()) + + +if __name__ == '__main__': + test_move_reduce_init_basic() + test_move_reduce_init_enables_inlining() + test_move_reduce_init_correctness() + test_move_reduce_init_not_applied_when_written_before() + test_move_reduce_init_not_applied_non_init_first_state() + test_move_reduce_init_from_library_node() + test_move_reduce_init_from_library_node_enables_inlining() + test_move_reduce_init_from_library_node_max_reduction() + test_move_reduce_init_multiple_reductions() From 0365784ce84ad16b78a4627fefc7390f95844fb3 Mon Sep 17 00:00:00 2001 From: Affifboudaoud Date: Sat, 22 Nov 2025 15:53:23 +0100 Subject: [PATCH 2/2] Fix Memlet squeezing edge case --- .../interstate/move_reduce_init.py | 57 +++++++++++++++-- .../interstate/move_reduction_init_tests.py | 61 +++++++++++++++++++ 2 files changed, 113 insertions(+), 5 deletions(-) diff --git a/dace/transformation/interstate/move_reduce_init.py b/dace/transformation/interstate/move_reduce_init.py index cbed1736b2..10d1af08f9 100644 --- a/dace/transformation/interstate/move_reduce_init.py +++ b/dace/transformation/interstate/move_reduce_init.py @@ -106,6 +106,44 @@ def _substitute_symbols(nsdfg_node: nodes.NestedSDFG, rng: tuple) -> tuple: return tuple(new_rng) +def _compose_subsets(inner_subset: dace.subsets.Range, outer_subset: dace.subsets.Range, + nsdfg_node: nodes.NestedSDFG) -> dace.subsets.Range: + """ + Compose inner and outer subsets when inner array has fewer dimensions. + + For example: + - Inner subset: [_o0, _o1, _o2] (3D) + - Outer subset: [0:2, 0:8, 0:128, 0] (4D, last dim is squeezed) + - Result: [_o0, _o1, _o2, 0] (4D) + + The composition replaces ranges in the outer subset with the corresponding + inner subset indices, keeping fixed dimensions (size 1) as-is. + """ + inner_ranges = list(inner_subset) + outer_ranges = list(outer_subset) + + if len(inner_ranges) == len(outer_ranges): + return dace.subsets.Range([_substitute_symbols(nsdfg_node, rng) for rng in inner_ranges]) + + result_ranges = [] + inner_idx = 0 + + for outer_rng in outer_ranges: + start, end, _ = outer_rng + size = (end - start + 1) if not symbolic.issymbolic(end - start) else None + + if size == 1: + result_ranges.append(_substitute_symbols(nsdfg_node, outer_rng)) + else: + if inner_idx < len(inner_ranges): + result_ranges.append(_substitute_symbols(nsdfg_node, inner_ranges[inner_idx])) + inner_idx += 1 + else: + result_ranges.append(_substitute_symbols(nsdfg_node, outer_rng)) + + return dace.subsets.Range(result_ranges) + + @make_properties @transformation.explicit_cf_compatible class MoveReduceInitOutOfNestedSDFG(transformation.SingleStateTransformation): @@ -185,11 +223,13 @@ def apply(self, state: SDFGState, sdfg: SDFG): sdfg.add_edge(init_state, old_start, InterstateEdge()) sdfg.start_block = sdfg.node_id(init_state) - # Build mapping from inner array names to outer (outside NestedSDFG) array names. - # If the destination is a View, resolve to the underlying array. - connector_map = {} + # Build mapping from inner array names to outer (outside NestedSDFG) array names + # and their subsets. If the destination is a View, resolve to the underlying array. + connector_map = {} # inner_name -> outer_name + outer_subsets = {} # inner_name -> outer_subset (for dimension composition) for e in state.out_edges(nsdfg_node): if e.src_conn in _get_init_output_arrays(start_state): + outer_subset = e.data.subset if isinstance(e.dst, nodes.AccessNode): arr_name = e.dst.data arr = sdfg.arrays.get(arr_name) @@ -197,9 +237,12 @@ def apply(self, state: SDFGState, sdfg: SDFG): view_edge = sdutil.get_view_edge(state, e.dst) if view_edge is not None: arr_name = view_edge.data.data + outer_subset = view_edge.data.subset connector_map[e.src_conn] = arr_name + outer_subsets[e.src_conn] = outer_subset else: connector_map[e.src_conn] = e.data.data + outer_subsets[e.src_conn] = outer_subset # Copy map nodes first @@ -247,8 +290,12 @@ def apply(self, state: SDFGState, sdfg: SDFG): new_memlet.data = outer_name if new_memlet.subset is not None: - new_subset = [_substitute_symbols(nsdfg_node, rng) for rng in new_memlet.subset] - new_memlet.subset = dace.subsets.Range(new_subset) + outer_subset = outer_subsets.get(inner_name) + if outer_subset is not None and len(outer_subset) != len(new_memlet.subset): + new_memlet.subset = _compose_subsets(new_memlet.subset, outer_subset, nsdfg_node) + else: + new_subset = [_substitute_symbols(nsdfg_node, rng) for rng in new_memlet.subset] + new_memlet.subset = dace.subsets.Range(new_subset) init_state.add_edge(src, edge.src_conn, dst, edge.dst_conn, new_memlet) diff --git a/tests/transformations/interstate/move_reduction_init_tests.py b/tests/transformations/interstate/move_reduction_init_tests.py index 299c8f69d5..597cd0f478 100644 --- a/tests/transformations/interstate/move_reduction_init_tests.py +++ b/tests/transformations/interstate/move_reduction_init_tests.py @@ -364,6 +364,66 @@ def multi_reduce(A: dace.float64[10, 10], B: dace.float64[10], C: dace.float64[1 assert np.allclose(D_after, A.sum()) +def test_move_reduce_init_dimension_mismatch(): + """Test transformation when inner array has fewer dimensions than outer array. + + This tests the case where the nested SDFG's array is a squeezed view of the + outer array (e.g., inner 3D array mapping to outer 4D array with size-1 dim). + """ + outer_sdfg = SDFG('outer') + outer_sdfg.add_array('A', [2, 8, 128, 64], dace.float64) + outer_sdfg.add_array('B', [2, 8, 128, 1], dace.float64) + + nsdfg = SDFG('nested_reduce') + nsdfg.add_array('_A', [2, 8, 128, 64], dace.float64) + nsdfg.add_array('_B', [2, 8, 128], dace.float64) + + init_state = nsdfg.add_state('reduce_init') + init_state.add_mapped_tasklet('reduce_init_map', { + '_o0': '0:2', + '_o1': '0:8', + '_o2': '0:128' + }, {}, + '__out = 0.0', {'__out': Memlet('_B[_o0, _o1, _o2]')}, + external_edges=True) + + reduce_state = nsdfg.add_state('reduce') + reduce_state.add_mapped_tasklet('reduce_map', { + 'i': '0:2', + 'j': '0:8', + 'k': '0:128', + 'l': '0:64' + }, {'__in': Memlet('_A[i, j, k, l]')}, + '__out = __in', {'__out': Memlet('_B[i, j, k]', wcr='lambda a, b: a + b')}, + external_edges=True) + + nsdfg.add_edge(init_state, reduce_state, dace.InterstateEdge()) + + outer_state = outer_sdfg.add_state('main') + read_a = outer_state.add_read('A') + write_b = outer_state.add_write('B') + + nsdfg_node = outer_state.add_nested_sdfg(nsdfg, {'_A'}, {'_B'}) + outer_state.add_edge(read_a, None, nsdfg_node, '_A', Memlet('A[0:2, 0:8, 0:128, 0:64]')) + outer_state.add_edge(nsdfg_node, '_B', write_b, None, Memlet('B[0:2, 0:8, 0:128, 0]')) + + A = np.random.rand(2, 8, 128, 64) + B_before = np.zeros((2, 8, 128, 1)) + outer_sdfg(A=A.copy(), B=B_before) + + applied = outer_sdfg.apply_transformations(MoveReduceInitOutOfNestedSDFG) + assert applied == 1 + + outer_sdfg.validate() + + B_after = np.zeros((2, 8, 128, 1)) + outer_sdfg(A=A.copy(), B=B_after) + + assert np.allclose(B_before, B_after) + expected = A.sum(axis=3, keepdims=True) + assert np.allclose(B_after, expected) + + if __name__ == '__main__': test_move_reduce_init_basic() test_move_reduce_init_enables_inlining() @@ -374,3 +434,4 @@ def multi_reduce(A: dace.float64[10, 10], B: dace.float64[10], C: dace.float64[1 test_move_reduce_init_from_library_node_enables_inlining() test_move_reduce_init_from_library_node_max_reduction() test_move_reduce_init_multiple_reductions() + test_move_reduce_init_dimension_mismatch()