Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 36 additions & 32 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,50 +93,54 @@ def _loop_region_to_code(region: LoopRegion, dispatch_state: Callable[[SDFGState

expr = ''

if loop.update_statement and loop.init_statement and loop.loop_variable:
lsyms = {}
lsyms.update(symbols)
if codegen.dispatcher.defined_vars.has(loop.loop_variable) and not loop.loop_variable in lsyms:
lsyms[loop.loop_variable] = codegen.dispatcher.defined_vars.get(loop.loop_variable)[1]
lsyms = {}
lsyms.update(symbols)
if (loop.loop_variable and codegen.dispatcher.defined_vars.has(loop.loop_variable)
and not loop.loop_variable in lsyms):
lsyms[loop.loop_variable] = codegen.dispatcher.defined_vars.get(loop.loop_variable)[1]

if loop.init_statement:
init = unparse_interstate_edge(loop.init_statement.code[0], sdfg, codegen=codegen, symbols=lsyms)
init = init.strip(';')
else:
init = ''

if loop.update_statement:
update = unparse_interstate_edge(loop.update_statement.code[0], sdfg, codegen=codegen, symbols=lsyms)
update = update.strip(';')
else:
update = ''

if loop.inverted:
if loop.update_before_condition:
if loop.inverted:
if loop.update_before_condition:
if init:
expr += f'{init};\n'
expr += 'do {\n'
expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols))
expr += 'do {\n'
expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols))
if update:
expr += f'{update};\n'
expr += f'}} while({cond});\n'
else:
expr += f'{init};\n'
expr += 'while (1) {\n'
expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols))
expr += f'if (!({cond}))\n'
expr += 'break;\n'
expr += f'{update};\n'
expr += '}\n'
expr += '\n'
expr += f'}} while({cond});\n'
else:
if loop.unroll:
if loop.unroll_factor >= 1:
expr += f'#pragma unroll {loop.unroll_factor}\n'
else:
expr += f'#pragma unroll\n'
expr += f'for ({init}; {cond}; {update}) {{\n'
if init:
expr += f'{init};\n'
expr += 'while (1) {\n'
expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols))
expr += '\n}\n'
expr += f'if (!({cond}))\n'
expr += 'break;\n'
if update:
expr += f'{update};\n'
expr += '}\n'
else:
if loop.inverted:
expr += 'do {\n'
expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols))
expr += f'\n}} while({cond});\n'
else:
expr += f'while ({cond}) {{\n'
expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols))
expr += '\n}\n'
if loop.unroll:
if loop.unroll_factor >= 1:
expr += f'#pragma unroll {loop.unroll_factor}\n'
else:
expr += f'#pragma unroll\n'
expr += f'for ({init}; {cond}; {update}) {{\n'
expr += _clean_loop_body(control_flow_region_to_code(loop, dispatch_state, codegen, symbols))
expr += '\n}\n'

return expr

Expand Down
97 changes: 86 additions & 11 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
import pickle

import dace
from dace.sdfg.graph import generate_element_id
from dace.sdfg.graph import generate_element_id, SubgraphView
import dace.serialize
from dace import (data as dt, hooks, memlet as mm, subsets as sbs, dtypes, symbolic)
from dace.sdfg.replace import replace_properties_dict
from dace.sdfg.validation import (InvalidSDFGError, validate_sdfg)
from dace.config import Config
from dace.frontend.python import astutils
from dace.sdfg import nodes as nd
from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState, ControlFlowRegion
from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState, ControlFlowRegion, LoopRegion
from dace.sdfg.type_inference import infer_expr_type
from dace.distr_types import ProcessGrid, SubArray, RedistrArray
from dace.dtypes import validate_name
Expand Down Expand Up @@ -2249,20 +2249,98 @@ def add_rdistrarray(self, array_a: str, array_b: str):
self.append_exit_code(self._rdistrarrays[name].exit_code(self))
return name

def add_loop(
def add_loop(self,
before_block: ControlFlowBlock,
loop_start_block: ControlFlowBlock,
after_block: ControlFlowBlock,
loop_var: str,
initialize_expr: str,
condition_expr: str,
increment_expr: str,
loop_end_block: Optional[ControlFlowBlock] = None,
label: Optional[str] = None) -> LoopRegion:
"""
Helper function that adds a looping control flow block around a
given block/state (or sequence of blocks, if ``loop_end_block`` is provided).

:param before_block: The block after which the loop should
begin, or None if the loop is the first
block (creates an empty block).
:param loop_start_block: The block that begins the loop. See also
``loop_end_block`` if the loop is multi-block.
:param after_block: The block that should be invoked after
the loop ends, or None if the program
should terminate (creates an empty block).
:param loop_var: A name of a symbol to use for the loop variable.
:param initialize_expr: A string expression that is assigned
to ``loop_var`` before the loop begins.
If None, does not define an expression.
:param condition_expr: A string condition that occurs every
loop iteration. If None, loops forever.
:param increment_expr: A string expression that is assigned to
``loop_var`` after every loop iteration.
:param loop_end_block: If the loop wraps multiple blocks, the block
where the loop iteration ends. If None, sets
the end block to ``loop_start_block`` as well.
:param label: An optional label for the loop region.
:return: The generated LoopRegion block.
"""
label = self._ensure_unique_block_name(label or "loop")
loop_region = LoopRegion(label, condition_expr, loop_var,
f'{loop_var} = {initialize_expr}' if initialize_expr else None,
f'{loop_var} = {increment_expr}' if increment_expr else None)

# Capture subgraphview of loop body
if loop_start_block in self.nodes() or loop_start_block in self.states():
# Find all reachable blocks in loop body
if loop_end_block is None:
blocks = {loop_start_block}
else:
blocks = set(self.all_nodes_between(loop_start_block, loop_end_block))
blocks.add(loop_start_block)
blocks.add(loop_end_block)

subgraph = SubgraphView(self, blocks)
remove_subgraph = True
else:
subgraph = SubgraphView(self, [loop_start_block, loop_end_block] if loop_end_block else [loop_start_block])
remove_subgraph = False

# Readd subgraph to main SDFG with loop_region as parent
loop_region.add_node(loop_start_block, is_start_block=True)
loop_region.add_nodes_from(set(subgraph.nodes()) - {loop_start_block})
for e in subgraph.edges():
loop_region.add_edge(e.src, e.dst, e.data)

# Remove subgraph from main SDFG, if necessary
if remove_subgraph:
self.remove_nodes_from(blocks)

# Connect to graph
self.add_node(loop_region, is_start_block=(before_block is None))
if before_block is not None:
self.add_edge(before_block, loop_region, InterstateEdge())
if after_block is not None:
self.add_edge(loop_region, after_block, InterstateEdge())

return loop_region

def add_loop_state_machine(
self,
before_state,
loop_state,
after_state,
before_state: SDFGState,
loop_state: SDFGState,
after_state: SDFGState,
loop_var: str,
initialize_expr: str,
condition_expr: str,
increment_expr: str,
loop_end_state=None,
loop_end_state: Optional[SDFGState] = None,
):
"""
Helper function that adds a looping state machine around a
given state (or sequence of states).
given state (or sequence of states). It is recommended to use
``add_loop`` instead of this function, unless creating a loop
state machine is explicitly requested.

:param before_state: The state after which the loop should
begin, or None if the loop is the first
Expand Down Expand Up @@ -2293,9 +2371,6 @@ def add_loop(
"""
from dace.frontend.python.astutils import negate_expr # Avoid import loops

warnings.warn("SDFG.add_loop is deprecated and will be removed in a future release. Use LoopRegions instead.",
DeprecationWarning)

# Argument checks
if loop_var is None and (initialize_expr or increment_expr):
raise ValueError("Cannot initalize or increment an empty loop variable")
Expand Down
4 changes: 3 additions & 1 deletion dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3536,8 +3536,10 @@ def _used_symbols_internal(self,
free_syms = set() if free_syms is None else free_syms
used_before_assignment = set() if used_before_assignment is None else used_before_assignment

defined_syms.add(self.loop_variable)
if self.init_statement is not None:
# Loops with no initialization statement do not redefine the loop variable
defined_syms.add(self.loop_variable)

free_syms |= self.init_statement.get_free_symbols()
if self.update_statement is not None:
free_syms |= self.update_statement.get_free_symbols()
Expand Down
3 changes: 2 additions & 1 deletion tests/graph_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ def test_ordered_multidigraph(self):
def test_dfs_edges(self):

sdfg = dace.SDFG('test_dfs_edges')
before, _, _ = sdfg.add_loop(sdfg.add_state(), sdfg.add_state(), sdfg.add_state(), 'i', '0', 'i < 10', 'i + 1')
before, _, _ = sdfg.add_loop_state_machine(sdfg.add_state(), sdfg.add_state(), sdfg.add_state(), 'i', '0',
'i < 10', 'i + 1')

visited_edges = list(sdfg.dfs_edges(before))
assert len(visited_edges) == len(set(visited_edges))
Expand Down
47 changes: 26 additions & 21 deletions tests/passes/writeset_underapproximation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,7 @@ def test_simple_loop_overwrite():
init = sdfg.add_state("init")
end = sdfg.add_state("end")
loop_body = sdfg.add_state("loop_body")
_, guard, _ = sdfg.add_loop(init, loop_body, end, "i", "0", "i < N", "i + 1")
_, guard, _ = sdfg.add_loop_state_machine(init, loop_body, end, "i", "0", "i < N", "i + 1")
a0 = loop_body.add_access("A")
loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0")
loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]"))
Expand All @@ -629,8 +629,8 @@ def test_loop_2D_overwrite():
loop_body = sdfg.add_state("loop_body")
loop_before_1 = sdfg.add_state("loop_before_1")
loop_after_1 = sdfg.add_state("loop_after_1")
_, guard2, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard1, _ = sdfg.add_loop(init, loop_before_1, end, "j", "0", "j < M", "j + 1", loop_after_1)
_, guard2, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_1, end, "j", "0", "j < M", "j + 1", loop_after_1)
a0 = loop_body.add_access("A")
loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0")
loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]"))
Expand Down Expand Up @@ -659,10 +659,12 @@ def test_loop_2D_propagation_gap_symbolic():
loop_after_1 = sdfg.add_state("loop_after_1")
loop_before_2 = sdfg.add_state("loop_before_2")
loop_after_2 = sdfg.add_state("loop_after_2")
_, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1") # inner-most loop
_, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < K", "k + 1",
loop_after_1) # second-inner-most loop
_, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2) # outer-most loop
_, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N",
"i + 1") # inner-most loop
_, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < K", "k + 1",
loop_after_1) # second-inner-most loop
_, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1",
loop_after_2) # outer-most loop
a0 = loop_body.add_access("A")
loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0")
loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]"))
Expand All @@ -687,8 +689,8 @@ def test_2_loops_overwrite():
end = sdfg.add_state("end")
loop_body_1 = sdfg.add_state("loop_body_1")
loop_body_2 = sdfg.add_state("loop_body_2")
_, guard_1, after_state = sdfg.add_loop(init, loop_body_1, None, "i", "0", "i < N", "i + 1")
_, guard_2, _ = sdfg.add_loop(after_state, loop_body_2, end, "i", "0", "i < N", "i + 1")
_, guard_1, after_state = sdfg.add_loop_state_machine(init, loop_body_1, None, "i", "0", "i < N", "i + 1")
_, guard_2, _ = sdfg.add_loop_state_machine(after_state, loop_body_2, end, "i", "0", "i < N", "i + 1")
a0 = loop_body_1.add_access("A")
loop_tasklet_1 = loop_body_1.add_tasklet("overwrite", {}, {"a"}, "a = 0")
loop_body_1.add_edge(loop_tasklet_1, "a", a0, None, dace.Memlet("A[i]"))
Expand Down Expand Up @@ -720,9 +722,10 @@ def test_loop_2D_overwrite_propagation_gap_non_empty():
loop_after_1 = sdfg.add_state("loop_after_1")
loop_before_2 = sdfg.add_state("loop_before_2")
loop_after_2 = sdfg.add_state("loop_after_2")
_, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1", loop_after_1)
_, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2)
_, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1",
loop_after_1)
_, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2)
a0 = loop_body.add_access("A")
loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0")
loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]"))
Expand Down Expand Up @@ -751,9 +754,10 @@ def test_loop_nest_multiplied_indices():
loop_after_1 = sdfg.add_state("loop_after_1")
loop_before_2 = sdfg.add_state("loop_before_2")
loop_after_2 = sdfg.add_state("loop_after_2")
_, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1", loop_after_1)
_, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2)
_, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 10", "k + 1",
loop_after_1)
_, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2)
a0 = loop_body.add_access("A")
loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0")
loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i,i*j]"))
Expand Down Expand Up @@ -783,9 +787,10 @@ def test_loop_nest_empty_nested_loop():
loop_after_1 = sdfg.add_state("loop_after_1")
loop_before_2 = sdfg.add_state("loop_before_2")
loop_after_2 = sdfg.add_state("loop_after_2")
_, guard3, _ = sdfg.add_loop(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard2, _ = sdfg.add_loop(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 0", "k + 1", loop_after_1)
_, guard1, _ = sdfg.add_loop(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2)
_, guard3, _ = sdfg.add_loop_state_machine(loop_before_1, loop_body, loop_after_1, "i", "0", "i < N", "i + 1")
_, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_before_1, loop_after_2, "k", "0", "k < 0", "k + 1",
loop_after_1)
_, guard1, _ = sdfg.add_loop_state_machine(init, loop_before_2, end, "j", "0", "j < M", "j + 1", loop_after_2)
a0 = loop_body.add_access("A")
loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0")
loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]"))
Expand Down Expand Up @@ -813,8 +818,8 @@ def test_loop_nest_inner_loop_conditional():
if_merge = sdfg.add_state("if_merge")
loop_before_2 = sdfg.add_state("loop_before_2")
loop_after_2 = sdfg.add_state("loop_after_2")
_, guard2, _ = sdfg.add_loop(loop_before_2, loop_body, loop_after_2, "k", "0", "k < N", "k + 1")
_, guard1, _ = sdfg.add_loop(init, if_guard, end, "j", "0", "j < M", "j + 1", if_merge)
_, guard2, _ = sdfg.add_loop_state_machine(loop_before_2, loop_body, loop_after_2, "k", "0", "k < N", "k + 1")
_, guard1, _ = sdfg.add_loop_state_machine(init, if_guard, end, "j", "0", "j < M", "j + 1", if_merge)
sdfg.add_edge(if_guard, loop_before_2, dace.InterstateEdge(condition="j % 2 == 0"))
sdfg.add_edge(if_guard, if_merge, dace.InterstateEdge(condition="j % 2 == 1"))
sdfg.add_edge(loop_after_2, if_merge, dace.InterstateEdge())
Expand Down Expand Up @@ -909,7 +914,7 @@ def test_loop_break():
loop_body_0 = sdfg.add_state("loop_body_0")
loop_body_1 = sdfg.add_state("loop_body_1")
loop_after_1 = sdfg.add_state("loop_after_1")
_, guard3, _ = sdfg.add_loop(init, loop_body_0, loop_after_1, "i", "0", "i < N", "i + 1", loop_body_1)
_, guard3, _ = sdfg.add_loop_state_machine(init, loop_body_0, loop_after_1, "i", "0", "i < N", "i + 1", loop_body_1)
sdfg.add_edge(loop_body_0, loop_after_1, dace.InterstateEdge(condition="i > 10"))
sdfg.add_edge(loop_body_0, loop_body_1, dace.InterstateEdge(condition="not(i > 10)"))
a0 = loop_body_1.add_access("A")
Expand Down
4 changes: 2 additions & 2 deletions tests/schedule_tree/naming_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def _irreducible_loop_to_loop():
# Add a loop
l1 = sdfg.add_state()
l2 = sdfg.add_state_after(l1)
sdfg.add_loop(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2)
sdfg.add_loop_state_machine(s1, l1, s2, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l2)

l3 = sdfg.add_state()
l4 = sdfg.add_state_after(l3)
sdfg.add_loop(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4)
sdfg.add_loop_state_machine(s2, l3, e, 'i', '0', 'i < 10', 'i + 1', loop_end_state=l4)

# Irreducible part
sdfg.add_edge(l3, l1, dace.InterstateEdge('i < 5'))
Expand Down
Loading