From e1a7149d3fdf11aa26cf6ef41a4723a752da9653 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 9 Feb 2026 09:25:35 -0800 Subject: [PATCH 1/3] Update `SDFG.add_loop` to use `LoopRegion` --- dace/sdfg/sdfg.py | 88 ++++++++++++++++++++++++++++++++++++++++------ dace/sdfg/state.py | 4 ++- 2 files changed, 81 insertions(+), 11 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index d88effd594..bf2b7782f4 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -18,7 +18,7 @@ import pickle import dace -from dace.sdfg.graph import generate_element_id +from dace.sdfg.graph import generate_element_id, SubgraphView import dace.serialize from dace import (data as dt, hooks, memlet as mm, subsets as sbs, dtypes, symbolic) from dace.sdfg.replace import replace_properties_dict @@ -26,7 +26,7 @@ from dace.config import Config from dace.frontend.python import astutils from dace.sdfg import nodes as nd -from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState, ControlFlowRegion, LoopRegion from dace.sdfg.type_inference import infer_expr_type from dace.distr_types import ProcessGrid, SubArray, RedistrArray from dace.dtypes import validate_name @@ -2249,16 +2249,87 @@ def add_rdistrarray(self, array_a: str, array_b: str): self.append_exit_code(self._rdistrarrays[name].exit_code(self)) return name - def add_loop( + def add_loop(self, + before_block: ControlFlowBlock, + loop_start_block: ControlFlowBlock, + after_block: ControlFlowBlock, + loop_var: str, + initialize_expr: str, + condition_expr: str, + increment_expr: str, + loop_end_block: ControlFlowBlock = None) -> LoopRegion: + """ + Helper function that adds a looping control flow block around a + given block/state (or sequence of blocks, if ``loop_end_block`` is provided). + + :param before_block: The block after which the loop should + begin, or None if the loop is the first + block (creates an empty block). + :param loop_start_block: The block that begins the loop. See also + ``loop_end_block`` if the loop is multi-block. + :param after_block: The block that should be invoked after + the loop ends, or None if the program + should terminate (creates an empty block). + :param loop_var: A name of a symbol to use for the loop variable. + :param initialize_expr: A string expression that is assigned + to ``loop_var`` before the loop begins. + If None, does not define an expression. + :param condition_expr: A string condition that occurs every + loop iteration. If None, loops forever. + :param increment_expr: A string expression that is assigned to + ``loop_var`` after every loop iteration. + :param loop_end_block: If the loop wraps multiple blocks, the block + where the loop iteration ends. If None, sets + the end block to ``loop_start_block`` as well. + :return: The generated LoopRegion block. + """ + loop_region = LoopRegion("loop", condition_expr, loop_var, initialize_expr, f'{loop_var} = {increment_expr}') + + # Capture subgraphview of loop body + if loop_start_block in self.states(): + # Find all reachable blocks in loop body + if loop_end_block is None: + blocks = {loop_start_block} + else: + blocks = set(self.all_nodes_between(loop_start_block, loop_end_block)) + blocks.add(loop_start_block) + blocks.add(loop_end_block) + + subgraph = SubgraphView(self, blocks) + remove_subgraph = True + else: + subgraph = SubgraphView(self, [loop_start_block, loop_end_block] if loop_end_block else [loop_start_block]) + remove_subgraph = False + + # Readd subgraph to main SDFG with loop_region as parent + loop_region.add_node(loop_start_block, is_start_block=True) + loop_region.add_nodes_from(set(subgraph.nodes()) - {loop_start_block}) + for e in subgraph.edges(): + loop_region.add_edge(e.src, e.dst, e.data) + + # Remove subgraph from main SDFG, if necessary + if remove_subgraph: + self.remove_nodes_from(blocks) + + # Connect to graph + self.add_node(loop_region, is_start_block=(before_block is None)) + if before_block is not None: + self.add_edge(before_block, loop_region, InterstateEdge()) + if after_block is not None: + self.add_edge(loop_region, after_block, InterstateEdge()) + + return loop_region + + def add_loop_state_machine( self, - before_state, - loop_state, - after_state, + before_state: SDFGState, + loop_state: SDFGState, + after_state: SDFGState, loop_var: str, initialize_expr: str, condition_expr: str, increment_expr: str, - loop_end_state=None, + loop_end_state: Optional[SDFGState] = None, ): """ Helper function that adds a looping state machine around a @@ -2293,9 +2364,6 @@ def add_loop( """ from dace.frontend.python.astutils import negate_expr # Avoid import loops - warnings.warn("SDFG.add_loop is deprecated and will be removed in a future release. Use LoopRegions instead.", - DeprecationWarning) - # Argument checks if loop_var is None and (initialize_expr or increment_expr): raise ValueError("Cannot initalize or increment an empty loop variable") diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index adc77bf439..87cadf48be 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3536,8 +3536,10 @@ def _used_symbols_internal(self, free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment - defined_syms.add(self.loop_variable) if self.init_statement is not None: + # Loops with no initialization statement do not redefine the loop variable + defined_syms.add(self.loop_variable) + free_syms |= self.init_statement.get_free_symbols() if self.update_statement is not None: free_syms |= self.update_statement.get_free_symbols() From f53f6e59542670b2f66e66ddceb08f7e33e39676 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 9 Feb 2026 22:12:28 -0800 Subject: [PATCH 2/3] Corrections to tests that need an inline state machine --- dace/sdfg/sdfg.py | 14 ++- tests/graph_test.py | 3 +- .../writeset_underapproximation_test.py | 47 ++++---- tests/schedule_tree/naming_test.py | 4 +- tests/schedule_tree/schedule_test.py | 7 +- tests/sdfg/data/container_array_test.py | 25 +++-- tests/sdfg/data/structure_test.py | 103 ++++++++++-------- tests/sdfg/reference_test.py | 2 +- tests/transformations/loop_to_map_test.py | 12 +- 9 files changed, 124 insertions(+), 93 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index bf2b7782f4..ee9b080e8a 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -2257,7 +2257,8 @@ def add_loop(self, initialize_expr: str, condition_expr: str, increment_expr: str, - loop_end_block: ControlFlowBlock = None) -> LoopRegion: + loop_end_block: Optional[ControlFlowBlock] = None, + label: Optional[str] = None) -> LoopRegion: """ Helper function that adds a looping control flow block around a given block/state (or sequence of blocks, if ``loop_end_block`` is provided). @@ -2281,12 +2282,15 @@ def add_loop(self, :param loop_end_block: If the loop wraps multiple blocks, the block where the loop iteration ends. If None, sets the end block to ``loop_start_block`` as well. + :param label: An optional label for the loop region. :return: The generated LoopRegion block. """ - loop_region = LoopRegion("loop", condition_expr, loop_var, initialize_expr, f'{loop_var} = {increment_expr}') + label = self._ensure_unique_block_name(label or "loop") + loop_region = LoopRegion(label, condition_expr, loop_var, f'{loop_var} = {initialize_expr}', + f'{loop_var} = {increment_expr}') # Capture subgraphview of loop body - if loop_start_block in self.states(): + if loop_start_block in self.nodes() or loop_start_block in self.states(): # Find all reachable blocks in loop body if loop_end_block is None: blocks = {loop_start_block} @@ -2333,7 +2337,9 @@ def add_loop_state_machine( ): """ Helper function that adds a looping state machine around a - given state (or sequence of states). + given state (or sequence of states). It is recommended to use + ``add_loop`` instead of this function, unless creating a loop + state machine is explicitly requested. :param before_state: The state after which the loop should begin, or None if the loop is the first diff --git a/tests/graph_test.py b/tests/graph_test.py index 9cee2d0be1..4a3ec780c1 100644 --- a/tests/graph_test.py +++ b/tests/graph_test.py @@ -110,7 +110,8 @@ def test_ordered_multidigraph(self): def test_dfs_edges(self): sdfg = dace.SDFG('test_dfs_edges') - before, _, _ = sdfg.add_loop(sdfg.add_state(), sdfg.add_state(), sdfg.add_state(), 'i', '0', 'i < 10', 'i + 1') + before, _, _ = sdfg.add_loop_state_machine(sdfg.add_state(), sdfg.add_state(), sdfg.add_state(), 'i', '0', + 'i < 10', 'i + 1') visited_edges = list(sdfg.dfs_edges(before)) assert len(visited_edges) == len(set(visited_edges)) diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 19e9820def..d4bc48307c 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -605,7 +605,7 @@ def test_simple_loop_overwrite(): init = sdfg.add_state("init") end = sdfg.add_state("end") loop_body = sdfg.add_state("loop_body") - _, guard, _ = sdfg.add_loop(init, loop_body, end, "i", "0", "i < N", "i + 1") + _, guard, _ = sdfg.add_loop_state_machine(init, loop_body, end, "i", "0", "i < N", "i + 1") a0 = loop_body.add_access("A") loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) @@ -629,8 +629,8 @@ def test_loop_2D_overwrite(): loop_body = sdfg.add_state("loop_body") loop_before_1 = sdfg.add_state("loop_before_1") loop_after_1 = sdfg.add_state("loop_after_1") - _, guard2, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") - _, guard1, _ = sdfg.add_loop(init, loop_before_1, end, "j", "0", "j < M", "j + 1", loop_after_1) + _, guard2, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") + _, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_1, end, "j", "0", "j < M", "j + 1", loop_after_1) a0 = loop_body.add_access("A") loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) @@ -659,10 +659,12 @@ def test_loop_2D_propagation_gap_symbolic(): loop_after_1 = sdfg.add_state("loop_after_1") loop_before_2 = sdfg.add_state("loop_before_2") loop_after_2 = sdfg.add_state("loop_after_2") - _, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") # inner-most loop - _, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < K", "k + 1", - loop_after_1) # second-inner-most loop - _, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) # outer-most loop + _, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", + "i + 1") # inner-most loop + _, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < K", "k + 1", + loop_after_1) # second-inner-most loop + _, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1", + loop_after_2) # outer-most loop a0 = loop_body.add_access("A") loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) @@ -687,8 +689,8 @@ def test_2_loops_overwrite(): end = sdfg.add_state("end") loop_body_1 = sdfg.add_state("loop_body_1") loop_body_2 = sdfg.add_state("loop_body_2") - _, guard_1, after_state = sdfg.add_loop(init, loop_body_1, None, "i", "0", "i < N", "i + 1") - _, guard_2, _ = sdfg.add_loop(after_state, loop_body_2, end, "i", "0", "i < N", "i + 1") + _, guard_1, after_state = sdfg.add_loop_state_machine(init, loop_body_1, None, "i", "0", "i < N", "i + 1") + _, guard_2, _ = sdfg.add_loop_state_machine(after_state, loop_body_2, end, "i", "0", "i < N", "i + 1") a0 = loop_body_1.add_access("A") loop_tasklet_1 = loop_body_1.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_1.add_edge(loop_tasklet_1, "a", a0, None, dace.Memlet("A[i]")) @@ -720,9 +722,10 @@ def test_loop_2D_overwrite_propagation_gap_non_empty(): loop_after_1 = sdfg.add_state("loop_after_1") loop_before_2 = sdfg.add_state("loop_before_2") loop_after_2 = sdfg.add_state("loop_after_2") - _, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") - _, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1", loop_after_1) - _, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) + _, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") + _, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1", + loop_after_1) + _, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) a0 = loop_body.add_access("A") loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) @@ -751,9 +754,10 @@ def test_loop_nest_multiplied_indices(): loop_after_1 = sdfg.add_state("loop_after_1") loop_before_2 = sdfg.add_state("loop_before_2") loop_after_2 = sdfg.add_state("loop_after_2") - _, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") - _, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1", loop_after_1) - _, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) + _, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") + _, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1", + loop_after_1) + _, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) a0 = loop_body.add_access("A") loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i,i*j]")) @@ -783,9 +787,10 @@ def test_loop_nest_empty_nested_loop(): loop_after_1 = sdfg.add_state("loop_after_1") loop_before_2 = sdfg.add_state("loop_before_2") loop_after_2 = sdfg.add_state("loop_after_2") - _, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") - _, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 0", "k + 1", loop_after_1) - _, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) + _, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") + _, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 0", "k + 1", + loop_after_1) + _, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) a0 = loop_body.add_access("A") loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) @@ -813,8 +818,8 @@ def test_loop_nest_inner_loop_conditional(): if_merge = sdfg.add_state("if_merge") loop_before_2 = sdfg.add_state("loop_before_2") loop_after_2 = sdfg.add_state("loop_after_2") - _, guard2, _ = sdfg.add_loop(loop_before_2, loop_body, loop_after_2, "k", "0", "k < N", "k + 1") - _, guard1, _ = sdfg.add_loop(init, if_guard, end, "j", "0", "j < M", "j + 1", if_merge) + _, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_body, loop_after_2, "k", "0", "k < N", "k + 1") + _, guard1, _ = sdfg.add_loop_state_machine(init, if_guard, end, "j", "0", "j < M", "j + 1", if_merge) sdfg.add_edge(if_guard, loop_before_2, dace.InterstateEdge(condition="j % 2 == 0")) sdfg.add_edge(if_guard, if_merge, dace.InterstateEdge(condition="j % 2 == 1")) sdfg.add_edge(loop_after_2, if_merge, dace.InterstateEdge()) @@ -909,7 +914,7 @@ def test_loop_break(): loop_body_0 = sdfg.add_state("loop_body_0") loop_body_1 = sdfg.add_state("loop_body_1") loop_after_1 = sdfg.add_state("loop_after_1") - _, guard3, _ = sdfg.add_loop(init, loop_body_0, loop_after_1, "i", "0", "i < N", "i + 1", loop_body_1) + _, guard3, _ = sdfg.add_loop_state_machine(init, loop_body_0, loop_after_1, "i", "0", "i < N", "i + 1", loop_body_1) sdfg.add_edge(loop_body_0, loop_after_1, dace.InterstateEdge(condition="i > 10")) sdfg.add_edge(loop_body_0, loop_body_1, dace.InterstateEdge(condition="not(i > 10)")) a0 = loop_body_1.add_access("A") diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py index 8c39e3033f..e1841ac30c 100644 --- a/tests/schedule_tree/naming_test.py +++ b/tests/schedule_tree/naming_test.py @@ -18,11 +18,11 @@ def _irreducible_loop_to_loop(): # Add a loop l1 = sdfg.add_state() l2 = sdfg.add_state_after(l1) - sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + sdfg.add_loop_state_machine(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) l3 = sdfg.add_state() l4 = sdfg.add_state_after(l3) - sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) + sdfg.add_loop_state_machine(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) # Irreducible part sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index c15eb99f88..2946c91cd5 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -151,7 +151,7 @@ def test_irreducible_sub_sdfg(): sdfg.add_edge(s2, e, dace.InterstateEdge('b < 0')) # Add a loop following general block - sdfg.add_loop(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1') + sdfg.add_loop_state_machine(e, sdfg.add_state(), None, 'i', '0', 'i < 10', 'i + 1') FixedPointPipeline([ControlFlowRaising()]).apply_pass(sdfg, {}) @@ -171,12 +171,11 @@ def test_irreducible_in_loops(): # Add a loop l1 = sdfg.add_state() l2 = sdfg.add_state_after(l1) - sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) + sdfg.add_loop_state_machine(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2) l3 = sdfg.add_state() l4 = sdfg.add_state_after(l3) - sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) - + sdfg.add_loop_state_machine(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4) # Irreducible part sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5')) diff --git a/tests/sdfg/data/container_array_test.py b/tests/sdfg/data/container_array_test.py index 2773b5fec5..9919c8a410 100644 --- a/tests/sdfg/data/container_array_test.py +++ b/tests/sdfg/data/container_array_test.py @@ -126,16 +126,23 @@ def test_write_struct_array(): if_body.add_edge(indices, 'views', vcsr, None, dace.Memlet(data='vcsr.indices', subset='0:nnz')) if_body.add_edge(vcsr, 'views', B, None, dace.Memlet(data='B', subset='k')) # Make For Loop for j - j_before, j_guard, j_after = sdfg.add_loop(None, - if_before, - None, - 'j', - '0', - 'j < N', - 'j + 1', - loop_end_state=if_after) + j_before, _, j_after = sdfg.add_loop_state_machine(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) # Make For Loop for i - i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + i_before, i_guard, i_after = sdfg.add_loop_state_machine(None, + j_before, + None, + 'i', + '0', + 'i < M', + 'i + 1', + loop_end_state=j_after) sdfg.start_state = sdfg.node_id(i_before) i_before_guard = sdfg.edges_between(i_before, i_guard)[0] i_before_guard.data.assignments['idx'] = '0' diff --git a/tests/sdfg/data/structure_test.py b/tests/sdfg/data/structure_test.py index caf32dfcfb..8268cda965 100644 --- a/tests/sdfg/data/structure_test.py +++ b/tests/sdfg/data/structure_test.py @@ -98,16 +98,23 @@ def test_write_structure(): if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) if_body.add_edge(indices, 'views', B, None, dace.Memlet(data='B.indices', subset='0:nnz')) # Make For Loop for j - j_before, j_guard, j_after = sdfg.add_loop(None, - if_before, - None, - 'j', - '0', - 'j < N', - 'j + 1', - loop_end_state=if_after) + j_before, _, j_after = sdfg.add_loop_state_machine(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) # Make For Loop for i - i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + i_before, i_guard, i_after = sdfg.add_loop_state_machine(None, + j_before, + None, + 'i', + '0', + 'i < M', + 'i + 1', + loop_end_state=j_after) sdfg.start_state = sdfg.node_id(i_before) i_before_guard = sdfg.edges_between(i_before, i_guard)[0] i_before_guard.data.assignments['idx'] = '0' @@ -183,16 +190,23 @@ def test_local_structure(): if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='tmp_vindices', subset='idx')) if_body.add_edge(indices, 'views', tmp, None, dace.Memlet(data='tmp.indices', subset='0:nnz')) # Make For Loop for j - j_before, j_guard, j_after = sdfg.add_loop(None, - if_before, - None, - 'j', - '0', - 'j < N', - 'j + 1', - loop_end_state=if_after) + j_before, _, j_after = sdfg.add_loop_state_machine(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) # Make For Loop for i - i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + i_before, i_guard, i_after = sdfg.add_loop_state_machine(None, + j_before, + None, + 'i', + '0', + 'i < M', + 'i + 1', + loop_end_state=j_after) sdfg.start_state = sdfg.node_id(i_before) i_before_guard = sdfg.edges_between(i_before, i_guard)[0] i_before_guard.data.assignments['idx'] = '0' @@ -347,16 +361,23 @@ def test_write_nested_structure(): if_body.add_edge(t, '__out', indices, None, dace.Memlet(data='vindices', subset='idx')) if_body.add_edge(indices, 'views', B, None, dace.Memlet(data='B.csr.indices', subset='0:nnz')) # Make For Loop for j - j_before, j_guard, j_after = sdfg.add_loop(None, - if_before, - None, - 'j', - '0', - 'j < N', - 'j + 1', - loop_end_state=if_after) + j_before, j_guard, j_after = sdfg.add_loop_state_machine(None, + if_before, + None, + 'j', + '0', + 'j < N', + 'j + 1', + loop_end_state=if_after) # Make For Loop for i - i_before, i_guard, i_after = sdfg.add_loop(None, j_before, None, 'i', '0', 'i < M', 'i + 1', loop_end_state=j_after) + i_before, i_guard, i_after = sdfg.add_loop_state_machine(None, + j_before, + None, + 'i', + '0', + 'i < M', + 'i + 1', + loop_end_state=j_after) sdfg.start_state = sdfg.node_id(i_before) i_before_guard = sdfg.edges_between(i_before, i_guard)[0] i_before_guard.data.assignments['idx'] = '0' @@ -465,16 +486,8 @@ def test_direct_read_structure_loops(): state.add_edge(data, None, t, '__val', dace.Memlet(data='A.data', subset='idx')) state.add_edge(t, '__out', B, None, dace.Memlet(data='B', subset='0:M, 0:N', volume=1)) - idx_before, idx_guard, idx_after = sdfg.add_loop(None, state, None, 'idx', 'A.indptr[i]', 'idx < A.indptr[i+1]', - 'idx + 1') - i_before, i_guard, i_after = sdfg.add_loop(None, - idx_before, - None, - 'i', - '0', - 'i < M', - 'i + 1', - loop_end_state=idx_after) + idx_loop = sdfg.add_loop(None, state, None, 'idx', 'A.indptr[i]', 'idx < A.indptr[i+1]', 'idx + 1') + sdfg.add_loop(None, idx_loop, None, 'i', '0', 'i < M', 'i + 1') func = sdfg.compile() @@ -613,12 +626,12 @@ def test_read_struct_member_interstate_edge(): if __name__ == "__main__": - test_read_structure() - test_write_structure() - test_local_structure() - test_read_nested_structure() - test_write_nested_structure() - test_direct_read_structure() - test_direct_read_nested_structure() + # test_read_structure() + # test_write_structure() + # test_local_structure() + # test_read_nested_structure() + # test_write_nested_structure() + # test_direct_read_structure() + # test_direct_read_nested_structure() test_direct_read_structure_loops() - test_read_struct_member_interstate_edge() + # test_read_struct_member_interstate_edge() diff --git a/tests/sdfg/reference_test.py b/tests/sdfg/reference_test.py index 79b4ceff00..2201ec6a06 100644 --- a/tests/sdfg/reference_test.py +++ b/tests/sdfg/reference_test.py @@ -336,7 +336,7 @@ def _create_loop_reference_internal_use(): state = sdfg.add_state() after = sdfg.add_state() sdfg.add_edge(state, after, dace.InterstateEdge()) - sdfg.add_loop(istate, state, None, 'i', '0', 'i < 20', 'i + 1', loop_end_state=after) + sdfg.add_loop(istate, state, None, 'i', '0', 'i < 20', 'i + 1', loop_end_block=after) # Reference set inside loop state.add_edge(state.add_read('A'), None, state.add_write('ref'), 'set', dace.Memlet('A[i]')) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index ecb905c0f0..4077366fe3 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -303,7 +303,7 @@ def test_need_for_tasklet(): aname, _ = sdfg.add_array('A', (10, ), dace.int32) bname, _ = sdfg.add_array('B', (10, ), dace.int32) body = sdfg.add_state('body') - _, _, _ = sdfg.add_loop(None, body, None, 'i', '0', 'i < 10', 'i + 1', None) + _, _, _ = sdfg.add_loop_state_machine(None, body, None, 'i', '0', 'i < 10', 'i + 1', None) anode = body.add_access(aname) bnode = body.add_access(bname) body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='i', other_subset='9 - i')) @@ -325,7 +325,7 @@ def test_need_for_transient(): aname, _ = sdfg.add_array('A', (10, 10), dace.int32) bname, _ = sdfg.add_array('B', (10, 10), dace.int32) body = sdfg.add_state('body') - _, _, _ = sdfg.add_loop(None, body, None, 'i', '0', 'i < 10', 'i + 1', None) + _, _, _ = sdfg.add_loop_state_machine(None, body, None, 'i', '0', 'i < 10', 'i + 1', None) anode = body.add_access(aname) bnode = body.add_access(bname) body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='0:10, i', other_subset='0:10, 9 - i')) @@ -390,7 +390,7 @@ def test_symbol_write_before_read(): body_start = sdfg.add_state() body = sdfg.add_state() body_end = sdfg.add_state() - sdfg.add_loop(init, body_start, None, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) + sdfg.add_loop_state_machine(init, body_start, None, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) # Internal loop structure sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(j='0'))) @@ -410,7 +410,7 @@ def test_symbol_array_mix(overwrite): body = sdfg.add_state() body_end = sdfg.add_state() after = sdfg.add_state() - sdfg.add_loop(init, body_start, after, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) + sdfg.add_loop_state_machine(init, body_start, after, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) sdfg.out_edges(init)[0].data.assignments['sym'] = '0.0' @@ -438,7 +438,7 @@ def test_symbol_array_mix_2(parallel): body_start = sdfg.add_state() body_end = sdfg.add_state() after = sdfg.add_state() - sdfg.add_loop(init, body_start, after, 'i', '1', 'i < 20', 'i + 1', loop_end_state=body_end) + sdfg.add_loop_state_machine(init, body_start, after, 'i', '1', 'i < 20', 'i + 1', loop_end_state=body_end) sdfg.out_edges(init)[0].data.assignments['sym'] = '0.0' @@ -463,7 +463,7 @@ def test_internal_symbol_used_outside(overwrite): body = sdfg.add_state() body_end = sdfg.add_state() after = sdfg.add_state() - sdfg.add_loop(init, body_start, after, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) + sdfg.add_loop_state_machine(init, body_start, after, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) # Internal loop structure sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(j='0'))) From e313b09a0347bf3a93c5b0258d4624f1a15351d6 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 9 Feb 2026 22:25:10 -0800 Subject: [PATCH 3/3] Fix loop code generation --- dace/codegen/control_flow.py | 68 +++++++++++++++++++----------------- dace/sdfg/sdfg.py | 5 +-- 2 files changed, 39 insertions(+), 34 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 1b5384428e..063ca6f542 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -93,50 +93,54 @@ def _loop_region_to_code(region: LoopRegion, dispatch_state: Callable[[SDFGState expr = '' - if loop.update_statement and loop.init_statement and loop.loop_variable: - lsyms = {} - lsyms.update(symbols) - if codegen.dispatcher.defined_vars.has(loop.loop_variable) and not loop.loop_variable in lsyms: - lsyms[loop.loop_variable] = codegen.dispatcher.defined_vars.get(loop.loop_variable)[1] + lsyms = {} + lsyms.update(symbols) + if (loop.loop_variable and codegen.dispatcher.defined_vars.has(loop.loop_variable) + and not loop.loop_variable in lsyms): + lsyms[loop.loop_variable] = codegen.dispatcher.defined_vars.get(loop.loop_variable)[1] + + if loop.init_statement: init = unparse_interstate_edge(loop.init_statement.code[0], sdfg, codegen=codegen, symbols=lsyms) init = init.strip(';') + else: + init = '' + if loop.update_statement: update = unparse_interstate_edge(loop.update_statement.code[0], sdfg, codegen=codegen, symbols=lsyms) update = update.strip(';') + else: + update = '' - if loop.inverted: - if loop.update_before_condition: + if loop.inverted: + if loop.update_before_condition: + if init: expr += f'{init};\n' - expr += 'do {\n' - expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) + expr += 'do {\n' + expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) + if update: expr += f'{update};\n' - expr += f'}} while({cond});\n' else: - expr += f'{init};\n' - expr += 'while (1) {\n' - expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) - expr += f'if (!({cond}))\n' - expr += 'break;\n' - expr += f'{update};\n' - expr += '}\n' + expr += '\n' + expr += f'}} while({cond});\n' else: - if loop.unroll: - if loop.unroll_factor >= 1: - expr += f'#pragma unroll {loop.unroll_factor}\n' - else: - expr += f'#pragma unroll\n' - expr += f'for ({init}; {cond}; {update}) {{\n' + if init: + expr += f'{init};\n' + expr += 'while (1) {\n' expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) - expr += '\n}\n' + expr += f'if (!({cond}))\n' + expr += 'break;\n' + if update: + expr += f'{update};\n' + expr += '}\n' else: - if loop.inverted: - expr += 'do {\n' - expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) - expr += f'\n}} while({cond});\n' - else: - expr += f'while ({cond}) {{\n' - expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) - expr += '\n}\n' + if loop.unroll: + if loop.unroll_factor >= 1: + expr += f'#pragma unroll {loop.unroll_factor}\n' + else: + expr += f'#pragma unroll\n' + expr += f'for ({init}; {cond}; {update}) {{\n' + expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols)) + expr += '\n}\n' return expr diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index ee9b080e8a..e3401ed244 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -2286,8 +2286,9 @@ def add_loop(self, :return: The generated LoopRegion block. """ label = self._ensure_unique_block_name(label or "loop") - loop_region = LoopRegion(label, condition_expr, loop_var, f'{loop_var} = {initialize_expr}', - f'{loop_var} = {increment_expr}') + loop_region = LoopRegion(label, condition_expr, loop_var, + f'{loop_var} = {initialize_expr}' if initialize_expr else None, + f'{loop_var} = {increment_expr}' if increment_expr else None) # Capture subgraphview of loop body if loop_start_block in self.nodes() or loop_start_block in self.states():