From 5928fe8744289d6766507b1655dca486ea8cbb2d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 3 Dec 2023 12:50:33 -0800 Subject: [PATCH 001/137] Add validation check for empty SDFGs --- dace/sdfg/validation.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index a3914494c3..707b51f27d 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -202,6 +202,8 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context raise InvalidSDFGError("Invalid name", sdfg, None) all_blocks = set(sdfg.all_control_flow_blocks()) + if len(all_blocks) == 0: + raise InvalidSDFGError('SDFG contains no states or control flow blocks', sdfg, None) if len(all_blocks) != len(set([s.label for s in all_blocks])): raise InvalidSDFGError('Found multiple blocks with the same name', sdfg, None) From 0d84cecd69ef547dee97551e279c242432ae3709 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 3 Dec 2023 12:51:08 -0800 Subject: [PATCH 002/137] Add explicit stree root node and SDFG conversion method --- dace/sdfg/analysis/schedule_tree/treenodes.py | 27 +++++++++++--- dace/sdfg/sdfg.py | 37 +++++++++++++++---- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 99918cd2a4..8329aec84f 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -3,11 +3,11 @@ from dace import nodes, data, subsets from dace.codegen import control_flow as cf from dace.properties import CodeBlock -from dace.sdfg import InterstateEdge +from dace.sdfg.sdfg import InterstateEdge, SDFG from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet -from typing import Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, List, Optional, Set, Union INDENTATION = ' ' @@ -33,11 +33,8 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: @dataclass class ScheduleTreeScope(ScheduleTreeNode): children: List['ScheduleTreeNode'] - containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) - symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) - def __init__(self, - children: Optional[List['ScheduleTreeNode']] = None): + def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): self.children = children or [] if self.children: for child in children: @@ -59,6 +56,24 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: # TODO: Helper function that gets input/output memlets of the scope +@dataclass +class ScheduleTreeRoot(ScheduleTreeScope): + """ + A root of an SDFG schedule tree. This is a schedule tree scope with additional information on + the available descriptors, symbol types, and constants of the tree, aka the descriptor repository. + """ + name: str + containers: Dict[str, data.Data] = field(default_factory=dict) + symbols: Dict[str, symbol] = field(default_factory=dict) + constants: Dict[str, Any] = field(default_factory=dict) + callback_mapping: Dict[str, str] = field(default_factory=dict) + arg_names: List[str] = field(default_factory=list) + + def as_sdfg(self) -> SDFG: + from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s # Avoid import loop + return t2s.from_schedule_tree(self) + + @dataclass class ControlFlowScope(ScheduleTreeScope): pass diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 020fb9dbab..83d95dda58 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -46,6 +46,7 @@ from dace.codegen.instrumentation.report import InstrumentationReport from dace.codegen.instrumentation.data.data_report import InstrumentedDataReport from dace.codegen.compiled_sdfg import CompiledSDFG + from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeRoot class NestedDict(dict): @@ -1049,6 +1050,25 @@ def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, ########################################## + def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeRoot': + """ + Creates a schedule tree from this SDFG and all nested SDFGs. The schedule tree is a tree of nodes that represent + the execution order of the SDFG. + Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, + etc.) or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + + It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, + erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the + ``as_sdfg`` method or the ``from_schedule_tree`` function in ``dace.sdfg.analysis.schedule_tree.tree_to_sdfg``. + + :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might + not be usable after the conversion if ``in_place`` is True! + :return: A schedule tree representing the given SDFG. + """ + # Avoid import loop + from dace.sdfg.analysis.schedule_tree import sdfg_to_tree as s2t + return s2t.as_schedule_tree(self, in_place=in_place) + @property def build_folder(self) -> str: """ Returns a relative path to the build cache folder for this SDFG. """ @@ -1233,10 +1253,10 @@ def arrays_recursive(self): def _used_symbols_internal(self, all_symbols: bool, - defined_syms: Optional[Set]=None, - free_syms: Optional[Set]=None, - used_before_assignment: Optional[Set]=None, - keep_defined_in_mapping: bool=False) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms: Optional[Set] = None, + free_syms: Optional[Set] = None, + used_before_assignment: Optional[Set] = None, + keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -1253,10 +1273,11 @@ def _used_symbols_internal(self, for code in self.exit_code.values(): free_syms |= symbolic.symbols_in_code(code.as_string, self.symbols.keys()) - return super()._used_symbols_internal( - all_symbols=all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, - defined_syms=defined_syms, free_syms=free_syms, used_before_assignment=used_before_assignment - ) + return super()._used_symbols_internal(all_symbols=all_symbols, + keep_defined_in_mapping=keep_defined_in_mapping, + defined_syms=defined_syms, + free_syms=free_syms, + used_before_assignment=used_before_assignment) def get_all_toplevel_symbols(self) -> Set[str]: """ From 91d2bc642ac902fc0318002d7f2d6c8df236be6c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 3 Dec 2023 13:14:47 -0800 Subject: [PATCH 003/137] Initial conversion function and test --- .../analysis/schedule_tree/sdfg_to_tree.py | 39 ++++++++++++++-- .../analysis/schedule_tree/tree_to_sdfg.py | 23 ++++++++++ dace/sdfg/analysis/schedule_tree/treenodes.py | 4 +- tests/schedule_tree/roundtrip_test.py | 46 +++++++++++++++++++ 4 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py create mode 100644 tests/schedule_tree/roundtrip_test.py diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index a519f24596..f682cbed4b 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -325,6 +325,29 @@ def remove_name_collisions(sdfg: SDFG): nsdfg.replace_dict(replacements) +def create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot): + """ + Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes + data containers, symbols, constants, etc. + + :param sdfg: The top-level SDFG to create the repository from. + :param stree: The tree root in which to make the unified descriptor repository. + """ + stree.containers = sdfg.arrays + stree.symbols = sdfg.symbols + stree.constants = sdfg.constants_prop + + # Since the SDFG is assumed to be de-aliased and contain unique names, we union the contents of + # the nested SDFGs' descriptor repositories + for nsdfg in sdfg.all_sdfgs_recursive(): + transients = {k: v for k, v in nsdfg.arrays.items() if v.transient} + symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols} + constants = {k: v for k, v in nsdfg.constants_prop.items() if k not in stree.constants} + stree.containers.update(transients) + stree.symbols.update(symbols) + stree.constants.update(constants) + + def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, viewed_name: str) -> tn.ViewNode: """ @@ -608,7 +631,7 @@ def _generate_views_in_scope(edges: List[gr.MultiConnectorEdge[Memlet]], return result -def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope: +def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeRoot: """ Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. @@ -642,7 +665,6 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) dealias_sdfg(sdfg) # Handle name collisions (in arrays, state labels, symbols) remove_name_collisions(sdfg) - ############################# # Create initial tree from CFG @@ -726,7 +748,18 @@ def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.Sche return result # Recursive traversal of the control flow tree - result = tn.ScheduleTreeScope(children=totree(cfg)) + children = totree(cfg) + + # Create the scope object + if toplevel: + # Create the root with the elements of the descriptor repository + result = tn.ScheduleTreeRoot(name=sdfg.name, + children=children, + arg_names=sdfg.arg_names, + callback_mapping=sdfg.callback_mapping) + create_unified_descriptor_repository(sdfg, result) + else: + result = tn.ScheduleTreeScope(children=children) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py new file mode 100644 index 0000000000..2847e7d16a --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -0,0 +1,23 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from dace.sdfg.sdfg import SDFG +from dace.sdfg.analysis.schedule_tree import treenodes as tn + + +def from_schedule_tree(stree: tn.ScheduleTreeRoot) -> SDFG: + """ + Converts a schedule tree into an SDFG. + + :param stree: The schedule tree root to convert. + :return: An SDFG representing the schedule tree. + """ + # Set SDFG descriptor repository + result = SDFG(stree.name, propagate=False) + result.arg_names = copy.deepcopy(stree.arg_names) + result._arrays = copy.deepcopy(stree.containers) + result.constants_prop = copy.deepcopy(stree.constants) + result.symbols = copy.deepcopy(stree.symbols) + + # TODO: Fill SDFG contents + + return result diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 8329aec84f..2aa470ca7c 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -7,7 +7,7 @@ from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet -from typing import Any, Dict, Iterator, List, Optional, Set, Union +from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union INDENTATION = ' ' @@ -65,7 +65,7 @@ class ScheduleTreeRoot(ScheduleTreeScope): name: str containers: Dict[str, data.Data] = field(default_factory=dict) symbols: Dict[str, symbol] = field(default_factory=dict) - constants: Dict[str, Any] = field(default_factory=dict) + constants: Dict[str, Tuple[data.Data, Any]] = field(default_factory=dict) callback_mapping: Dict[str, str] = field(default_factory=dict) arg_names: List[str] = field(default_factory=list) diff --git a/tests/schedule_tree/roundtrip_test.py b/tests/schedule_tree/roundtrip_test.py new file mode 100644 index 0000000000..7eafe63bf2 --- /dev/null +++ b/tests/schedule_tree/roundtrip_test.py @@ -0,0 +1,46 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests conversion of schedule trees to SDFGs. +""" +import dace +import numpy as np + + +def test_implicit_inline_and_constants(): + """ + Tests implicit inlining upon roundtrip conversion, as well as constants with conflicting names. + """ + + @dace + def nester(A: dace.float64[20]): + A[:] = 12 + + @dace.program + def tester(A: dace.float64[20, 20]): + for i in dace.map[0:20]: + nester(A[:, i]) + + sdfg = tester.to_sdfg(simplify=False) + + # Inject constant into nested SDFG + assert len(list(sdfg.all_sdfgs_recursive())) > 1 + sdfg.add_constant('cst', 13) # Add an unused constant + sdfg.sdfg_list[-1].add_constant('cst', 1, dace.data.Scalar(dace.float64)) + tasklet = next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.Tasklet)) + tasklet.code.as_string = tasklet.code.as_string.replace('12', 'cst') + + # Perform a roundtrip conversion + stree = sdfg.as_schedule_tree() + new_sdfg = stree.as_sdfg() + + assert len(list(new_sdfg.all_sdfgs_recursive())) == 1 + assert new_sdfg.constants['cst_0'].dtype == np.float64 + + # Test SDFG + a = np.random.rand(20, 20) + new_sdfg(a) # Tests arg_names + assert np.allclose(a, 1) + + +if __name__ == '__main__': + test_implicit_inline_and_constants() From 33910d198bd22f61c39e5235e251e7a97cf896e7 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 3 Dec 2023 23:15:02 -0800 Subject: [PATCH 004/137] Structure conversion process and add tests --- .../analysis/schedule_tree/tree_to_sdfg.py | 58 +++++++++++- dace/sdfg/analysis/schedule_tree/treenodes.py | 11 +++ tests/schedule_tree/to_sdfg_test.py | 89 +++++++++++++++++++ 3 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 tests/schedule_tree/to_sdfg_test.py diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 2847e7d16a..0f871ef569 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,14 +1,26 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import copy -from dace.sdfg.sdfg import SDFG +from dace.sdfg import nodes +from dace.sdfg.sdfg import SDFG, ControlFlowRegion +from dace.sdfg.state import SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn +from enum import Enum, auto +from typing import Optional, Sequence -def from_schedule_tree(stree: tn.ScheduleTreeRoot) -> SDFG: +class StateBoundaryBehavior(Enum): + STATE_TRANSITION = auto() #: Creates multiple states with a state transition + EMPTY_MEMLET = auto() #: Happens-before empty memlet edges in the same state + + +def from_schedule_tree(stree: tn.ScheduleTreeRoot, + state_boundary_behavior: StateBoundaryBehavior = StateBoundaryBehavior.STATE_TRANSITION) -> SDFG: """ Converts a schedule tree into an SDFG. :param stree: The schedule tree root to convert. + :param state_boundary_behavior: Sets the behavior upon encountering a state boundary (e.g., write-after-write). + See the ``StateBoundaryBehavior`` enumeration for more details. :return: An SDFG representing the schedule tree. """ # Set SDFG descriptor repository @@ -19,5 +31,47 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot) -> SDFG: result.symbols = copy.deepcopy(stree.symbols) # TODO: Fill SDFG contents + insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. + + # TODO: create_state_boundary + # TODO: create_loop_block + # TODO: create_conditional_block + # TODO: create_dataflow_scope return result + + +def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> None: + """ + Inserts StateBoundaryNode objects into a schedule tree where more than one SDFG state would be necessary. + Operates in-place on the given schedule tree. + + This happens when there is a: + * write-after-write dependency; + * write-after-read dependency that cannot be fulfilled via memlets; + * control flow block (for/if); or + * otherwise before a state label (which means a state transition could occur, e.g., in a gblock) + + :param stree: The schedule tree to operate on. + """ + pass + + +############################################################################# +# SDFG content creation functions + + +def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowRegion, state: SDFGState, + behavior: StateBoundaryBehavior) -> SDFGState: + """ + Creates a boundary between two states + + :param bnode: The state boundary node to generate. + :param sdfg_region: The control flow block in which to generate the boundary (e.g., SDFG). + :param state: The last state prior to this boundary. + :param behavior: The state boundary behavior with which to create the boundary. + :return: The newly created state. + """ + scope: tn.ControlFlowScope = bnode.parent + assert scope is not None + pass diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 2aa470ca7c..85cbb22ebc 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -372,6 +372,17 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' +@dataclass +class StateBoundaryNode(ScheduleTreeNode): + """ + A node that represents a state boundary (e.g., when a write-after-write is encountered). This node + is used only during conversion from a schedule tree to an SDFG. + """ + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'state boundary' + + # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes class ScheduleNodeVisitor: diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py new file mode 100644 index 0000000000..369bb6b4cd --- /dev/null +++ b/tests/schedule_tree/to_sdfg_test.py @@ -0,0 +1,89 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests components in conversion of schedule trees to SDFGs. +""" +import dace +from dace.codegen import control_flow as cf +from dace.properties import CodeBlock +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn + + +def test_state_boundaries_none(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + t2s.insert_state_boundaries_to_tree(stree) + assert tn.StateBoundaryNode not in [type(n) for n in stree.children] + + +def test_state_boundaries_waw(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_war(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_cfg(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + tn.ForScope([ + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet('A[1]')}), + ], cf.ForScope(None, None, 'i', None, '0', CodeBlock('i < 20'), 'i + 1', None, [])), + ], + ) + + t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.ForScope] == [type(n) for n in stree.children] + + +if __name__ == '__main__': + test_state_boundaries_none() + test_state_boundaries_waw() + test_state_boundaries_war() + test_state_boundaries_cfg() From f74be2f6cce7d0761e65a1af64f06c95566b193f Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 3 Dec 2023 23:28:34 -0800 Subject: [PATCH 005/137] Non-data-dependency state boundary insertion --- .../analysis/schedule_tree/tree_to_sdfg.py | 24 ++++++++++++++++--- tests/schedule_tree/to_sdfg_test.py | 8 +++---- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 0f871ef569..f38a6e1aa7 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -31,7 +31,7 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, result.symbols = copy.deepcopy(stree.symbols) # TODO: Fill SDFG contents - insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. + stree = insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. # TODO: create_state_boundary # TODO: create_loop_block @@ -41,7 +41,7 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, return result -def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> None: +def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: """ Inserts StateBoundaryNode objects into a schedule tree where more than one SDFG state would be necessary. Operates in-place on the given schedule tree. @@ -54,7 +54,25 @@ def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> None: :param stree: The schedule tree to operate on. """ - pass + + # Simple boundary node inserter for control flow blocks and state labels + class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer): + + def visit_scope(self, scope: tn.ScheduleTreeScope): + if isinstance(scope, tn.ControlFlowScope): + return [tn.StateBoundaryNode(), self.generic_visit(scope)] + return self.generic_visit(scope) + + def visit_StateLabel(self, node: tn.StateLabel): + return [tn.StateBoundaryNode(), self.generic_visit(node)] + + # First, insert boundaries around labels and control flow + stree = SimpleStateBoundaryInserter().visit(stree) + + # TODO: Insert boundaries after unmet memory dependencies + # TODO: Implement generic methods that get input/output memlets for stree scopes and nodes + + return stree ############################################################################# diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 369bb6b4cd..7c28215264 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -23,7 +23,7 @@ def test_state_boundaries_none(): ], ) - t2s.insert_state_boundaries_to_tree(stree) + stree = t2s.insert_state_boundaries_to_tree(stree) assert tn.StateBoundaryNode not in [type(n) for n in stree.children] @@ -40,7 +40,7 @@ def test_state_boundaries_waw(): ], ) - t2s.insert_state_boundaries_to_tree(stree) + stree = t2s.insert_state_boundaries_to_tree(stree) assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] @@ -59,7 +59,7 @@ def test_state_boundaries_war(): ], ) - t2s.insert_state_boundaries_to_tree(stree) + stree = t2s.insert_state_boundaries_to_tree(stree) assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] @@ -78,7 +78,7 @@ def test_state_boundaries_cfg(): ], ) - t2s.insert_state_boundaries_to_tree(stree) + stree = t2s.insert_state_boundaries_to_tree(stree) assert [tn.TaskletNode, tn.StateBoundaryNode, tn.ForScope] == [type(n) for n in stree.children] From 25a72b96ef960b78ecb787b1822241503019e2dc Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 1 Jan 2024 07:43:38 -0800 Subject: [PATCH 006/137] More test cases and design --- .../analysis/schedule_tree/tree_to_sdfg.py | 5 ++ tests/schedule_tree/to_sdfg_test.py | 49 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index f38a6e1aa7..591a013776 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -71,6 +71,9 @@ def visit_StateLabel(self, node: tn.StateLabel): # TODO: Insert boundaries after unmet memory dependencies # TODO: Implement generic methods that get input/output memlets for stree scopes and nodes + # TODO: Implement method that searches for a memlet in a dictionary of memlets (even if that memlet + # is a subset of a dictionary key) and returns that key. If intersection indeterminate, assume + # intersects and replace key with union key. Implement in dace.sdfg.memlet_utils. return stree @@ -90,6 +93,8 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowR :param behavior: The state boundary behavior with which to create the boundary. :return: The newly created state. """ + # TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every + # behavior. Fall back to state transition in that case. scope: tn.ControlFlowScope = bnode.parent assert scope is not None pass diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 7c28215264..0d3d9ce9fe 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -63,6 +63,53 @@ def test_state_boundaries_war(): assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] +def test_state_boundaries_read_write_chain(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('B[0]')}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla3', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.TaskletNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_data_race(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla11', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('B[0]')}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla3', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode, + tn.TaskletNode] == [type(n) for n in stree.children] + + def test_state_boundaries_cfg(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( @@ -86,4 +133,6 @@ def test_state_boundaries_cfg(): test_state_boundaries_none() test_state_boundaries_waw() test_state_boundaries_war() + test_state_boundaries_read_write_chain() + test_state_boundaries_data_race() test_state_boundaries_cfg() From 2bf1da135462d6954feb66891bd3386f3012bdbc Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 1 Jan 2024 23:38:03 -0800 Subject: [PATCH 007/137] Implement memory-dependency state boundary pass, input/output memlet scaffolding --- .../analysis/schedule_tree/tree_to_sdfg.py | 87 +++++++++- dace/sdfg/analysis/schedule_tree/treenodes.py | 163 +++++++++++++++++- dace/sdfg/memlet_utils.py | 7 + tests/schedule_tree/to_sdfg_test.py | 23 +++ 4 files changed, 267 insertions(+), 13 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 591a013776..080374e266 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,11 +1,13 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import copy -from dace.sdfg import nodes +from collections import defaultdict +from dace.memlet import Memlet +from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion from dace.sdfg.state import SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn from enum import Enum, auto -from typing import Optional, Sequence +from typing import Dict, List, Set, Union class StateBoundaryBehavior(Enum): @@ -34,6 +36,7 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, stree = insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. # TODO: create_state_boundary + # TODO: When creating a state boundary, include all inter-state assignments that precede it. # TODO: create_loop_block # TODO: create_conditional_block # TODO: create_dataflow_scope @@ -60,24 +63,90 @@ class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer): def visit_scope(self, scope: tn.ScheduleTreeScope): if isinstance(scope, tn.ControlFlowScope): - return [tn.StateBoundaryNode(), self.generic_visit(scope)] + return [tn.StateBoundaryNode(True), self.generic_visit(scope)] return self.generic_visit(scope) def visit_StateLabel(self, node: tn.StateLabel): - return [tn.StateBoundaryNode(), self.generic_visit(node)] + return [tn.StateBoundaryNode(True), self.generic_visit(node)] # First, insert boundaries around labels and control flow stree = SimpleStateBoundaryInserter().visit(stree) - # TODO: Insert boundaries after unmet memory dependencies - # TODO: Implement generic methods that get input/output memlets for stree scopes and nodes - # TODO: Implement method that searches for a memlet in a dictionary of memlets (even if that memlet - # is a subset of a dictionary key) and returns that key. If intersection indeterminate, assume - # intersects and replace key with union key. Implement in dace.sdfg.memlet_utils. + # Then, insert boundaries after unmet memory dependencies or potential data races + _insert_memory_dependency_state_boundaries(stree) return stree +def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): + """ + Helper function that inserts boundaries after unmet memory dependencies. + """ + reads: Dict[mmu.MemletSet, List[tn.ScheduleTreeNode]] = defaultdict(list) + writes: Dict[mmu.MemletSet, List[tn.ScheduleTreeNode]] = defaultdict(list) + parents: Dict[int, Set[int]] = defaultdict(set) + boundaries_to_insert: List[int] = [] + + for i, n in enumerate(scope.children): + if isinstance(n, (tn.StateBoundaryNode, tn.ControlFlowScope)): # Clear state + reads.clear() + writes.clear() + parents.clear() + if isinstance(n, tn.ControlFlowScope): # Insert memory boundaries recursively + _insert_memory_dependency_state_boundaries(n) + continue + + # If dataflow scope, insert state boundaries recursively and as a node + if isinstance(n, tn.DataflowScope): + _insert_memory_dependency_state_boundaries(n) + + inputs = n.input_memlets() + outputs = n.output_memlets() + + # Register reads + for inp in inputs: + reads[inp].append(n) + + # Transitively add parents + if inp in writes: + for parent in writes[inp]: + parents[id(n)].add(id(parent)) + parents[id(n)].update(parents[id(parent)]) + + # Inter-state assignment nodes with reads necessitate a state transition if they were written to. + if isinstance(n, tn.AssignNode) and any(inp in writes for inp in inputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Write after write or potential write/write data race, insert state boundary + if any(o in writes and (o not in reads or any(id(r) not in parents for r in reads[o])) for o in outputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Potential read/write data race: if any read is not in the parents of this node, it might + # be performed in parallel + if any(o in reads and any(id(r) not in parents for r in reads[o]) for o in outputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Register writes after all hazards have been tested for + for out in outputs: + writes[out].append(n) + + # Insert memory dependency state boundaries in reverse in order to keep indices intact + for i in reversed(boundaries_to_insert): + scope.children.insert(i, tn.StateBoundaryNode()) + + ############################################################################# # SDFG content creation functions diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index eecd93ad0c..c84c3b2fb2 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,9 +1,11 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import ast from dataclasses import dataclass, field from dace import nodes, data, subsets from dace.codegen import control_flow as cf from dace.properties import CodeBlock -from dace.sdfg.sdfg import InterstateEdge, SDFG +from dace.sdfg.memlet_utils import MemletSet +from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet @@ -29,6 +31,31 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: """ yield self + def get_root(self) -> 'ScheduleTreeRoot': + if self.parent is None: + raise ValueError('Non-root schedule tree node has no parent') + return self.parent.get_root() + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + """ + Returns a set of inputs for this node. For scopes, returns the union of its contents. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :return: A set of memlets representing the inputs of this node. + """ + raise NotImplementedError + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + """ + Returns a set of outputs for this node. For scopes, returns the union of its contents. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :return: A set of memlets representing the inputs of this node. + """ + raise NotImplementedError + @dataclass class ScheduleTreeScope(ScheduleTreeNode): @@ -53,7 +80,13 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: for child in self.children: yield from child.preorder_traversal() - # TODO: Helper function that gets input/output memlets of the scope + # TODO: Missing propagation and locals + # TODO: Add symbol ranges as an argument + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return MemletSet().union(*(c.input_memlets(root) for c in self.children)) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return MemletSet().union(*(c.output_memlets(root) for c in self.children)) @dataclass @@ -73,6 +106,9 @@ def as_sdfg(self) -> SDFG: from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s # Avoid import loop return t2s.from_schedule_tree(self) + def get_root(self) -> 'ScheduleTreeRoot': + return self + @dataclass class ControlFlowScope(ScheduleTreeScope): @@ -104,6 +140,12 @@ class StateLabel(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'label {self.state.name}:' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + @dataclass class GotoNode(ScheduleTreeNode): @@ -113,6 +155,12 @@ def as_string(self, indent: int = 0): name = self.target or 'exit' return indent * INDENTATION + f'goto {name}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + @dataclass class AssignNode(ScheduleTreeNode): @@ -126,6 +174,13 @@ class AssignNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + root = root if root is not None else self.get_root() + return set(self.edge.get_read_memlets(root.containers)) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + @dataclass class ForScope(ControlFlowScope): @@ -141,6 +196,15 @@ def as_string(self, indent: int = 0): f'{node.itervar} = {node.update}:\n') return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + root = root if root is not None else self.get_root() + result = set() + result.update(memlets_in_ast(ast.parse(self.header.init), root.containers)) + result.update(memlets_in_ast(self.header.condition.code[0], root.containers)) + result.update(memlets_in_ast(ast.parse(self.header.update), root.containers)) + result.update(super().input_memlets(root)) + return result + @dataclass class WhileScope(ControlFlowScope): @@ -153,6 +217,13 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'while {self.header.test.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + root = root if root is not None else self.get_root() + result = set() + result.update(memlets_in_ast(self.header.test.code[0], root.containers)) + result.update(super().input_memlets(root)) + return result + @dataclass class DoWhileScope(ControlFlowScope): @@ -166,6 +237,13 @@ def as_string(self, indent: int = 0): footer = indent * INDENTATION + f'while {self.header.test.as_string}\n' return header + super().as_string(indent) + footer + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + root = root if root is not None else self.get_root() + result = set() + result.update(memlets_in_ast(self.header.test.code[0], root.containers)) + result.update(super().input_memlets(root)) + return result + @dataclass class IfScope(ControlFlowScope): @@ -178,6 +256,13 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + root = root if root is not None else self.get_root() + result = set() + result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(super().input_memlets(root)) + return result + @dataclass class StateIfScope(IfScope): @@ -199,6 +284,12 @@ class BreakNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'break' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + @dataclass class ContinueNode(ScheduleTreeNode): @@ -209,6 +300,12 @@ class ContinueNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'continue' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + @dataclass class ElifScope(ControlFlowScope): @@ -221,6 +318,13 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + root = root if root is not None else self.get_root() + result = set() + result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(super().input_memlets(root)) + return result + @dataclass class ElseScope(ControlFlowScope): @@ -283,12 +387,18 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'tasklet({in_memlets})' return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set(self.in_memlets.values()) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set(self.out_memlets.values()) + @dataclass class LibraryCall(ScheduleTreeNode): node: nodes.LibraryNode - in_memlets: Union[Dict[str, Memlet], Set[Memlet]] - out_memlets: Union[Dict[str, Memlet], Set[Memlet]] + in_memlets: Union[Dict[str, Memlet], MemletSet] + out_memlets: Union[Dict[str, Memlet], MemletSet] def as_string(self, indent: int = 0): if isinstance(self.in_memlets, set): @@ -305,6 +415,16 @@ def as_string(self, indent: int = 0): if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + if isinstance(self.in_memlets, set): + return set(self.in_memlets) + return set(self.in_memlets.values()) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + if isinstance(self.out_memlets, set): + return set(self.out_memlets) + return set(self.out_memlets.values()) + @dataclass class CopyNode(ScheduleTreeNode): @@ -323,6 +443,16 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return {self.memlet} + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + root = root if root is not None else self.get_root() + if self.memlet.other_subset is not None: + return {Memlet(data=self.target, subset=self.memlet.other_subset, wcr=self.memlet.wcr)} + + return {Memlet.from_array(self.target, root.containers[self.target], self.memlet.wcr)} + @dataclass class DynScopeCopyNode(ScheduleTreeNode): @@ -335,6 +465,12 @@ class DynScopeCopyNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return {self.memlet} + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + @dataclass class ViewNode(ScheduleTreeNode): @@ -347,6 +483,12 @@ class ViewNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return {self.memlet} + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return {Memlet.from_array(self.target, self.view_desc)} + @dataclass class NView(ViewNode): @@ -373,6 +515,12 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = refset from {type(self.src_desc).__name__.lower()}' return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return {self.memlet} + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return {Memlet.from_array(self.target, self.ref_desc)} + @dataclass class StateBoundaryNode(ScheduleTreeNode): @@ -380,10 +528,17 @@ class StateBoundaryNode(ScheduleTreeNode): A node that represents a state boundary (e.g., when a write-after-write is encountered). This node is used only during conversion from a schedule tree to an SDFG. """ + due_to_control_flow: bool = False def as_string(self, indent: int = 0): return indent * INDENTATION + 'state boundary' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + return set() + # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes class ScheduleNodeVisitor: diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 59a2c178d2..38ae389280 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -77,3 +77,10 @@ def visit_Subscript(self, node: ast.Subscript): if isinstance(node.value, ast.Name) and node.value.id in self.array_filter: return self._replace(node) return self.generic_visit(node) + + +class MemletSet(Set[Memlet], set): + # TODO: Implement method that searches for a memlet in a dictionary of memlets (even if that memlet + # is a subset of a dictionary key) and returns that key. If intersection indeterminate, assume + # intersects and replace key with union key. + pass diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 0d3d9ce9fe..eb4585f6fb 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -129,6 +129,28 @@ def test_state_boundaries_cfg(): assert [tn.TaskletNode, tn.StateBoundaryNode, tn.ForScope] == [type(n) for n in stree.children] +def test_state_boundaries_state_transition(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={ + 'N': dace.symbol('N'), + }, + children=[ + tn.AssignNode('irrelevant', CodeBlock('N + 1'), dace.InterstateEdge(assignments=dict(irrelevant='N + 1'))), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + tn.AssignNode('relevant', CodeBlock('A[1] + 2'), + dace.InterstateEdge(assignments=dict(relevant='A[1] + 2'))), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.AssignNode, tn.TaskletNode, tn.StateBoundaryNode, tn.AssignNode] == [type(n) for n in stree.children] + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -136,3 +158,4 @@ def test_state_boundaries_cfg(): test_state_boundaries_read_write_chain() test_state_boundaries_data_race() test_state_boundaries_cfg() + test_state_boundaries_state_transition() From bb58e80ad5570f82b79df694cdb1d3947fa7b140 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 2 Jan 2024 07:18:17 -0800 Subject: [PATCH 008/137] More tests --- tests/schedule_tree/to_sdfg_test.py | 61 +++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index eb4585f6fb..b3ef97e000 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -7,6 +7,7 @@ from dace.properties import CodeBlock from dace.sdfg import nodes from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn +import pytest def test_state_boundaries_none(): @@ -44,6 +45,30 @@ def test_state_boundaries_waw(): assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] +@pytest.mark.parametrize('overlap', (False, True)) +def test_state_boundaries_waw_ranges(overlap): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={'N': N}, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'pass'), {}, {'out': dace.Memlet('A[0:N/2]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'pass'), {}, + {'out': dace.Memlet('A[1:N]' if overlap else 'A[N/2+1:N]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + if overlap: + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + else: + assert [tn.TaskletNode, tn.TaskletNode] == [type(n) for n in stree.children] + + def test_state_boundaries_war(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( @@ -151,11 +176,47 @@ def test_state_boundaries_state_transition(): assert [tn.AssignNode, tn.TaskletNode, tn.StateBoundaryNode, tn.AssignNode] == [type(n) for n in stree.children] +@pytest.mark.parametrize('boundary', (False, True)) +def test_state_boundaries_propagation(boundary): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={ + 'N': N, + }, + children=[ + tn.MapScope(node=dace.nodes.MapEntry(dace.nodes.Map('map', ['i'], dace.subsets.Range([(1, N - 1, 1)]))), + children=[ + tn.TaskletNode(nodes.Tasklet('inner', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[i]')}), + ]), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[1]' if boundary else 'A[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + + node_types = [type(n) for n in stree.preorder_traversal()] + if boundary: + assert [tn.MapScope, tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == node_types[1:] + else: + assert [tn.MapScope, tn.TaskletNode, tn.TaskletNode] == node_types[1:] + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() + test_state_boundaries_waw_ranges(overlap=False) + test_state_boundaries_waw_ranges(overlap=True) test_state_boundaries_war() test_state_boundaries_read_write_chain() test_state_boundaries_data_race() test_state_boundaries_cfg() test_state_boundaries_state_transition() + test_state_boundaries_propagation(boundary=False) + test_state_boundaries_propagation(boundary=True) From 649b21637208aaa0825d38e5060648832817defb Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 2 Jan 2024 07:18:53 -0800 Subject: [PATCH 009/137] Copyright year --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 2 +- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 +- dace/sdfg/memlet_utils.py | 2 +- tests/schedule_tree/to_sdfg_test.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 080374e266..128e43099f 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict from dace.memlet import Memlet diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index c84c3b2fb2..66208d45e9 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from dataclasses import dataclass, field from dace import nodes, data, subsets diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 38ae389280..ab8f0dad69 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from dace.frontend.python import memlet_parser diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index b3ef97e000..6925511680 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests components in conversion of schedule trees to SDFGs. """ From 438683b0ed69b75498050f4a79f12c1ddaf72e9e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 2 Jan 2024 07:20:38 -0800 Subject: [PATCH 010/137] Copyright year --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 +- tests/schedule_tree/roundtrip_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 6c14ef6435..6d924e00b2 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict import copy from typing import Dict, List, Set diff --git a/tests/schedule_tree/roundtrip_test.py b/tests/schedule_tree/roundtrip_test.py index 7eafe63bf2..e4aea2a56a 100644 --- a/tests/schedule_tree/roundtrip_test.py +++ b/tests/schedule_tree/roundtrip_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests conversion of schedule trees to SDFGs. """ From 7569304527e780996341a340c4997e6c5f489ec5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 2 Jan 2024 07:40:29 -0800 Subject: [PATCH 011/137] Use memlet sets --- dace/sdfg/analysis/schedule_tree/treenodes.py | 64 +++++++++---------- 1 file changed, 32 insertions(+), 32 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 66208d45e9..1506da5c8e 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -141,10 +141,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'label {self.state.name}:' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() @dataclass @@ -156,10 +156,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'goto {name}' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() @dataclass @@ -176,10 +176,10 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: root = root if root is not None else self.get_root() - return set(self.edge.get_read_memlets(root.containers)) + return MemletSet(self.edge.get_read_memlets(root.containers)) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() @dataclass @@ -198,7 +198,7 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: root = root if root is not None else self.get_root() - result = set() + result = MemletSet() result.update(memlets_in_ast(ast.parse(self.header.init), root.containers)) result.update(memlets_in_ast(self.header.condition.code[0], root.containers)) result.update(memlets_in_ast(ast.parse(self.header.update), root.containers)) @@ -219,7 +219,7 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: root = root if root is not None else self.get_root() - result = set() + result = MemletSet() result.update(memlets_in_ast(self.header.test.code[0], root.containers)) result.update(super().input_memlets(root)) return result @@ -239,7 +239,7 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: root = root if root is not None else self.get_root() - result = set() + result = MemletSet() result.update(memlets_in_ast(self.header.test.code[0], root.containers)) result.update(super().input_memlets(root)) return result @@ -258,7 +258,7 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: root = root if root is not None else self.get_root() - result = set() + result = MemletSet() result.update(memlets_in_ast(self.condition.code[0], root.containers)) result.update(super().input_memlets(root)) return result @@ -285,10 +285,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + 'break' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() @dataclass @@ -301,10 +301,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + 'continue' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() @dataclass @@ -320,7 +320,7 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: root = root if root is not None else self.get_root() - result = set() + result = MemletSet() result.update(memlets_in_ast(self.condition.code[0], root.containers)) result.update(super().input_memlets(root)) return result @@ -388,10 +388,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set(self.in_memlets.values()) + return MemletSet(self.in_memlets.values()) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set(self.out_memlets.values()) + return MemletSet(self.out_memlets.values()) @dataclass @@ -417,13 +417,13 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: if isinstance(self.in_memlets, set): - return set(self.in_memlets) - return set(self.in_memlets.values()) + return MemletSet(self.in_memlets) + return MemletSet(self.in_memlets.values()) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: if isinstance(self.out_memlets, set): - return set(self.out_memlets) - return set(self.out_memlets.values()) + return MemletSet(self.out_memlets) + return MemletSet(self.out_memlets.values()) @dataclass @@ -444,14 +444,14 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return {self.memlet} + return MemletSet({self.memlet}) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: root = root if root is not None else self.get_root() if self.memlet.other_subset is not None: - return {Memlet(data=self.target, subset=self.memlet.other_subset, wcr=self.memlet.wcr)} + return MemletSet({Memlet(data=self.target, subset=self.memlet.other_subset, wcr=self.memlet.wcr)}) - return {Memlet.from_array(self.target, root.containers[self.target], self.memlet.wcr)} + return MemletSet({Memlet.from_array(self.target, root.containers[self.target], self.memlet.wcr)}) @dataclass @@ -466,10 +466,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return {self.memlet} + return MemletSet({self.memlet}) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() @dataclass @@ -484,10 +484,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return {self.memlet} + return MemletSet({self.memlet}) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return {Memlet.from_array(self.target, self.view_desc)} + return MemletSet({Memlet.from_array(self.target, self.view_desc)}) @dataclass @@ -516,10 +516,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return {self.memlet} + return MemletSet({self.memlet}) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return {Memlet.from_array(self.target, self.ref_desc)} + return MemletSet({Memlet.from_array(self.target, self.ref_desc)}) @dataclass @@ -534,10 +534,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + 'state boundary' def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return set() + return MemletSet() # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes From 3c1a7864f4698ea4473195e3b6546854f060c190 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 2 Jan 2024 09:30:57 -0800 Subject: [PATCH 012/137] Implement stree scope memlet analysis, memlet propagation now accepts undefined variables --- dace/sdfg/analysis/schedule_tree/treenodes.py | 98 +++++++++++++++++-- dace/sdfg/propagation.py | 18 ++-- 2 files changed, 102 insertions(+), 14 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 1506da5c8e..c7b741336b 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -5,6 +5,7 @@ from dace.codegen import control_flow as cf from dace.properties import CodeBlock from dace.sdfg.memlet_utils import MemletSet +from dace.sdfg.propagation import propagate_subset from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast from dace.sdfg.state import SDFGState from dace.symbolic import symbol @@ -36,7 +37,7 @@ def get_root(self) -> 'ScheduleTreeRoot': raise ValueError('Non-root schedule tree node has no parent') return self.parent.get_root() - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: """ Returns a set of inputs for this node. For scopes, returns the union of its contents. @@ -46,7 +47,7 @@ def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: """ raise NotImplementedError - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: """ Returns a set of outputs for this node. For scopes, returns the union of its contents. @@ -80,13 +81,94 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: for child in self.children: yield from child.preorder_traversal() - # TODO: Missing propagation and locals - # TODO: Add symbol ranges as an argument - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return MemletSet().union(*(c.input_memlets(root) for c in self.children)) + def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoot'], keep_locals: bool, + propagate: Dict[str, subsets.Range], disallow_propagation: Set[str]) -> MemletSet: + gather = (lambda n, root: n.input_memlets(root)) if inputs else (lambda n, root: n.output_memlets(root)) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: - return MemletSet().union(*(c.output_memlets(root) for c in self.children)) + # Fast path, no propagation necessary + if keep_locals: + return MemletSet().union(*(gather(c) for c in self.children)) + + root = root if root is not None else self.get_root() + + if propagate: + to_propagate = list(propagate.items()) + propagate_keys = [a[0] for a in to_propagate] + propagate_values = subsets.Range([a[1].ndrange() for a in to_propagate]) + + current_locals = set() + current_locals |= disallow_propagation + result = MemletSet() + + # Loop over children in order, if any new symbol is defined within this scope (e.g., symbol assignment, + # dynamic map range), consider it as a new local + for c in self.children: + # Add new locals + if isinstance(c, AssignNode): + current_locals.add(c.name) + elif isinstance(c, DynScopeCopyNode): + current_locals.add(c.target) + + internal_memlets: MemletSet = gather(c, root) + if propagate: + for memlet in internal_memlets: + result.add( + propagate_subset(memlet, + root.containers[memlet.data], + propagate_keys, + propagate_values, + undefined_variables=current_locals, + use_dst=not inputs)) + + return result + + def input_memlets(self, + root: Optional['ScheduleTreeRoot'] = None, + keep_locals: bool = False, + propagate: Optional[Dict[str, subsets.Range]] = None, + disallow_propagation: Optional[Set[str]] = None) -> MemletSet: + """ + Returns a union of the set of inputs for this scope. Propagates the memlets used in the scope if ``keep_locals`` + is set to False. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :param keep_locals: If True, keeps the local symbols defined within the scope as part of the resulting memlets. + Otherwise, performs memlet propagation (see ``propagate`` and ``disallow_propagation``) or + assumes the entire container is used. + :param propagate: An optional dictionary mapping symbols to their corresponding ranges outside of this scope. + For example, the range of values a for-loop may take. + If ``keep_locals`` is False, this dictionary will be used to create projection memlets over + the ranges. See :ref:`memprop` in the documentation for more information. + :param disallow_propagation: If ``keep_locals`` is False, this optional set of strings will be considered + as additional locals. + :return: A set of memlets representing the inputs of this scope. + """ + return self._gather_memlets_in_scope(True, root, keep_locals, propagate or {}, disallow_propagation or set()) + + def output_memlets(self, + root: Optional['ScheduleTreeRoot'] = None, + keep_locals: bool = False, + propagate: Optional[Dict[str, subsets.Range]] = None, + disallow_propagation: Optional[Set[str]] = None) -> MemletSet: + """ + Returns a union of the set of outputs for this scope. Propagates the memlets used in the scope if + ``keep_locals`` is set to False. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :param keep_locals: If True, keeps the local symbols defined within the scope as part of the resulting memlets. + Otherwise, performs memlet propagation (see ``propagate`` and ``disallow_propagation``) or + assumes the entire container is used. + :param propagate: An optional dictionary mapping symbols to their corresponding ranges outside of this scope. + For example, the range of values a for-loop may take. + If ``keep_locals`` is False, this dictionary will be used to create projection memlets over + the ranges. See :ref:`memprop` in the documentation for more information. + :param disallow_propagation: If ``keep_locals`` is False, this optional set of strings will be considered + as additional locals. + :return: A set of memlets representing the inputs of this scope. + """ + return self._gather_memlets_in_scope(True, root, keep_locals, propagate or {}, disallow_propagation or set()) @dataclass diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 18c4d7a192..ec12d36d2a 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -680,6 +680,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): return condition_edges + def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: """ Annotate the states of an SDFG with the number of executions. @@ -1389,6 +1390,7 @@ def propagate_subset(memlets: List[Memlet], params: List[str], rng: subsets.Subset, defined_variables: Set[symbolic.SymbolicType] = None, + undefined_variables: Set[symbolic.SymbolicType] = None, use_dst: bool = False) -> Memlet: """ Tries to propagate a list of memlets through a range (computes the image of the memlet function applied on an integer set of, e.g., a @@ -1401,8 +1403,12 @@ def propagate_subset(memlets: List[Memlet], range to propagate with. :param defined_variables: A set of symbols defined that will remain the same throughout propagation. If None, assumes - that all symbols outside of `params` have been - defined. + that all symbols outside of ``params``, except + for ``undefined_variables``, have been defined. + :param undefined_variables: A set of symbols that are explicitly considered + as not defined throughout propagation, such as + locals. Their existence will trigger propagating + the entire memlet. :param use_dst: Whether to propagate the memlets' dst subset or use the src instead, depending on propagation direction. :return: Memlet with propagated subset and volume. @@ -1417,6 +1423,9 @@ def propagate_subset(memlets: List[Memlet], defined_variables -= set(params) defined_variables = set(symbolic.pystr_to_symbolic(p) for p in defined_variables) + if undefined_variables: + defined_variables = defined_variables - undefined_variables + # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1441,10 +1450,7 @@ def propagate_subset(memlets: List[Memlet], tmp_subset = pattern.propagate(arr, [subset], rng) break else: - # No patterns found. Emit a warning and propagate the entire - # array whenever symbols are used - warnings.warn('Cannot find appropriate memlet pattern to ' - 'propagate %s through %s' % (str(subset), str(rng))) + # No patterns found. Propagate the entire array whenever symbols are used entire_array = subsets.Range.from_array(arr) paramset = set(map(str, params)) # Fill in the entire array only if one of the parameters appears in the From 62527d02d65eeb796a5126c5f0a301fc9782ddbe Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 3 Jan 2024 10:08:42 -0800 Subject: [PATCH 013/137] Fix deprecation warning --- dace/frontend/python/newast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 733c3c7f62..6266bc03f3 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1149,7 +1149,7 @@ def __init__(self, if sym.name not in self.sdfg.symbols: self.sdfg.add_symbol(sym.name, sym.dtype) self.sdfg._temp_transients = tmp_idx - self.last_state = self.sdfg.add_state('init', is_start_state=True) + self.last_state = self.sdfg.add_state('init', is_start_block=True) self.inputs: DependencyType = {} self.outputs: DependencyType = {} From 2baa73ca421c66190103051f43af2829ef5d485e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 3 Jan 2024 10:40:18 -0800 Subject: [PATCH 014/137] Implement union-set of memlets --- dace/codegen/control_flow.py | 24 +++- dace/sdfg/analysis/schedule_tree/treenodes.py | 122 ++++++++++++------ dace/sdfg/memlet_utils.py | 107 ++++++++++++++- tests/schedule_tree/to_sdfg_test.py | 20 +++ 4 files changed, 220 insertions(+), 53 deletions(-) diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index a198ed371b..d8e2d5723c 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -60,7 +60,7 @@ from typing import (Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union) import sympy as sp import dace -from dace import dtypes +from dace import dtypes, symbolic from dace.sdfg.state import SDFGState from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.graph import Edge @@ -234,7 +234,7 @@ def as_cpp(self, codegen, symbols) -> str: successor = self.elements[i + 1].first_state elif i == len(self.elements) - 1: # If last edge leads to first state in next block - next_block = _find_next_block(self) + next_block = _find_next_block(self) if next_block is not None: successor = next_block.first_state @@ -372,8 +372,8 @@ def as_cpp(self, codegen, symbols) -> str: init = self.itervar else: init = f'{symbols[self.itervar]} {self.itervar}' - init += ' = ' + unparse_interstate_edge(self.init_edges[0].data.assignments[self.itervar], - sdfg, codegen=codegen) + init += ' = ' + unparse_interstate_edge( + self.init_edges[0].data.assignments[self.itervar], sdfg, codegen=codegen) preinit = '' if self.init_edges: @@ -405,6 +405,22 @@ def first_state(self) -> SDFGState: def children(self) -> List[ControlFlow]: return [self.body] + def loop_range(self) -> Optional[Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType]]: + """ + For well-formed loops, returns a tuple of (start, end, stride). Otherwise, returns None. + """ + from dace.transformation.interstate.loop_detection import find_for_loop + sdfg = self.guard.parent + for e in sdfg.out_edges(self.guard): + if e.data.condition == self.condition: + break + else: + return None # Condition edge not found + result = find_for_loop(sdfg, self.guard, e.dst, self.itervar) + if result is None: + return None + return result[1] + @dataclass class WhileScope(ControlFlow): diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index c7b741336b..5a16b3e317 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -82,8 +82,10 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: yield from child.preorder_traversal() def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoot'], keep_locals: bool, - propagate: Dict[str, subsets.Range], disallow_propagation: Set[str]) -> MemletSet: - gather = (lambda n, root: n.input_memlets(root)) if inputs else (lambda n, root: n.output_memlets(root)) + propagate: Dict[str, + subsets.Range], disallow_propagation: Set[str], **kwargs) -> MemletSet: + gather = (lambda n, root: n.input_memlets(root, **kwargs)) if inputs else ( + lambda n, root: n.output_memlets(root, **kwargs)) # Fast path, no propagation necessary if keep_locals: @@ -94,7 +96,7 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo if propagate: to_propagate = list(propagate.items()) propagate_keys = [a[0] for a in to_propagate] - propagate_values = subsets.Range([a[1].ndrange() for a in to_propagate]) + propagate_values = subsets.Range([a[1] for a in to_propagate]) current_locals = set() current_locals |= disallow_propagation @@ -113,7 +115,7 @@ def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoo if propagate: for memlet in internal_memlets: result.add( - propagate_subset(memlet, + propagate_subset([memlet], root.containers[memlet.data], propagate_keys, propagate_values, @@ -126,7 +128,8 @@ def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, keep_locals: bool = False, propagate: Optional[Dict[str, subsets.Range]] = None, - disallow_propagation: Optional[Set[str]] = None) -> MemletSet: + disallow_propagation: Optional[Set[str]] = None, + **kwargs) -> MemletSet: """ Returns a union of the set of inputs for this scope. Propagates the memlets used in the scope if ``keep_locals`` is set to False. @@ -144,13 +147,15 @@ def input_memlets(self, as additional locals. :return: A set of memlets representing the inputs of this scope. """ - return self._gather_memlets_in_scope(True, root, keep_locals, propagate or {}, disallow_propagation or set()) + return self._gather_memlets_in_scope(True, root, keep_locals, propagate or {}, disallow_propagation or set(), + **kwargs) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, keep_locals: bool = False, propagate: Optional[Dict[str, subsets.Range]] = None, - disallow_propagation: Optional[Set[str]] = None) -> MemletSet: + disallow_propagation: Optional[Set[str]] = None, + **kwargs) -> MemletSet: """ Returns a union of the set of outputs for this scope. Propagates the memlets used in the scope if ``keep_locals`` is set to False. @@ -168,7 +173,8 @@ def output_memlets(self, as additional locals. :return: A set of memlets representing the inputs of this scope. """ - return self._gather_memlets_in_scope(True, root, keep_locals, propagate or {}, disallow_propagation or set()) + return self._gather_memlets_in_scope(False, root, keep_locals, propagate or {}, disallow_propagation or set(), + **kwargs) @dataclass @@ -222,10 +228,10 @@ class StateLabel(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'label {self.state.name}:' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() @@ -237,10 +243,10 @@ def as_string(self, indent: int = 0): name = self.target or 'exit' return indent * INDENTATION + f'goto {name}' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() @@ -256,11 +262,11 @@ class AssignNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() return MemletSet(self.edge.get_read_memlets(root.containers)) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() @@ -278,15 +284,33 @@ def as_string(self, indent: int = 0): f'{node.itervar} = {node.update}:\n') return result + super().as_string(indent) - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() result.update(memlets_in_ast(ast.parse(self.header.init), root.containers)) result.update(memlets_in_ast(self.header.condition.code[0], root.containers)) result.update(memlets_in_ast(ast.parse(self.header.update), root.containers)) - result.update(super().input_memlets(root)) + + # If loop range is well-formed, use it in propagation + rng = self.header.loop_range() + if rng is not None: + propagate = {self.header.itervar: rng} + else: + propagate = None + + result.update(super().input_memlets(root, propagate=propagate, **kwargs)) return result + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + # If loop range is well-formed, use it in propagation + rng = self.header.loop_range() + if rng is not None: + propagate = {self.header.itervar: rng} + else: + propagate = None + + return super().output_memlets(root, propagate=propagate, **kwargs) + @dataclass class WhileScope(ControlFlowScope): @@ -299,11 +323,11 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'while {self.header.test.as_string}:\n' return result + super().as_string(indent) - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() result.update(memlets_in_ast(self.header.test.code[0], root.containers)) - result.update(super().input_memlets(root)) + result.update(super().input_memlets(root, **kwargs)) return result @@ -319,11 +343,11 @@ def as_string(self, indent: int = 0): footer = indent * INDENTATION + f'while {self.header.test.as_string}\n' return header + super().as_string(indent) + footer - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() result.update(memlets_in_ast(self.header.test.code[0], root.containers)) - result.update(super().input_memlets(root)) + result.update(super().input_memlets(root, **kwargs)) return result @@ -338,11 +362,11 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() result.update(memlets_in_ast(self.condition.code[0], root.containers)) - result.update(super().input_memlets(root)) + result.update(super().input_memlets(root, **kwargs)) return result @@ -366,10 +390,10 @@ class BreakNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'break' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() @@ -382,10 +406,10 @@ class ContinueNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'continue' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() @@ -400,11 +424,11 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' return result + super().as_string(indent) - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() result = MemletSet() result.update(memlets_in_ast(self.condition.code[0], root.containers)) - result.update(super().input_memlets(root)) + result.update(super().input_memlets(root, **kwargs)) return result @@ -430,6 +454,18 @@ def as_string(self, indent: int = 0): result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return super().input_memlets(root, + propagate={k: v + for k, v in zip(self.node.map.params, self.node.map.range)}, + **kwargs) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return super().output_memlets(root, + propagate={k: v + for k, v in zip(self.node.map.params, self.node.map.range)}, + **kwargs) + @dataclass class ConsumeScope(DataflowScope): @@ -445,7 +481,7 @@ def as_string(self, indent: int = 0): @dataclass -class PipelineScope(DataflowScope): +class PipelineScope(MapScope): """ Pipeline scope. """ @@ -469,10 +505,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'tasklet({in_memlets})' return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet(self.in_memlets.values()) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet(self.out_memlets.values()) @@ -497,12 +533,12 @@ def as_string(self, indent: int = 0): if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: if isinstance(self.in_memlets, set): return MemletSet(self.in_memlets) return MemletSet(self.in_memlets.values()) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: if isinstance(self.out_memlets, set): return MemletSet(self.out_memlets) return MemletSet(self.out_memlets.values()) @@ -525,10 +561,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet({self.memlet}) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() if self.memlet.other_subset is not None: return MemletSet({Memlet(data=self.target, subset=self.memlet.other_subset, wcr=self.memlet.wcr)}) @@ -547,10 +583,10 @@ class DynScopeCopyNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet({self.memlet}) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() @@ -565,10 +601,10 @@ class ViewNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet({self.memlet}) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet({Memlet.from_array(self.target, self.view_desc)}) @@ -597,10 +633,10 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = refset from {type(self.src_desc).__name__.lower()}' return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet({self.memlet}) - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet({Memlet.from_array(self.target, self.ref_desc)}) @@ -615,10 +651,10 @@ class StateBoundaryNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'state boundary' - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() - def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None) -> MemletSet: + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return MemletSet() diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index ab8f0dad69..9e913c24c2 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -1,9 +1,10 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast +import copy from dace.frontend.python import memlet_parser -from dace import data, Memlet -from typing import Callable, Dict, Optional, Set, Union +from dace import data, Memlet, subsets +from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Union class MemletReplacer(ast.NodeTransformer): @@ -79,8 +80,102 @@ def visit_Subscript(self, node: ast.Subscript): return self.generic_visit(node) -class MemletSet(Set[Memlet], set): - # TODO: Implement method that searches for a memlet in a dictionary of memlets (even if that memlet - # is a subset of a dictionary key) and returns that key. If intersection indeterminate, assume - # intersects and replace key with union key. +class MemletSet(Set[Memlet]): + """ + Implements a set of memlets that considers subsets that intersect or are covered by its other memlets. + Set updates and unions also perform unions on the contained memlet subsets. + """ + + def __init__(self, iterable: Optional[Iterable[Memlet]] = None, intersection_is_contained: bool = True): + """ + Initializes a memlet set. + + :param iterable: An optional iterable of memlets to initialize the set with. + :param intersection_is_contained: Whether the check ``m in memlet_set`` should return True if the memlet + only intersects with the contents of the set. If False, only completely + covered subsets would return True. + """ + self.internal_set: Dict[str, Set[Memlet]] = {} + self.intersection_is_contained = intersection_is_contained + if iterable is not None: + self.update(iterable) + + def __iter__(self): + for subset in self.internal_set.values(): + yield from subset + + def update(self, *iterable: Iterable[Memlet]): + """ + Updates set of memlets via union of existing ranges. + """ + if len(iterable) == 0: + return + if len(iterable) > 1: + for i in iterable: + self.update(i) + return + + to_update, = iterable + for elem in to_update: + self.add(elem) + + def add(self, elem: Memlet): + """ + Adds a memlet to the set, potentially performing a union of existing ranges. + """ + if elem.data not in self.internal_set: + self.internal_set[elem.data] = {elem} + return + + # Memlet is in set, either perform a union (if possible) or add to internal set + # TODO(later): Consider other_subset as well + for existing_memlet in self.internal_set[elem.data]: + if existing_memlet.subset.intersects(elem.subset) == True: # Definitely intersects + if existing_memlet.subset.covers(elem.subset): + break # Nothing to do + + # Create a new union memlet + self.internal_set[elem.data].remove(existing_memlet) + new_memlet = copy.deepcopy(existing_memlet) + new_memlet.subset = subsets.union(existing_memlet.subset, elem.subset) + self.internal_set[elem.data].add(new_memlet) + break + else: # all intersections were False or indeterminate (may or does not intersect with existing memlets) + self.internal_set[elem.data].add(elem) + + def __contains__(self, elem: Memlet) -> bool: + """ + Returns True iff the memlet or a range superset thereof exists in this set. + """ + if elem.data not in self.internal_set: + return False + for existing_memlet in self.internal_set[elem.data]: + if existing_memlet.subset.covers(elem.subset): + return True + if self.intersection_is_contained: + if existing_memlet.subset.intersects(elem.subset) == False: + continue + else: # May intersect or indeterminate + return True + + return False + + def union(self, *s: Iterable[Memlet]) -> 'MemletSet': + """ + Performs a set-union (with memlet union) over the given sets of memlets. + + :return: New memlet set containing the union of this set and the inputs. + """ + newset = MemletSet(self) + newset.update(s) + return newset + + +T = TypeVar('T') + + +class MemletDict(Dict[Memlet, T]): + """ + Implements a dictionary with memlet keys that considers subsets that intersect or are covered by its other memlets. + """ pass diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 6925511680..1fcf26624d 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -208,6 +208,25 @@ def test_state_boundaries_propagation(boundary): assert [tn.MapScope, tn.TaskletNode, tn.TaskletNode] == node_types[1:] +def test_stree_propagation_forloop(): + N = dace.symbol('N') + + @dace.program + def tester(a: dace.float64[20]): + for i in range(1, N): + a[i] = 2 + a[1] = 1 + + stree = tester.to_sdfg().as_schedule_tree() + stree = t2s.insert_state_boundaries_to_tree(stree) + + node_types = [n for n in stree.preorder_traversal()] + assert isinstance(node_types[2], tn.ForScope) + memlet = dace.Memlet('a[1:N]') + memlet._is_data_src = False + assert list(node_types[2].output_memlets()) == [memlet] + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -220,3 +239,4 @@ def test_state_boundaries_propagation(boundary): test_state_boundaries_state_transition() test_state_boundaries_propagation(boundary=False) test_state_boundaries_propagation(boundary=True) + test_stree_propagation_forloop() From c00f8dc17c4b63fa9d80a6dfc89acd239ac53b84 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 4 Jan 2024 02:11:56 -0800 Subject: [PATCH 015/137] Implement memlet dictionary --- .../analysis/schedule_tree/tree_to_sdfg.py | 14 +++-- dace/sdfg/memlet_utils.py | 63 ++++++++++++++++++- 2 files changed, 71 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 128e43099f..9a7c181209 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -82,8 +82,8 @@ def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): """ Helper function that inserts boundaries after unmet memory dependencies. """ - reads: Dict[mmu.MemletSet, List[tn.ScheduleTreeNode]] = defaultdict(list) - writes: Dict[mmu.MemletSet, List[tn.ScheduleTreeNode]] = defaultdict(list) + reads: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict() + writes: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict() parents: Dict[int, Set[int]] = defaultdict(set) boundaries_to_insert: List[int] = [] @@ -105,7 +105,10 @@ def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): # Register reads for inp in inputs: - reads[inp].append(n) + if inp not in reads: + reads[inp] = [n] + else: + reads[inp].append(n) # Transitively add parents if inp in writes: @@ -140,7 +143,10 @@ def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): # Register writes after all hazards have been tested for for out in outputs: - writes[out].append(n) + if out not in writes: + writes[out] = [n] + else: + writes[out].append(n) # Insert memory dependency state boundaries in reverse in order to keep indices intact for i in reversed(boundaries_to_insert): diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 9e913c24c2..a8e90cdccc 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -1,10 +1,11 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast +from collections import defaultdict import copy from dace.frontend.python import memlet_parser from dace import data, Memlet, subsets -from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Union +from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Tuple, Union class MemletReplacer(ast.NodeTransformer): @@ -178,4 +179,62 @@ class MemletDict(Dict[Memlet, T]): """ Implements a dictionary with memlet keys that considers subsets that intersect or are covered by its other memlets. """ - pass + + def __init__(self, **kwargs): + self.internal_dict: Dict[str, Dict[Memlet, T]] = defaultdict(dict) + if kwargs: + self.update(kwargs) + + def _getkey(self, elem: Memlet) -> Optional[Memlet]: + """ + Returns the corresponding key (exact, covered, intersecting, or indeterminately intersecting memlet) if + exists in the dictionary, or None if it does not. + """ + if elem.data not in self.internal_dict: + return None + for existing_memlet in self.internal_dict[elem.data]: + if existing_memlet.subset.covers(elem.subset): + return existing_memlet + try: + if existing_memlet.subset.intersects(elem.subset) == False: # Definitely does not intersect + continue + except TypeError: + pass + + # May or will intersect + return existing_memlet + + return None + + def _setkey(self, key: Memlet, value: T) -> None: + self.internal_dict[key.data][key] = value + + def clear(self): + self.internal_dict.clear() + + def update(self, mapping: Dict[Memlet, T]): + for k, v in mapping.items(): + ak = self._getkey(k) + if ak is None: + self._setkey(k, v) + else: + self._setkey(ak, v) + + def __contains__(self, elem: Memlet) -> bool: + """ + Returns True iff the memlet or a range superset thereof exists in this dictionary. + """ + return self._getkey(elem) is not None + + def __getitem__(self, key: Memlet) -> T: + actual_key = self._getkey(key) + if actual_key is None: + raise KeyError(key) + return self.internal_dict[key.data][actual_key] + + def __setitem__(self, key: Memlet, value: T) -> None: + actual_key = self._getkey(key) + if actual_key is None: + self._setkey(key, value) + else: + self._setkey(actual_key, value) From 6fe97874027ed037778c2951dded1fa9c46ad7d5 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 4 Jan 2024 02:33:11 -0800 Subject: [PATCH 016/137] Minor fixes and adding tests --- dace/sdfg/memlet_utils.py | 32 +++++---- tests/schedule_tree/propagation_test.py | 89 +++++++++++++++++++++++++ tests/schedule_tree/to_sdfg_test.py | 20 ------ 3 files changed, 108 insertions(+), 33 deletions(-) create mode 100644 tests/schedule_tree/propagation_test.py diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index a8e90cdccc..65b34db6f4 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -131,16 +131,19 @@ def add(self, elem: Memlet): # Memlet is in set, either perform a union (if possible) or add to internal set # TODO(later): Consider other_subset as well for existing_memlet in self.internal_set[elem.data]: - if existing_memlet.subset.intersects(elem.subset) == True: # Definitely intersects - if existing_memlet.subset.covers(elem.subset): - break # Nothing to do - - # Create a new union memlet - self.internal_set[elem.data].remove(existing_memlet) - new_memlet = copy.deepcopy(existing_memlet) - new_memlet.subset = subsets.union(existing_memlet.subset, elem.subset) - self.internal_set[elem.data].add(new_memlet) - break + try: + if existing_memlet.subset.intersects(elem.subset) == True: # Definitely intersects + if existing_memlet.subset.covers(elem.subset): + break # Nothing to do + + # Create a new union memlet + self.internal_set[elem.data].remove(existing_memlet) + new_memlet = copy.deepcopy(existing_memlet) + new_memlet.subset = subsets.union(existing_memlet.subset, elem.subset) + self.internal_set[elem.data].add(new_memlet) + break + except TypeError: # Indeterminate + pass else: # all intersections were False or indeterminate (may or does not intersect with existing memlets) self.internal_set[elem.data].add(elem) @@ -154,9 +157,12 @@ def __contains__(self, elem: Memlet) -> bool: if existing_memlet.subset.covers(elem.subset): return True if self.intersection_is_contained: - if existing_memlet.subset.intersects(elem.subset) == False: - continue - else: # May intersect or indeterminate + try: + if existing_memlet.subset.intersects(elem.subset) == False: + continue + else: # May intersect or indeterminate + return True + except TypeError: return True return False diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py new file mode 100644 index 0000000000..872382d3d6 --- /dev/null +++ b/tests/schedule_tree/propagation_test.py @@ -0,0 +1,89 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests schedule tree input/output memlet computation +""" +import dace +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn +from dace.properties import CodeBlock +import numpy as np + + +def test_stree_propagation_forloop(): + N = dace.symbol('N') + + @dace.program + def tester(a: dace.float64[20]): + for i in range(1, N): + a[i] = 2 + a[1] = 1 + + stree = tester.to_sdfg().as_schedule_tree() + stree = t2s.insert_state_boundaries_to_tree(stree) + + node_types = [n for n in stree.preorder_traversal()] + assert isinstance(node_types[2], tn.ForScope) + memlet = dace.Memlet('a[1:N]') + memlet._is_data_src = False + assert list(node_types[2].output_memlets()) == [memlet] + + +def test_stree_propagation_symassign(): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={ + 'N': N, + }, + children=[ + tn.MapScope(node=dace.nodes.MapEntry(dace.nodes.Map('map', ['i'], dace.subsets.Range([(1, N - 1, 1)]))), + children=[ + tn.AssignNode('j', CodeBlock('N + i'), dace.InterstateEdge(assignments=dict(j='N + i'))), + tn.TaskletNode(nodes.Tasklet('inner', {}, {'out'}, 'out = inp + 2'), + {'inp': dace.Memlet('A[j]')}, {'out': dace.Memlet('A[j]')}), + ]), + ], + ) + stree.children[0].parent = stree + for c in stree.children[0].children: + c.parent = stree.children[0] + + assert list(stree.children[0].input_memlets()) == [dace.Memlet('A[0:20]', volume=N - 1)] + + +def test_stree_propagation_dynset(): + H = dace.symbol() + nnz = dace.symbol('nnz') + W = dace.symbol() + + @dace.program + def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32[nnz], x: dace.float32[W]): + b = np.zeros([H], dtype=np.float32) + + for i in dace.map[0:H]: + for j in dace.map[A_row[i]:A_row[i + 1]]: + b[i] += A_val[j] * x[A_col[j]] + + return b + + sdfg = spmv.to_sdfg() + stree = sdfg.as_schedule_tree() + assert len(stree.children) == 2 + assert all(isinstance(c, tn.MapScope) for c in stree.children) + mapscope = stree.children[1] + _, _, dynrangemap = mapscope.children + assert isinstance(dynrangemap, tn.MapScope) + print('internal:', list(dynrangemap.input_memlets())) + print('external:', list(mapscope.input_memlets())) + assert list(dynrangemap.input_memlets()) == [] + assert list(mapscope.input_memlets()) == [] + + +if __name__ == '__main__': + test_stree_propagation_forloop() + test_stree_propagation_symassign() + test_stree_propagation_dynset() diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 1fcf26624d..6925511680 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -208,25 +208,6 @@ def test_state_boundaries_propagation(boundary): assert [tn.MapScope, tn.TaskletNode, tn.TaskletNode] == node_types[1:] -def test_stree_propagation_forloop(): - N = dace.symbol('N') - - @dace.program - def tester(a: dace.float64[20]): - for i in range(1, N): - a[i] = 2 - a[1] = 1 - - stree = tester.to_sdfg().as_schedule_tree() - stree = t2s.insert_state_boundaries_to_tree(stree) - - node_types = [n for n in stree.preorder_traversal()] - assert isinstance(node_types[2], tn.ForScope) - memlet = dace.Memlet('a[1:N]') - memlet._is_data_src = False - assert list(node_types[2].output_memlets()) == [memlet] - - if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -239,4 +220,3 @@ def tester(a: dace.float64[20]): test_state_boundaries_state_transition() test_state_boundaries_propagation(boundary=False) test_state_boundaries_propagation(boundary=True) - test_stree_propagation_forloop() From 5eac791ae646b11ca0ab89f295d05b3c5456f251 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 4 Jan 2024 08:17:14 -0800 Subject: [PATCH 017/137] Fix memlet propagation for undefined symbols, add stree tests --- dace/sdfg/propagation.py | 28 +++++++++++++++++++------ tests/schedule_tree/propagation_test.py | 27 ++++++++++++++++++------ 2 files changed, 43 insertions(+), 12 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index ec12d36d2a..1df6176b46 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -417,6 +417,10 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di if symbolic.issymbolic(dim): used_symbols.update(dim.free_symbols) + if any(s not in defined_vars for s in (used_symbols - set(self.params))): + # Cannot propagate symbols that are undefined outside scope (e.g., internal symbols) + return False + if (used_symbols & set(self.params) and any(symbolic.pystr_to_symbolic(s) not in defined_vars for s in node_range.free_symbols)): # Cannot propagate symbols that are undefined in the outer range @@ -1422,9 +1426,11 @@ def propagate_subset(memlets: List[Memlet], defined_variables |= memlet.free_symbols defined_variables -= set(params) defined_variables = set(symbolic.pystr_to_symbolic(p) for p in defined_variables) + else: + defined_variables = set(defined_variables) if undefined_variables: - defined_variables = defined_variables - undefined_variables + defined_variables = defined_variables - set(symbolic.pystr_to_symbolic(p) for p in undefined_variables) # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1454,11 +1460,21 @@ def propagate_subset(memlets: List[Memlet], entire_array = subsets.Range.from_array(arr) paramset = set(map(str, params)) # Fill in the entire array only if one of the parameters appears in the - # free symbols list of the subset dimension - tmp_subset = subsets.Range([ - ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s - for s, ea in zip(subset, entire_array) - ]) + # free symbols list of the subset dimension or is undefined outside + tmp_subset_rng = [] + for s, ea in zip(subset, entire_array): + contains_params = False + contains_undefs = False + for sdim in s: + fsyms = _freesyms(sdim) + fsyms_str = set(map(str, fsyms)) + contains_params |= len(fsyms_str & paramset) != 0 + contains_undefs |= len(fsyms - defined_variables) != 0 + if contains_params or contains_undefs: + tmp_subset_rng.append(ea) + else: + tmp_subset_rng.append(s) + tmp_subset = subsets.Range(tmp_subset_rng) # Union edges as necessary if new_subset is None: diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py index 872382d3d6..35b05cb7c0 100644 --- a/tests/schedule_tree/propagation_test.py +++ b/tests/schedule_tree/propagation_test.py @@ -56,9 +56,9 @@ def test_stree_propagation_symassign(): def test_stree_propagation_dynset(): - H = dace.symbol() + H = dace.symbol('H') nnz = dace.symbol('nnz') - W = dace.symbol() + W = dace.symbol('W') @dace.program def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32[nnz], x: dace.float32[W]): @@ -77,10 +77,25 @@ def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32 mapscope = stree.children[1] _, _, dynrangemap = mapscope.children assert isinstance(dynrangemap, tn.MapScope) - print('internal:', list(dynrangemap.input_memlets())) - print('external:', list(mapscope.input_memlets())) - assert list(dynrangemap.input_memlets()) == [] - assert list(mapscope.input_memlets()) == [] + + # Check dynamic range map memlets + internal_memlets = list(dynrangemap.input_memlets()) + internal_memlet_data = [m.data for m in internal_memlets] + assert 'x' in internal_memlet_data + assert 'A_val' in internal_memlet_data + assert 'A_row' not in internal_memlet_data + for m in internal_memlets: + if m.data == 'A_val': + assert m.subset != dace.subsets.Range([(0, nnz - 1, 1)]) # Not propagated + + # Check top-level scope memlets + external_memlets = list(mapscope.input_memlets()) + assert dace.Memlet('A_row[0:H+1]') in external_memlets # Two memlets should be unioned + assert dace.Memlet('x[0:W]', volume=0, dynamic=True) in external_memlets + assert dace.Memlet('A_val[0:nnz]', volume=0, dynamic=True) in external_memlets + for m in external_memlets: + if m.data == 'A_val': + assert m.subset == dace.subsets.Range([(0, nnz - 1, 1)]) # Propagated if __name__ == '__main__': From 5a973348cfd7e6f93b48bc74758edc1a1fe1d312 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 5 Jan 2024 00:12:18 -0800 Subject: [PATCH 018/137] Fix test --- tests/schedule_tree/propagation_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py index 35b05cb7c0..507a3d7226 100644 --- a/tests/schedule_tree/propagation_test.py +++ b/tests/schedule_tree/propagation_test.py @@ -90,7 +90,8 @@ def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32 # Check top-level scope memlets external_memlets = list(mapscope.input_memlets()) - assert dace.Memlet('A_row[0:H+1]') in external_memlets # Two memlets should be unioned + assert dace.Memlet('A_row[0:H]') in external_memlets + assert dace.Memlet('A_row[1:H+1]') in external_memlets assert dace.Memlet('x[0:W]', volume=0, dynamic=True) in external_memlets assert dace.Memlet('A_val[0:nnz]', volume=0, dynamic=True) in external_memlets for m in external_memlets: From 553d430d1ab4cb09ebf2e000347bf13f6e1dde7d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Wed, 20 Nov 2024 07:35:28 -0800 Subject: [PATCH 019/137] Update test --- tests/schedule_tree/to_sdfg_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 6925511680..5422f94472 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -146,7 +146,7 @@ def test_state_boundaries_cfg(): tn.TaskletNode(nodes.Tasklet('bla1', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), tn.ForScope([ tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet('A[1]')}), - ], cf.ForScope(None, None, 'i', None, '0', CodeBlock('i < 20'), 'i + 1', None, [])), + ], cf.ForScope(None, None, True, 'i', None, '0', CodeBlock('i < 20'), 'i + 1', None, [])), ], ) From 00f05e6a1ea3086f25cc772b24e2ecb9dacf634e Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 5 Dec 2024 22:20:04 +0100 Subject: [PATCH 020/137] Make sure CI runs for v1 maintenance branch PRs (#1810) --- .github/workflows/fpga-ci.yml | 6 +++--- .github/workflows/general-ci.yml | 6 +++--- .github/workflows/gpu-ci.yml | 6 +++--- .github/workflows/heterogeneous-ci.yml | 6 +++--- .github/workflows/pyFV3-ci.yml | 6 +++--- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/.github/workflows/fpga-ci.yml b/.github/workflows/fpga-ci.yml index 2d6d42514f..b687ecf875 100644 --- a/.github/workflows/fpga-ci.yml +++ b/.github/workflows/fpga-ci.yml @@ -2,11 +2,11 @@ name: FPGA Tests on: push: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] pull_request: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] merge_group: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/general-ci.yml b/.github/workflows/general-ci.yml index cde07f0406..cac290204d 100644 --- a/.github/workflows/general-ci.yml +++ b/.github/workflows/general-ci.yml @@ -2,11 +2,11 @@ name: General Tests on: push: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] pull_request: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] merge_group: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] jobs: test: diff --git a/.github/workflows/gpu-ci.yml b/.github/workflows/gpu-ci.yml index b3af9c8c05..461c353cd4 100644 --- a/.github/workflows/gpu-ci.yml +++ b/.github/workflows/gpu-ci.yml @@ -2,11 +2,11 @@ name: GPU Tests on: push: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] pull_request: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] merge_group: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] env: CUDACXX: /usr/local/cuda/bin/nvcc diff --git a/.github/workflows/heterogeneous-ci.yml b/.github/workflows/heterogeneous-ci.yml index 62887ad208..56fd571ca1 100644 --- a/.github/workflows/heterogeneous-ci.yml +++ b/.github/workflows/heterogeneous-ci.yml @@ -2,11 +2,11 @@ name: Heterogeneous Tests on: push: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] pull_request: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] merge_group: - branches: [ main, ci-fix ] + branches: [ main, v1/maintenance, ci-fix ] env: CUDA_HOME: /usr/local/cuda diff --git a/.github/workflows/pyFV3-ci.yml b/.github/workflows/pyFV3-ci.yml index 852b887cdb..55177e2c6f 100644 --- a/.github/workflows/pyFV3-ci.yml +++ b/.github/workflows/pyFV3-ci.yml @@ -2,11 +2,11 @@ name: NASA/NOAA pyFV3 repository build test on: push: - branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] pull_request: - branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] merge_group: - branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] defaults: run: From 3466973ee16b99cd1703769afa58c3905b777e0c Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Sat, 7 Dec 2024 07:20:04 +0100 Subject: [PATCH 021/137] Unused imports backport (successor to #1808) (#1816) Author: @romanc @romanc is on leave for the next few days, thus I have replayed the changes made in https://github.com/spcl/dace/pull/1808 here for faster turnaround. --- dace/frontend/python/newast.py | 32 ++++++++++++++++++++------------ dace/frontend/python/parser.py | 22 ++++++++++++---------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 1cbb8e67c9..d2813371c9 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1319,7 +1319,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List self.sdfg.replace_dict(repl_dict) propagate_states(self.sdfg) - for state, memlet, inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()): + for state, memlet, _inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()): if state is not None and state.dynamic_executions: memlet.dynamic = True @@ -2366,8 +2366,11 @@ def visit_For(self, node: ast.For): init_expr='%s = %s' % (indices[0], astutils.unparse(ast_ranges[0][0])), update_expr=incr[indices[0]], inverted=False) - _, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}', node.lineno, - extra_symbols=extra_syms, parent=loop_region, + _, first_subblock, _, _ = self._recursive_visit(node.body, + f'for_{node.lineno}', + node.lineno, + extra_symbols=extra_syms, + parent=loop_region, unconnected_last_block=False) loop_region.start_block = loop_region.node_id(first_subblock) self._connect_break_blocks(loop_region) @@ -2449,7 +2452,10 @@ def visit_While(self, node: ast.While): loop_region = self._add_loop_region(loop_cond, label=f'while_{node.lineno}', inverted=False) # Parse body - self._recursive_visit(node.body, f'while_{node.lineno}', node.lineno, parent=loop_region, + self._recursive_visit(node.body, + f'while_{node.lineno}', + node.lineno, + parent=loop_region, unconnected_last_block=False) if test_region is not None: @@ -2540,7 +2546,6 @@ def _has_loop_ancestor(self, node: ControlFlowBlock) -> bool: node = node.parent_graph return False - def visit_Break(self, node: ast.Break): if not self._has_loop_ancestor(self.cfg_target): raise DaceSyntaxError(self, node, "Break block outside loop region") @@ -2572,8 +2577,7 @@ def visit_If(self, node: ast.If): # Process 'else'/'elif' statements if len(node.orelse) > 0: - else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', - sdfg=self.sdfg) + else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) cond_block.add_branch(None, else_body) # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) @@ -2934,7 +2938,6 @@ def _add_aug_assignment(self, wsqueezed = [i for i in range(len(wtarget_subset)) if i not in wsqz] rsqueezed = [i for i in range(len(rtarget_subset)) if i not in rsqz] - if (boolarr or indirect_indices or (sqz_wsub.size() == sqz_osub.size() and sqz_wsub.size() == sqz_rsub.size())): map_range = {i: rng for i, rng in all_idx_tuples} @@ -3358,8 +3361,11 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): new_data, rng = None, None dtype_keys = tuple(dtypes.dtype_to_typeclass().keys()) - if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or - (isinstance(result, str) and any(result in x for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))): + if not ( + result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or + (isinstance(result, str) and any( + result in x + for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))): raise DaceSyntaxError( self, node, "In assignments, the rhs may only be " "data, numerical/boolean constants " @@ -3467,7 +3473,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): cname = self.sdfg.find_new_constant(f'__ind{i}_{true_name}') self.sdfg.add_constant(cname, carr) # Add constant to descriptor repository - self.sdfg.add_array(cname, carr.shape, dtypes.dtype_to_typeclass(carr.dtype.type), + self.sdfg.add_array(cname, + carr.shape, + dtypes.dtype_to_typeclass(carr.dtype.type), transient=True) if numpy.array(arr).dtype == numpy.bool_: boolarr = cname @@ -4769,7 +4777,7 @@ def visit_With(self, node: ast.With, is_async=False): evald = astutils.evalnode(node.items[0].context_expr, self.globals) if hasattr(evald, "name"): named_region_name: str = evald.name - else: + else: named_region_name = f"Named Region {node.lineno}" named_region = NamedRegion(named_region_name, debuginfo=self.current_lineinfo) self.cfg_target.add_node(named_region) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 20018effd0..b65e7c227d 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -59,9 +59,10 @@ def _get_locals_and_globals(f): result.update(f.__globals__) # grab the free variables (i.e. locals) if f.__closure__ is not None: - result.update( - {k: v - for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__])}) + result.update({ + k: v + for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__]) + }) return result @@ -142,6 +143,7 @@ def infer_symbols_from_datadescriptor(sdfg: SDFG, class DaceProgram(pycommon.SDFGConvertible): """ A data-centric program object, obtained by decorating a function with ``@dace.program``. """ + def __init__(self, f, args, @@ -405,9 +407,10 @@ def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any] # Update arguments with symbols in data shapes result.update( - infer_symbols_from_datadescriptor( - sdfg, {k: create_datadescriptor(v) - for k, v in result.items() if k not in self.constant_args})) + infer_symbols_from_datadescriptor(sdfg, { + k: create_datadescriptor(v) + for k, v in result.items() if k not in self.constant_args + })) return result def __call__(self, *args, **kwargs): @@ -487,9 +490,6 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF :param validate: If True, validates the resulting SDFG after creation. :return: The generated SDFG object. """ - # Avoid import loop - from dace.transformation.passes import scalar_to_symbol as scal2sym - from dace.transformation import helpers as xfh # Obtain DaCe program as SDFG sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) @@ -812,7 +812,9 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey: _, key = self._load_sdfg(None, *args, **kwargs) return key - def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], + def _generate_pdp(self, + args: Tuple[Any], + kwargs: Dict[str, Any], simplify: Optional[bool] = None) -> Tuple[SDFG, bool]: """ Generates the parsed AST representation of a DaCe program. From 41e64d4d190f6663e03b78f38a666b47302c4595 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 14 Jan 2025 10:41:13 -0800 Subject: [PATCH 022/137] Cherry pick regression fix from PR #1837 --- .../transformation/passes/scalar_to_symbol.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index a37729ca7c..fd256baa49 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -328,6 +328,25 @@ def __init__(self, in_edges: Dict[str, mm.Memlet], out_edges: Dict[str, mm.Memle self.out_mapping: Dict[str, Tuple[str, subsets.Range]] = {} self.do_not_remove: Set[str] = set() + def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subset) -> subsets.Subset: + """ + Returns the requested range from a subscript node, which consists of the memlet subset composed with the + tasklet subset. + + :param node: The subscript node. + :param memlet_subset: The memlet subset. + :return: The requested range. + """ + arrname, tasklet_slice = astutils.subscript_to_ast_slice(node) + arrname = arrname if arrname in self.arrays else None + if len(tasklet_slice) < len(memlet_subset): + # Unsqueeze all index dimensions from orig_subset into tasklet_subset + for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))): + if start == end: + tasklet_slice.insert(i, (None, None, None)) + tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname)) + return memlet_subset.compose(tasklet_subset) + def visit_Subscript(self, node: ast.Subscript) -> Any: # Convert subscript to symbol name node = self.generic_visit(node) @@ -336,8 +355,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: new_name = dt.find_new_name(node_name, self.connector_names) self.connector_names.add(new_name) - orig_subset = self.in_edges[node_name].subset - subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1])) + subset = self._get_requested_range(node, self.in_edges[node_name].subset) # Check if range can be collapsed if _range_is_promotable(subset, self.defined): self.in_mapping[new_name] = (node_name, subset) @@ -348,8 +366,7 @@ def visit_Subscript(self, node: ast.Subscript) -> Any: new_name = dt.find_new_name(node_name, self.connector_names) self.connector_names.add(new_name) - orig_subset = self.out_edges[node_name].subset - subset = orig_subset.compose(subsets.Range(astutils.subscript_to_slice(node, self.arrays)[1])) + subset = self._get_requested_range(node, self.out_edges[node_name].subset) # Check if range can be collapsed if _range_is_promotable(subset, self.defined): self.out_mapping[new_name] = (node_name, subset) From bb5fd170caa6726046cc7be8477fce33cb5fb486 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 14 Jan 2025 11:26:00 -0800 Subject: [PATCH 023/137] scal2sym: Fix incorrect dimensionality in indirection removal (#1871) Fixes a reported failure mode of scalar to symbol promotion. --- .../transformation/passes/scalar_to_symbol.py | 17 ++++++-- tests/passes/scalar_to_symbol_test.py | 39 +++++++++++++++++++ 2 files changed, 53 insertions(+), 3 deletions(-) diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index fd256baa49..cf2958c4eb 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -340,10 +340,21 @@ def _get_requested_range(self, node: ast.Subscript, memlet_subset: subsets.Subse arrname, tasklet_slice = astutils.subscript_to_ast_slice(node) arrname = arrname if arrname in self.arrays else None if len(tasklet_slice) < len(memlet_subset): + new_tasklet_slice = [(None, None, None)] * len(memlet_subset) # Unsqueeze all index dimensions from orig_subset into tasklet_subset - for i, (start, end, _) in reversed(list(enumerate(memlet_subset.ndrange()))): - if start == end: - tasklet_slice.insert(i, (None, None, None)) + j = 0 + for i, (start, end, _) in enumerate(memlet_subset.ndrange()): + if start != end: + new_tasklet_slice[i] = tasklet_slice[j] + j += 1 + + # Sanity check + if j != len(tasklet_slice): + raise IndexError(f'Only {j} out of {len(tasklet_slice)} indices were provided in subset expression ' + f'"{astutils.unparse(node)}", found during composing with memlet of subset ' + f'"{memlet_subset}".') + tasklet_slice = new_tasklet_slice + tasklet_subset = subsets.Range(astutils.astrange_to_symrange(tasklet_slice, self.arrays, arrname)) return memlet_subset.compose(tasklet_subset) diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 36decceba2..4b1b32a9d5 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -758,6 +758,43 @@ def test_reversed_order(): sdfg.compile() +@pytest.mark.parametrize('memlet_volume_n', (False, True)) +def test_scalar_index_regression(memlet_volume_n): + """ + Tests a reported failure with an invalid promotion of a scalar index. + """ + N = dace.symbol('N') + volume = 1 if not memlet_volume_n else N + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [10, 10, N], dace.float64) + sdfg.add_scalar('scal', dace.int64) + sdfg.add_scalar('tmp', dace.int64, transient=True) + + init_state = sdfg.add_state() + t = init_state.add_tasklet('set', {}, {'t'}, 't = 1') + w = init_state.add_write('tmp') + init_state.add_edge(t, 't', w, None, dace.Memlet('tmp')) + + state = sdfg.add_state_after(init_state) + r = state.add_read('scal') + rt = state.add_read('tmp') + t = state.add_tasklet('setone', {'s', 't'}, {'a'}, 'a[s + t] = -1') + w = state.add_write('A') + state.add_edge(rt, None, t, 't', dace.Memlet('tmp')) + state.add_edge(r, None, t, 's', dace.Memlet('scal')) + state.add_edge(t, 'a', w, None, dace.Memlet(data='A', subset='0, 0, 0:N', volume=volume)) + + sdfg.validate() + scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) + + a = np.random.rand(10, 10, 20) + scal = np.int64(5) + ref = np.copy(a) + ref[0, 0, scal + 1] = -1 + sdfg(A=a, scal=scal, N=20) + assert np.allclose(a, ref) + + if __name__ == '__main__': test_find_promotable() test_promote_simple() @@ -783,3 +820,5 @@ def test_reversed_order(): test_ternary_expression(True) test_double_index_bug() test_reversed_order() + test_scalar_index_regression(False) + test_scalar_index_regression(True) From 03c9222bf55a0ba33f4fb155bc2223e019f9f455 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 14 Jan 2025 11:40:37 -0800 Subject: [PATCH 024/137] Fix Ubuntu version for maintenance branch --- .github/workflows/general-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/general-ci.yml b/.github/workflows/general-ci.yml index cac290204d..9ec7f2caf2 100644 --- a/.github/workflows/general-ci.yml +++ b/.github/workflows/general-ci.yml @@ -11,7 +11,7 @@ on: jobs: test: if: "!contains(github.event.pull_request.labels.*.name, 'no-ci')" - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 strategy: matrix: python-version: [3.7,'3.12'] From a3cd17bbb07fc4830baedf6eed2344330cde78d0 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 23 Jan 2025 07:36:05 -0800 Subject: [PATCH 025/137] Bump version to 1.0.1 --- dace/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/version.py b/dace/version.py index 1f356cc57b..cd7ca4980c 100644 --- a/dace/version.py +++ b/dace/version.py @@ -1 +1 @@ -__version__ = '1.0.0' +__version__ = '1.0.1' From 117dc3abdb7646c66df499f140eb85c15d7af842 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Sun, 2 Feb 2025 17:50:27 +0100 Subject: [PATCH 026/137] Fix typos (backport) (#1918) Backport of PR https://github.com/spcl/dace/pull/1917. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- dace/sdfg/sdfg.py | 2 +- dace/sdfg/validation.py | 2 +- dace/transformation/transformation.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 927f033584..e0190b5e3d 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -912,7 +912,7 @@ def prepend_exit_code(self, cpp_code: str, location: str = 'frame'): def append_transformation(self, transformation): """ - Appends a transformation to the treansformation history of this SDFG. + Appends a transformation to the transformation history of this SDFG. If this is the first transformation being applied, it also saves the initial state of the SDFG to return to and play back the history. diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index c603597fb1..04f6e1b524 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -246,7 +246,7 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None, **context elif const_name in sdfg.symbols: if const_type.dtype != sdfg.symbols[const_name]: # This should actually be an error, but there is a lots of code that depends on it. - warnings.warn(f'Mismatch between constant and symobl type of "{const_name}", ' + warnings.warn(f'Mismatch between constant and symbol type of "{const_name}", ' f'expected to find "{const_type}" but found "{sdfg.symbols[const_name]}".') else: warnings.warn(f'Found constant "{const_name}" that does not refer to an array or a symbol.') diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 727ec5555b..b459535600 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -311,7 +311,7 @@ def _can_be_applied_and_apply( If `apply` is `True` then the function will apply the transformation, if `verify` is also `True` the function will first call `can_be_applied()` to ensure the - transformation can be applied. If not an error is genrated. + transformation can be applied. If not, an error is generated. If `apply` is `False` the function will only call `can_be_applied()` and returns its result. From 56ab279769741aa2130341cc354963cf08989c9a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 20 Feb 2025 18:00:27 +0100 Subject: [PATCH 027/137] Fix typo (#1945) Backport of typo fix for `v1/maintenance` (see https://github.com/spcl/dace/pull/1944). @phschaad would you mind hitting the merge button? Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- dace/frontend/python/astutils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 425e94cd9f..3e99bcdfb5 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -263,7 +263,7 @@ def unparse(node): # Support for numerical constants if isinstance(node, (numbers.Number, numpy.bool_)): return str(node) - # Suport for string + # Support for string if isinstance(node, str): return node From 6e2585b9a586125f931847a0a5897dd56973e46f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 5 Mar 2025 17:33:26 +0100 Subject: [PATCH 028/137] Fix: DDE removing read from access_set in read/write nodes (#1955) ## Description For this bug to show, we need two separate states with a transient produced in one and subsequently read and written (but not read again). It is important that `StateFusion` isn't able to merge these two state. I've put a dummy if/else in the middle. Before DDE this might look like ![image](https://github.com/user-attachments/assets/4ac52b3d-8cd8-4035-bc20-fba2258a7fd7) where `tmp_computed` is transient and `tmp` is a given variable. DDE will now go and see that `tmp_computed` can be removed as an output of the `read_write` tasklet. The currently faulty update of `access_set` will remove `tmp_computed` from the list of reads in `block` state. This will then propagate (badly) up to the `start` state where `tmp_computed` is marked as never read again, removing the whole tasklet, leaving the `block` state to read an uninitialized `tmp_computed` (if we were to codegen). ![image](https://github.com/user-attachments/assets/64892133-1f2f-4bf5-bee6-8a80c192a0cc) ### Repro ```python import dace import os # Create an empty SDFG sdfg = dace.SDFG(os.path.basename(__file__).removesuffix(".py").replace("-", "_")) sdfg.add_scalar("tmp", dace.float32) sdfg.add_scalar("tmp_computed", dace.float32, transient=True) start_state = sdfg.add_state("start", is_start_block=True) read_tmp = start_state.add_read("tmp") write_computed = start_state.add_write("tmp_computed") # upstream tasklet that writes a transient (to be read in a separate state) write = start_state.add_tasklet("write", {"IN_tmp"}, {"OUT_computed"}, "OUT_computed = IN_tmp * 2 + 1") start_state.add_memlet_path(read_tmp, write, dst_conn="IN_tmp", memlet=dace.Memlet(data="tmp")) start_state.add_memlet_path(write, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed")) # Add a condition to avoid fusing next_state and start_state separate guard = sdfg.add_state_after(start_state, "guard_state") true_state = sdfg.add_state("true_state") false_state = sdfg.add_state("false_state") fs_read = false_state.add_read("tmp") fs_write = false_state.add_write("tmp") fs_tasklet = false_state.add_tasklet("abs", {"IN_tmp"}, {"OUT_tmp"}, "OUT_tmp = -IN_tmp") false_state.add_memlet_path(fs_read, fs_tasklet, dst_conn="IN_tmp", memlet=dace.Memlet("tmp")) false_state.add_memlet_path(fs_tasklet, fs_write, src_conn="OUT_tmp", memlet=dace.Memlet("tmp")) merge = sdfg.add_state("merge_state") sdfg.add_edge(guard, true_state, dace.InterstateEdge("tmp >= 0")) sdfg.add_edge(guard, false_state, dace.InterstateEdge("tmp < 0")) sdfg.add_edge(true_state, merge, dace.InterstateEdge()) sdfg.add_edge(false_state, merge, dace.InterstateEdge()) next_state = sdfg.add_state_after(merge) write_computed = next_state.add_write("tmp_computed") read_computed = next_state.add_read("tmp_computed") write_tmp = next_state.add_write("tmp") # downstream tasklet that reads _and_ writes a transient consume = next_state.add_tasklet("read_write", {"IN_computed"}, {"OUT_tmp", "OUT_computed"}, "OUT_computed = 2 * IN_computed\nOUT_tmp = OUT_computed + IN_computed") next_state.add_memlet_path(read_computed, consume, dst_conn="IN_computed", memlet=dace.Memlet(data="tmp_computed")) next_state.add_memlet_path(consume, write_tmp, src_conn="OUT_tmp", memlet=dace.Memlet(data="tmp")) next_state.add_memlet_path(consume, write_computed, src_conn="OUT_computed", memlet=dace.Memlet(data="tmp_computed")) sdfg.validate() sdfg.simplify(verbose=True, validate=True) assert len(list(filter(lambda node: isinstance(node, dace.sdfg.nodes.Tasklet), sdfg.start_block.nodes()))) == 1, "write tasklet in start_block is gone" ``` ### Finishing this PR I'll need help to evaluate whether or not the proposed solution is a good one. Questions that I have: - Is this the right place to fix it? - Should we manually change `access_sets` or would it be simpler/more reliable to redo the analysis step? - The repro case translates to a unit test. Is it a good one or should I e.g. call DDE directly and write assertions for the output of the pass? --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- .../passes/dead_dataflow_elimination.py | 10 ++-- tests/passes/dead_code_elimination_test.py | 53 ++++++++++++++++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 856924abd2..4f5d718bdc 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -159,8 +159,10 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D for code in leaf.src.code.code: ast_find.generic_visit(code) except astutils.NameFound: - # then add the hint expression - leaf.src.code.code = ast.parse(f'{leaf.src_conn}: dace.{ctype.to_string()}\n').body + leaf.src.code.code + # then add the hint expression + leaf.src.code.code = ast.parse( + f'{leaf.src_conn}: dace.{ctype.to_string()}\n' + ).body + leaf.src.code.code else: raise NotImplementedError(f'Cannot eliminate dead connector "{leaf.src_conn}" on ' 'tasklet due to its code language.') @@ -190,8 +192,10 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D # Update read sets for the predecessor states to reuse remaining_access_nodes = set(n for n in (access_nodes - result[state]) if state.out_degree(n) > 0) + remaining_data_containers = set(node.data for node in remaining_access_nodes) removed_data_containers = set(n.data for n in result[state] - if isinstance(n, nodes.AccessNode) and n not in remaining_access_nodes) + if isinstance(n, nodes.AccessNode) and n not in remaining_access_nodes + and n.data not in remaining_data_containers) access_sets[state] = (access_sets[state][0] - removed_data_containers, access_sets[state][1]) return result or None diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index 1832ad8321..28f7887775 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -56,8 +56,8 @@ def test_dse_edge_condition_with_integer_as_boolean_regression(): state_init = sdfg.add_state() state_middle = sdfg.add_state() state_end = sdfg.add_state() - sdfg.add_edge(state_init, state_end, dace.InterstateEdge(condition='(not ((N > 20) != 0))', - assignments={'result': 'N'})) + sdfg.add_edge(state_init, state_end, + dace.InterstateEdge(condition='(not ((N > 20) != 0))', assignments={'result': 'N'})) sdfg.add_edge(state_init, state_middle, dace.InterstateEdge(condition='((N > 20) != 0)')) sdfg.add_edge(state_middle, state_end, dace.InterstateEdge(assignments={'result': '20'})) @@ -210,6 +210,54 @@ def nested(a: dace.float64[20], b: dace.float64[20]): sdfg.validate() +def test_dde_inout_two_states(): + """Test two states with read/write in second state.""" + + sdfg = dace.SDFG("dde_inout_two_states") + sdfg.add_scalar("tmp", dace.float32) + sdfg.add_scalar("computed", dace.float32, transient=True) + + start_state = sdfg.add_state("start_state", is_start_block=True) + s1_read_tmp = start_state.add_read("tmp") + s1_write_computed = start_state.add_write("computed") + + # upstream tasklet that writes a transient (to be read in a separate state) + first_tasklet = start_state.add_tasklet("write", {"read_tmp"}, {"write_computed"}, + "write_computed = read_tmp * 2 + 1") + start_state.add_memlet_path(s1_read_tmp, first_tasklet, dst_conn="read_tmp", memlet=dace.Memlet(data="tmp")) + start_state.add_memlet_path(first_tasklet, + s1_write_computed, + src_conn="write_computed", + memlet=dace.Memlet(data="computed")) + + next_state = sdfg.add_state_after(start_state, "next_state") + s2_write_computed = next_state.add_write("computed") + s2_read_computed = next_state.add_read("computed") + s2_write_tmp = next_state.add_write("tmp") + + # downstream tasklet that reads _and_ writes a transient + second_tasklet = next_state.add_tasklet( + "read_write", {"read_computed"}, {"write_tmp", "write_computed"}, + "write_computed = 2 * read_computed\nwrite_tmp = write_computed + read_computed") + next_state.add_memlet_path(s2_read_computed, + second_tasklet, + dst_conn="read_computed", + memlet=dace.Memlet(data="computed")) + next_state.add_memlet_path(second_tasklet, s2_write_tmp, src_conn="write_tmp", memlet=dace.Memlet(data="tmp")) + next_state.add_memlet_path(second_tasklet, + s2_write_computed, + src_conn="write_computed", + memlet=dace.Memlet(data="computed")) + + results = {} + Pipeline([DeadDataflowElimination()]).apply_pass(sdfg, results) + + dde_results = results["DeadDataflowElimination"] + assert dde_results.get(start_state) is None, "No changes to `start_state` expected." + expected_cleanup = dde_results.get(next_state) + assert expected_cleanup == {s2_write_computed}, "Expected to clean up write to `computed` from `next_state`." + + def test_dce(): """ End-to-end test evaluating both dataflow and state elimination. """ # Code should end up as b[:] = a + 2; b += 1 @@ -336,6 +384,7 @@ def test_dce_add_type_hint_of_variable(dtype): test_dde_scope_reconnect() test_dde_inout(False) test_dde_inout(True) + test_dde_inout_two_states() test_dce() test_dce_callback() test_dce_callback_manual() From 11d0e330993c8406e0f9ad4d1801e8da0a17abae Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Wed, 5 Mar 2025 11:34:44 -0500 Subject: [PATCH 029/137] `StateFusion` misses read-write conflict due to early return (#1954) Check for all potential match in `_check_paths` to not miss potential `memlets_intersect` failure. Unit tests for checks the internal `_check_paths` function since it triggers on a non deterministic networkx search. :warning: This code exists as-is in `main` as of Feb 26 and showcase the same issue, cherry-picking fix will have to happen :warning: --- .../transformation/interstate/state_fusion.py | 7 +- tests/transformations/state_fusion_test.py | 94 ++++++++++++++++++- 2 files changed, 98 insertions(+), 3 deletions(-) diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index dbdf7642bd..b8d29f6676 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -153,12 +153,15 @@ def _check_paths(self, first_state: SDFGState, second_state: SDFGState, match_no path_to = nx.has_path(first_state._nx, node, match) if not path_to: continue - path_found = True + path_found |= True node2 = next(n for n in second_input if n.data == match.data) if not all(nx.has_path(second_state._nx, node2, n) for n in nodes_second): fail = True break - if fail or path_found: + # We keep looking for a potential match with a path that fail to find + # a path to the second state to make sure we test memlet_intersections + # independant of the order of the access nodes in the lists + if fail: break # Check for intersection (if None, fusion is ok) diff --git a/tests/transformations/state_fusion_test.py b/tests/transformations/state_fusion_test.py index 6fa0cc8d05..837b9abacc 100644 --- a/tests/transformations/state_fusion_test.py +++ b/tests/transformations/state_fusion_test.py @@ -397,7 +397,98 @@ def func(A: dace.float64[128, 128], B: dace.float64[128, 128]): assert sdfg.number_of_nodes() == 2 -if __name__ == '__main__': +def test_check_paths(): + # Test extracted from NASA GFDL_1M microphysics + + # Case of: + # qm -> q -> qm, m1 in Block_0 + # qm -> q and m1 -> m1 in Block_5 + # m1 has a write in both cases, leading to state not being fusable + # but original code would exist early if qm was tested _before_ m1 + + sdfg = dace.SDFG("state_fusion_check_path_test") + sdfg.add_array("m1", [1], dace.int32) + sdfg.add_array("precip_fall", [1], dace.int32) + sdfg.add_array("q", [1], dace.int32) + sdfg.add_array("qm", [1], dace.int32) + sdfg.add_array("dp1", [1], dace.int32) + + block_0 = sdfg.add_state() + q_b0_w = block_0.add_write("q") + qm_b0 = block_0.add_read("qm") + qm_b0_w = block_0.add_write("qm") + tasklet_b0_on_q = block_0.add_tasklet( + "tasklet_b0_on_q", + {"p_qm"}, + {"p_q_w"}, + "p_q_w = p_qm", + ) + block_0.add_edge(qm_b0, None, tasklet_b0_on_q, "p_qm", dace.Memlet("qm[0]")) + block_0.add_edge(tasklet_b0_on_q, "p_q_w", q_b0_w, None, dace.Memlet("q[0]")) + + m1_b0_w = block_0.add_write("m1") + tasklet_b0_on_m1 = block_0.add_tasklet( + "tasklet_b0_on_m1_qm", + {"p_q"}, + {"p_m1_w", "p_qm_w"}, + "p_m1_w = p_q", + ) + block_0.add_edge(q_b0_w, None, tasklet_b0_on_m1, "p_q", dace.Memlet("q[0]")) + block_0.add_edge(tasklet_b0_on_m1, "p_m1_w", m1_b0_w, None, dace.Memlet("m1[0]")) + block_0.add_edge(tasklet_b0_on_m1, "p_qm_w", qm_b0_w, None, dace.Memlet("qm[0]")) + + block_5 = sdfg.add_state_after(block_0) + precip_fall_b5 = block_5.add_read("precip_fall") + qm_b5 = block_5.add_read("qm") + q_b5_w = block_5.add_write("q") + tasklet_b5_on_q = block_5.add_tasklet( + "tasklet_b5_on_q", + {"p_precip_fall", "p_qm"}, + {"p_q_w"}, + "p_q_w = p_dp1 + 1", + ) + block_5.add_edge( + precip_fall_b5, + None, + tasklet_b5_on_q, + "p_precip_fall", + dace.Memlet("precip_fall[0]"), + ) + block_5.add_edge(qm_b5, None, tasklet_b5_on_q, "p_qm", dace.Memlet("qm[0]")) + block_5.add_edge(tasklet_b5_on_q, "p_q_w", q_b5_w, None, dace.Memlet("q[0]")) + + m1_b5 = block_5.add_read("m1") + m1_b5_w = block_5.add_write("m1") + tasklet_b5_on_m1 = block_5.add_tasklet( + "tasklet_b5_on_m1", + {"p_m1", "p_precip_fall"}, + {"p_m1_w"}, + "m1_w = p_m1 + 1", + ) + block_5.add_edge(m1_b5, None, tasklet_b5_on_m1, "p_m1", dace.Memlet("m1[0]")) + block_5.add_edge( + precip_fall_b5, + None, + tasklet_b5_on_m1, + "p_precip_fall", + dace.Memlet("precip_fall[0]"), + ) + block_5.add_edge(tasklet_b5_on_m1, "p_m1_w", m1_b5_w, None, dace.Memlet("m1[0]")) + + do_fuse = StateFusion()._check_paths( + first_state=block_0, + second_state=block_5, + match_nodes={qm_b0_w: qm_b5, m1_b0_w: m1_b5}, + nodes_first=[q_b0_w], + nodes_second=[q_b5_w], + second_input={precip_fall_b5, m1_b5, qm_b5}, + first_read=False, + second_read=False, + ) + assert not do_fuse + + +if __name__ == "__main__": test_fuse_assignments() test_fuse_assignments_2() test_fuse_assignment_in_use() @@ -414,3 +505,4 @@ def func(A: dace.float64[128, 128], B: dace.float64[128, 128]): test_inout_read_after_write() test_inout_second_state() test_inout_second_state_2() + test_check_paths() From fff6010f8991c2dae558dc0354f452d85b6f3ef1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 18 Mar 2025 17:17:57 +0100 Subject: [PATCH 030/137] Generate two states from state boundary node --- .../analysis/schedule_tree/tree_to_sdfg.py | 27 +++++++++++++------ tests/schedule_tree/to_sdfg_test.py | 23 ++++++++++++++++ 2 files changed, 42 insertions(+), 8 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 9a7c181209..931f447948 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -35,11 +35,19 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, # TODO: Fill SDFG contents stree = insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. - # TODO: create_state_boundary - # TODO: When creating a state boundary, include all inter-state assignments that precede it. - # TODO: create_loop_block - # TODO: create_conditional_block - # TODO: create_dataflow_scope + # Start with an empty state, .. + current_state = result.add_state(is_start_block=True) + + # .. then add children one by one. + for child in stree.children: + # create_state_boundary + if isinstance(child, tn.StateBoundaryNode): + # TODO: When creating a state boundary, include all inter-state assignments that precede it. + current_state = create_state_boundary(child, result, current_state, state_boundary_behavior) + + # TODO: create_loop_block + # TODO: create_conditional_block + # TODO: create_dataflow_scope return result @@ -168,8 +176,11 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowR :param behavior: The state boundary behavior with which to create the boundary. :return: The newly created state. """ + if behavior != StateBoundaryBehavior.STATE_TRANSITION: + raise ValueError("Only STATE_TRANSITION is supported as StateBoundaryBehavior in this prototype.") + # TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every # behavior. Fall back to state transition in that case. - scope: tn.ControlFlowScope = bnode.parent - assert scope is not None - pass + + label = "cf_state_boundary" if bnode.due_to_control_flow else "state_boundary" + return sdfg_region.add_state_after(state, label=label) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 5422f94472..674efa5403 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -208,6 +208,26 @@ def test_state_boundaries_propagation(boundary): assert [tn.MapScope, tn.TaskletNode, tn.TaskletNode] == node_types[1:] +@pytest.mark.parametrize("control_flow", (True, False)) +def test_create_state_boundary_state_transition(control_flow): + sdfg = dace.SDFG("tester") + state = sdfg.add_state("start", is_start_block=True) + bnode = tn.StateBoundaryNode(control_flow) + + t2s.create_state_boundary(bnode, sdfg, state, t2s.StateBoundaryBehavior.STATE_TRANSITION) + new_label = "cf_state_boundary" if control_flow else "state_boundary" + assert ["start", new_label] == [state.label for state in sdfg.states()] + + +@pytest.mark.xfail(reason="Not yet implemented") +def test_create_state_boundary_empty_memlet(control_flow): + sdfg = dace.SDFG("tester") + state = sdfg.add_state("start", is_start_block=True) + bnode = tn.StateBoundaryNode(control_flow) + + t2s.create_state_boundary(bnode, sdfg, state, t2s.StateBoundaryBehavior.EMPTY_MEMLET) + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -220,3 +240,6 @@ def test_state_boundaries_propagation(boundary): test_state_boundaries_state_transition() test_state_boundaries_propagation(boundary=False) test_state_boundaries_propagation(boundary=True) + test_create_state_boundary_state_transition(control_flow=True) + test_create_state_boundary_state_transition(control_flow=False) + test_create_state_boundary_empty_memlet() From efdf839b82c0379513893a66c893ad04879304c8 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 19 Mar 2025 15:59:24 +0100 Subject: [PATCH 031/137] WIP: use visitor to generate SDFG - ScheduleTreeRoot: done - StateBoundaryNode: done - TaskletNode: done --- .../analysis/schedule_tree/tree_to_sdfg.py | 78 +++++++++++++++---- dace/sdfg/analysis/schedule_tree/treenodes.py | 12 +-- tests/schedule_tree/to_sdfg_test.py | 57 ++++++++++++++ 3 files changed, 128 insertions(+), 19 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 931f447948..46605e1ca9 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -7,7 +7,7 @@ from dace.sdfg.state import SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn from enum import Enum, auto -from typing import Dict, List, Set, Union +from typing import Dict, List, Optional, Set, Union class StateBoundaryBehavior(Enum): @@ -28,26 +28,78 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, # Set SDFG descriptor repository result = SDFG(stree.name, propagate=False) result.arg_names = copy.deepcopy(stree.arg_names) - result._arrays = copy.deepcopy(stree.containers) + for key, container in stree.containers.items(): + result._arrays[key] = copy.deepcopy(container) result.constants_prop = copy.deepcopy(stree.constants) result.symbols = copy.deepcopy(stree.symbols) # TODO: Fill SDFG contents stree = insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. - # Start with an empty state, .. - current_state = result.add_state(is_start_block=True) + class StreeToSDFG(tn.ScheduleNodeVisitor): - # .. then add children one by one. - for child in stree.children: - # create_state_boundary - if isinstance(child, tn.StateBoundaryNode): - # TODO: When creating a state boundary, include all inter-state assignments that precede it. - current_state = create_state_boundary(child, result, current_state, state_boundary_behavior) + def __init__(self) -> None: + self._state_stack: List[SDFGState] = [] + self._current_state: Optional[SDFGState] = None + self.access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} + + def _push_state(self, state: SDFGState) -> None: + """Push a state on the stack of states. Set it as self._current_state and setup an access_cache dictionary.""" + self._state_stack.append(state) + self._current_state = state + self.access_cache[state] = {} + + def _pop_state(self, label: Optional[str] = None) -> SDFGState: + """Pops the last state from the stack. Replaces `self._current_state` and cleans up the access_cache.""" + if not self._state_stack: + raise ValueError("Can't pop state from empty stack.") + + popped = self._state_stack.pop() + if label: + assert popped.label.startswith(label) + + self._current_state = None if not self._state_stack else self._state_stack[-1] - # TODO: create_loop_block - # TODO: create_conditional_block - # TODO: create_dataflow_scope + del self.access_cache[popped] + + return popped + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: + self._push_state(sdfg.add_state(label="tree_root", is_start_block=True)) + self.visit(node.children, sdfg=sdfg) + + def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: + # TODO: When creating a state boundary, include all inter-state assignments that precede it. + self._push_state( + create_state_boundary(node, sdfg, self._current_state, StateBoundaryBehavior.STATE_TRANSITION)) + + def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: + # Add Tasklet to current state + tasklet = node.node + self._current_state.add_node(tasklet) + + # Connect inputs and outputs + cache = self.access_cache[self._current_state] + for name, memlet in node.in_memlets.items(): + # cache read access + if memlet.data not in cache: + cache[memlet.data] = self._current_state.add_read(memlet.data) + + access_node = cache[memlet.data] + self._current_state.add_memlet_path(access_node, tasklet, dst_conn=name, memlet=memlet) + + for name, memlet in node.out_memlets.items(): + # we always write to a new access_node + access_node = self._current_state.add_write(memlet.data) + self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) + + # cache write access node (or update an existing one) for read after write cases + cache[memlet.data] = access_node + + # TODO: create_loop_block + # TODO: create_conditional_block + # TODO: create_dataflow_scope + StreeToSDFG().visit(stree, sdfg=result) return result diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index eec66b0524..453d13d80a 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -702,21 +702,21 @@ def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes class ScheduleNodeVisitor: - def visit(self, node: ScheduleTreeNode): + def visit(self, node: ScheduleTreeNode, **kwargs: Any): """Visit a node.""" if isinstance(node, list): - return [self.visit(snode) for snode in node] + return [self.visit(snode, **kwargs) for snode in node] if isinstance(node, ScheduleTreeScope) and hasattr(self, 'visit_scope'): - return self.visit_scope(node) + return self.visit_scope(node, **kwargs) method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) - return visitor(node) + return visitor(node, **kwargs) - def generic_visit(self, node: ScheduleTreeNode): + def generic_visit(self, node: ScheduleTreeNode, **kwargs: Any): if isinstance(node, ScheduleTreeScope): for child in node.children: - self.visit(child) + self.visit(child, **kwargs) class ScheduleNodeTransformer(ScheduleNodeVisitor): diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 674efa5403..6df8dbe4a3 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -228,6 +228,61 @@ def test_create_state_boundary_empty_memlet(control_flow): t2s.create_state_boundary(bnode, sdfg, state, t2s.StateBoundaryBehavior.EMPTY_MEMLET) +def test_create_tasklet_raw(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + sdfg = stree.as_sdfg() + assert len(sdfg.states()) == 1 + state = sdfg.states()[0] + first_tasklet, write_read_node, second_tasklet, write_node = state.nodes() + + assert first_tasklet.label == "bla" + assert not first_tasklet.in_connectors + assert first_tasklet.out_connectors.keys() == {"out"} + + assert second_tasklet.label == "bla2" + assert second_tasklet.in_connectors.keys() == {"inp"} + assert second_tasklet.out_connectors.keys() == {"out"} + + assert [(first_tasklet, write_read_node), (write_read_node, second_tasklet), + (second_tasklet, write_node)] == [(edge.src, edge.dst) for edge in state.edges()] + + +def test_create_tasklet_waw(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + sdfg = stree.as_sdfg() + assert len(sdfg.states()) == 2 + s1, s2 = sdfg.states() + + s1_tasklet, s1_anode = s1.nodes() + assert [(s1_tasklet, s1_anode)] == [(edge.src, edge.dst) for edge in s1.edges()] + + s2_tasklet, s2_anode = s2.nodes() + assert [(s2_tasklet, s2_anode)] == [(edge.src, edge.dst) for edge in s2.edges()] + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -243,3 +298,5 @@ def test_create_state_boundary_empty_memlet(control_flow): test_create_state_boundary_state_transition(control_flow=True) test_create_state_boundary_state_transition(control_flow=False) test_create_state_boundary_empty_memlet() + test_create_tasklet_raw() + test_create_tasklet_waw() From 13402cbfeeb6969cbd3915acfb7a30bdb543071b Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 20 Mar 2025 08:15:19 -0700 Subject: [PATCH 032/137] Bump version to 1.0.2 --- dace/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/version.py b/dace/version.py index cd7ca4980c..a6221b3de7 100644 --- a/dace/version.py +++ b/dace/version.py @@ -1 +1 @@ -__version__ = '1.0.1' +__version__ = '1.0.2' From 45b4125ac686e2f6cd3ee77b0011610774b1412f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 20 Mar 2025 16:32:35 +0100 Subject: [PATCH 033/137] WIP: visit_ForScope done --- .../analysis/schedule_tree/tree_to_sdfg.py | 46 +++++++++++++++++-- tests/schedule_tree/to_sdfg_test.py | 26 +++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 46605e1ca9..30e19fb6c3 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,13 +1,12 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict -from dace.memlet import Memlet -from dace.sdfg import nodes, memlet_utils as mmu -from dace.sdfg.sdfg import SDFG, ControlFlowRegion +from dace.sdfg import nodes, memlet_utils as mmu, utils as sdfg_utils +from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge from dace.sdfg.state import SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn from enum import Enum, auto -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set class StateBoundaryBehavior(Enum): @@ -73,6 +72,39 @@ def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> Non self._push_state( create_state_boundary(node, sdfg, self._current_state, StateBoundaryBehavior.STATE_TRANSITION)) + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + before_state = self._current_state + self._push_state(sdfg.add_state(label="loop_guard")) + guard_state = self._current_state + sdfg.add_edge(before_state, self._current_state, + InterstateEdge(assignments=dict({node.header.itervar: node.header.init}))) + + self._push_state(sdfg.add_state(label="loop_body")) + body_state = self._current_state + sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) + self.visit(node.children, sdfg=sdfg) + sdfg.add_edge(self._current_state, guard_state, + InterstateEdge(assignments=dict({node.header.itervar: node.header.update}))) + + self._push_state(sdfg.add_state(label="loop_after")) + after_state = self._current_state + negated_condition = f"not {node.header.condition}" if isinstance( + node.header.condition, str) else f"not {node.header.condition.as_string}" + sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=negated_condition)) + + def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + # TODO + pass + + def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: + # AFAIK we don't have support for do-while loops in the gt4py -> dace bridge. + # Implementing this is thus not necessary for the first prototype. + raise NotImplementedError + + def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError + def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add Tasklet to current state tasklet = node.node @@ -101,6 +133,9 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # TODO: create_dataflow_scope StreeToSDFG().visit(stree, sdfg=result) + # Convert LoopRegions to "normal" SDFG control flow + sdfg_utils.inline_loop_blocks(result) + return result @@ -229,7 +264,8 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowR :return: The newly created state. """ if behavior != StateBoundaryBehavior.STATE_TRANSITION: - raise ValueError("Only STATE_TRANSITION is supported as StateBoundaryBehavior in this prototype.") + # Only STATE_TRANSITION is supported as StateBoundaryBehavior in this prototype. + raise NotImplementedError # TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every # behavior. Fall back to state transition in that case. diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 6df8dbe4a3..becdbb1d93 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -283,6 +283,31 @@ def test_create_tasklet_waw(): assert [(s2_tasklet, s2_anode)] == [(edge.src, edge.dst) for edge in s2.edges()] +def test_create_for_loop(): + # Manually create a schedule tree + # yapf: disable + loop=tn.ForScope( + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + header=cf.ForScope( + itervar="i", init="0", condition=CodeBlock("i<3"), update="i+1", + dispatch_state=None, parent=None, last_block=True, guard=None, body=None, init_edges=[] + ) + ) + stree=tn.ScheduleTreeRoot( + name='tester', + containers={'A': dace.data.Array(dace.float64, [20])}, + children=[loop] + ) + # yapf: enable + assert stree is not None + + sdfg = stree.as_sdfg() + assert sdfg is not None + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -300,3 +325,4 @@ def test_create_tasklet_waw(): test_create_state_boundary_empty_memlet() test_create_tasklet_raw() test_create_tasklet_waw() + test_create_for_loop() From 35cfd2b6d366c72018b4b8cba0d02b54e662f548 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 21 Mar 2025 14:29:40 +0100 Subject: [PATCH 034/137] WIP: visit_WhileScope done --- .../analysis/schedule_tree/tree_to_sdfg.py | 24 ++++++++------ tests/schedule_tree/to_sdfg_test.py | 32 +++++++++++++++++-- 2 files changed, 45 insertions(+), 11 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 30e19fb6c3..6a37736cde 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -88,13 +88,23 @@ def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: self._push_state(sdfg.add_state(label="loop_after")) after_state = self._current_state - negated_condition = f"not {node.header.condition}" if isinstance( - node.header.condition, str) else f"not {node.header.condition.as_string}" - sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=negated_condition)) + sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: - # TODO - pass + before_state = self._current_state + self._push_state(sdfg.add_state(label="guard_state")) + guard_state = self._current_state + sdfg.add_edge(before_state, guard_state, InterstateEdge()) + + self._push_state(sdfg.add_state(label="loop_body")) + body_state = self._current_state + sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) + self.visit(node.children, sdfg=sdfg) + sdfg.add_edge(self._current_state, guard_state, InterstateEdge()) + + self._push_state(sdfg.add_state(label="loop_after")) + after_state = self._current_state + sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: # AFAIK we don't have support for do-while loops in the gt4py -> dace bridge. @@ -128,14 +138,10 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # cache write access node (or update an existing one) for read after write cases cache[memlet.data] = access_node - # TODO: create_loop_block # TODO: create_conditional_block # TODO: create_dataflow_scope StreeToSDFG().visit(stree, sdfg=result) - # Convert LoopRegions to "normal" SDFG control flow - sdfg_utils.inline_loop_blocks(result) - return result diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index becdbb1d93..1ea2ac15e3 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -302,10 +302,37 @@ def test_create_for_loop(): children=[loop] ) # yapf: enable - assert stree is not None sdfg = stree.as_sdfg() - assert sdfg is not None + sdfg.validate() + + +def test_create_while_loop(): + # Manually create a schedule tree + # yapf: disable + loop=tn.WhileScope( + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + header=cf.WhileScope( + test=CodeBlock("A[1] > 5"), + dispatch_state=None, + last_block=True, + parent=None, + guard=None, + body=None + ) + ) + stree=tn.ScheduleTreeRoot( + name='tester', + containers={'A': dace.data.Array(dace.float64, [20])}, + children=[loop] + ) + # yapf: enable + + sdfg = stree.as_sdfg() + sdfg.validate() if __name__ == '__main__': @@ -326,3 +353,4 @@ def test_create_for_loop(): test_create_tasklet_raw() test_create_tasklet_waw() test_create_for_loop() + test_create_while_loop() From c1330d8cf90a3b2ffbce02d2216258df4504fa5a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 21 Mar 2025 18:10:00 +0100 Subject: [PATCH 035/137] WIP: started working on conditional blocks --- .../analysis/schedule_tree/tree_to_sdfg.py | 65 +++++++++++++++++-- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 6a37736cde..a0f0383ccd 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -72,6 +72,22 @@ def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> Non self._push_state( create_state_boundary(node, sdfg, self._current_state, StateBoundaryBehavior.STATE_TRANSITION)) + def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_StateLabel(self, node: tn.StateLabel, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_GotoNode(self, node: tn.GotoNode, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: + # TODO: We'll need these symbol assignments + raise NotImplementedError(f"{type(node)} not implemented") + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: before_state = self._current_state self._push_state(sdfg.add_state(label="loop_guard")) @@ -107,13 +123,54 @@ def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: - # AFAIK we don't have support for do-while loops in the gt4py -> dace bridge. - # Implementing this is thus not necessary for the first prototype. - raise NotImplementedError + # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... - raise NotImplementedError + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: + # TODO + # add guard state + # add true_state + # visit children + + # only add merge state and close it if there's no `ElseScope` following + # 1. find this node in node.parent.children + # 2. check if there's an `ElseScope` following this block + # The above two-step process works even if we have nested if statements because + # the nested if statement would be in the children of this node. + raise NotImplementedError("TODO: IfScope not yet implemented") + + def visit_StateIfScope(self, node: tn.StateIfScope, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_BreakNode(self, node: tn.BreakNode, sdfg: SDFG) -> None: + # AFAIK we don't support for break statements in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_ContinueNode(self, node: tn.ContinueNode, sdfg: SDFG) -> None: + # AFAIK we don't support for continue statements in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_ElifScope(self, node: tn.ElifScope, sdfg: SDFG) -> None: + # AFAIK we don't support elif scopes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: + # TODO + # last_state = self._current_state # the last block of the if-(elif-)branch + # add merge_state + # connect last_state -> merge_state + + # ??? How to get the guard-state? (leverage the currently unused state_stack) + # add false_state + # visit children + + # connect self._current_state to merge_state + raise NotImplementedError("TODO: ElseScope not yet implemented") def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add Tasklet to current state From e33a5d9f7c8895a04c5f887e23953ed73295dd32 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 24 Mar 2025 16:10:19 +0100 Subject: [PATCH 036/137] WIP: conditional blocks done --- .../analysis/schedule_tree/tree_to_sdfg.py | 117 +++++++++++------- tests/schedule_tree/to_sdfg_test.py | 56 +++++++-- 2 files changed, 116 insertions(+), 57 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index a0f0383ccd..db88365127 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -42,35 +42,31 @@ def __init__(self) -> None: self._current_state: Optional[SDFGState] = None self.access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} - def _push_state(self, state: SDFGState) -> None: - """Push a state on the stack of states. Set it as self._current_state and setup an access_cache dictionary.""" - self._state_stack.append(state) - self._current_state = state - self.access_cache[state] = {} - def _pop_state(self, label: Optional[str] = None) -> SDFGState: - """Pops the last state from the stack. Replaces `self._current_state` and cleans up the access_cache.""" + """Pops the last state from the stack. + + :param label: ensures the popped state's label starts with the given string + + :return: The popped state. + """ if not self._state_stack: raise ValueError("Can't pop state from empty stack.") popped = self._state_stack.pop() - if label: + if label is not None: assert popped.label.startswith(label) - self._current_state = None if not self._state_stack else self._state_stack[-1] - - del self.access_cache[popped] - return popped def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: - self._push_state(sdfg.add_state(label="tree_root", is_start_block=True)) + self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) self.visit(node.children, sdfg=sdfg) def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: # TODO: When creating a state boundary, include all inter-state assignments that precede it. - self._push_state( - create_state_boundary(node, sdfg, self._current_state, StateBoundaryBehavior.STATE_TRANSITION)) + + self._current_state = create_state_boundary(node, sdfg, self._current_state, + StateBoundaryBehavior.STATE_TRANSITION) def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... @@ -90,36 +86,40 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: before_state = self._current_state - self._push_state(sdfg.add_state(label="loop_guard")) - guard_state = self._current_state + guard_state = sdfg.add_state(label="loop_guard") + self._current_state = guard_state sdfg.add_edge(before_state, self._current_state, InterstateEdge(assignments=dict({node.header.itervar: node.header.init}))) - self._push_state(sdfg.add_state(label="loop_body")) - body_state = self._current_state + body_state = sdfg.add_state(label="loop_body") + self._current_state = body_state sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) + + # visit children inside the loop self.visit(node.children, sdfg=sdfg) sdfg.add_edge(self._current_state, guard_state, InterstateEdge(assignments=dict({node.header.itervar: node.header.update}))) - self._push_state(sdfg.add_state(label="loop_after")) - after_state = self._current_state + after_state = sdfg.add_state(label="loop_after") + self._current_state = after_state sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: before_state = self._current_state - self._push_state(sdfg.add_state(label="guard_state")) - guard_state = self._current_state + guard_state = sdfg.add_state(label="guard_state") + self._current_state = guard_state sdfg.add_edge(before_state, guard_state, InterstateEdge()) - self._push_state(sdfg.add_state(label="loop_body")) - body_state = self._current_state + body_state = sdfg.add_state(label="loop_body") + self._current_state = body_state sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) + + # visit children inside the loop self.visit(node.children, sdfg=sdfg) sdfg.add_edge(self._current_state, guard_state, InterstateEdge()) - self._push_state(sdfg.add_state(label="loop_after")) - after_state = self._current_state + after_state = sdfg.add_state(label="loop_after") + self._current_state = after_state sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: @@ -131,17 +131,41 @@ def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: - # TODO + before_state = self._current_state + # add guard state + guard_state = sdfg.add_state(label="guard_state") + sdfg.add_edge(before_state, guard_state, InterstateEdge()) + # add true_state - # visit children + true_state = sdfg.add_state(label="true_state") + sdfg.add_edge(guard_state, true_state, InterstateEdge(condition=node.condition)) + self._current_state = true_state - # only add merge state and close it if there's no `ElseScope` following - # 1. find this node in node.parent.children - # 2. check if there's an `ElseScope` following this block - # The above two-step process works even if we have nested if statements because - # the nested if statement would be in the children of this node. - raise NotImplementedError("TODO: IfScope not yet implemented") + # visit children in the true branch + self.visit(node.children, sdfg=sdfg) + + # add merge_state + merge_state = sdfg.add_state_after(self._current_state, label="merge_state") + + # Check if there's an `ElseScope` following this node (in the parent's children). + # Filter StateBoundaryNodes, which we inserted earlier, for this analysis. + filtered = [n for n in node.parent.children if not isinstance(n, tn.StateBoundaryNode)] + if_index = filtered.index(node) + has_else_branch = len(filtered) > if_index + 1 and isinstance(filtered[if_index + 1], tn.ElseScope) + + if has_else_branch: + # push merge_state on the stack for later usage in `visit_ElseScope` + self._state_stack.append(merge_state) + false_state = sdfg.add_state(label="false_state") + + sdfg.add_edge(guard_state, false_state, InterstateEdge(condition=f"not {node.condition.as_string}")) + + # push false_state on the stack for later usage in `visit_ElseScope` + self._state_stack.append(false_state) + else: + sdfg.add_edge(guard_state, merge_state, InterstateEdge(condition=f"not {node.condition.as_string}")) + self._current_state = merge_state def visit_StateIfScope(self, node: tn.StateIfScope, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... @@ -160,25 +184,29 @@ def visit_ElifScope(self, node: tn.ElifScope, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: - # TODO - # last_state = self._current_state # the last block of the if-(elif-)branch - # add merge_state - # connect last_state -> merge_state + # get false_state form stack + false_state = self._pop_state("false_state") + self._current_state = false_state - # ??? How to get the guard-state? (leverage the currently unused state_stack) - # add false_state # visit children + self.visit(node.children, sdfg=sdfg) - # connect self._current_state to merge_state - raise NotImplementedError("TODO: ElseScope not yet implemented") + # merge false-branch into merge_state + merge_state = self._pop_state("merge_state") + sdfg.add_edge(self._current_state, merge_state, InterstateEdge()) + self._current_state = merge_state def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add Tasklet to current state tasklet = node.node self._current_state.add_node(tasklet) - # Connect inputs and outputs + # Manage access cache + if not self._current_state in self.access_cache: + self.access_cache[self._current_state] = {} cache = self.access_cache[self._current_state] + + # Connect inputs and outputs for name, memlet in node.in_memlets.items(): # cache read access if memlet.data not in cache: @@ -195,7 +223,6 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # cache write access node (or update an existing one) for read after write cases cache[memlet.data] = access_node - # TODO: create_conditional_block # TODO: create_dataflow_scope StreeToSDFG().visit(stree, sdfg=result) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 1ea2ac15e3..90f036e9f0 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -284,7 +284,6 @@ def test_create_tasklet_waw(): def test_create_for_loop(): - # Manually create a schedule tree # yapf: disable loop=tn.ForScope( children=[ @@ -296,19 +295,16 @@ def test_create_for_loop(): dispatch_state=None, parent=None, last_block=True, guard=None, body=None, init_edges=[] ) ) - stree=tn.ScheduleTreeRoot( - name='tester', - containers={'A': dace.data.Array(dace.float64, [20])}, - children=[loop] - ) # yapf: enable + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': dace.data.Array(dace.float64, [20])}, children=[loop]) + sdfg = stree.as_sdfg() sdfg.validate() def test_create_while_loop(): - # Manually create a schedule tree # yapf: disable loop=tn.WhileScope( children=[ @@ -324,13 +320,47 @@ def test_create_while_loop(): body=None ) ) - stree=tn.ScheduleTreeRoot( - name='tester', - containers={'A': dace.data.Array(dace.float64, [20])}, - children=[loop] - ) # yapf: enable + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': dace.data.Array(dace.float64, [20])}, children=[loop]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_if_else(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot(name="tester", + containers={'A': dace.data.Array(dace.float64, [20])}, + children=[ + tn.IfScope(condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, + {"out": dace.Memlet("A[1]")}), + ]), + tn.ElseScope([ + tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, + {"out": dace.Memlet("A[1]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_if_without_else(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot(name="tester", + containers={'A': dace.data.Array(dace.float64, [20])}, + children=[ + tn.IfScope(condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, + {"out": dace.Memlet("A[1]")}), + ]), + ]) + sdfg = stree.as_sdfg() sdfg.validate() @@ -354,3 +384,5 @@ def test_create_while_loop(): test_create_tasklet_waw() test_create_for_loop() test_create_while_loop() + test_create_if_else() + test_create_if_without_else() From 937dc7937eb3a6b55b232093edaea5ec48d4f185 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 24 Mar 2025 16:44:49 +0100 Subject: [PATCH 037/137] WIP: visit_AssignNode done Added support for symbol assignments --- .../analysis/schedule_tree/tree_to_sdfg.py | 28 ++++++++++++++----- 1 file changed, 21 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index db88365127..2033d16f1a 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -38,8 +38,14 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, class StreeToSDFG(tn.ScheduleNodeVisitor): def __init__(self) -> None: + # state management self._state_stack: List[SDFGState] = [] self._current_state: Optional[SDFGState] = None + + # inter-state symbol assignments + self._interstate_symbols: List[tn.AssignNode] = [] + + # caches self.access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} def _pop_state(self, label: Optional[str] = None) -> SDFGState: @@ -63,10 +69,14 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: self.visit(node.children, sdfg=sdfg) def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: - # TODO: When creating a state boundary, include all inter-state assignments that precede it. + # When creating a state boundary, include all inter-state assignments that precede it. + assignments = {} + for symbol in self._interstate_symbols: + assignments[symbol.name] = symbol.value + self._interstate_symbols.clear() self._current_state = create_state_boundary(node, sdfg, self._current_state, - StateBoundaryBehavior.STATE_TRANSITION) + StateBoundaryBehavior.STATE_TRANSITION, assignments) def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... @@ -81,8 +91,9 @@ def visit_GotoNode(self, node: tn.GotoNode, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: - # TODO: We'll need these symbol assignments - raise NotImplementedError(f"{type(node)} not implemented") + # We just collect them here. They'll be added when state boundaries are added, + # see `visit_StateBoundaryNode()` above. + self._interstate_symbols.append(node) def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: before_state = self._current_state @@ -342,8 +353,11 @@ def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): # SDFG content creation functions -def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowRegion, state: SDFGState, - behavior: StateBoundaryBehavior) -> SDFGState: +def create_state_boundary(bnode: tn.StateBoundaryNode, + sdfg_region: ControlFlowRegion, + state: SDFGState, + behavior: StateBoundaryBehavior, + assignments: Optional[Dict] = None) -> SDFGState: """ Creates a boundary between two states @@ -361,4 +375,4 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, sdfg_region: ControlFlowR # behavior. Fall back to state transition in that case. label = "cf_state_boundary" if bnode.due_to_control_flow else "state_boundary" - return sdfg_region.add_state_after(state, label=label) + return sdfg_region.add_state_after(state, label=label, assignments=assignments) From 38e4bf72ca5386f63f096a320318f2af865e87ed Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 24 Mar 2025 17:13:59 +0100 Subject: [PATCH 038/137] WIP: Add stub for MapScopes and the remaining nodes We need at least MapScopes to be finished before we can do anything meaningful. Other visitors now all raise a "NotImplementedError", which will tell us exactly - over time - what we are missing in the first/second/thrid prototype. --- .../analysis/schedule_tree/tree_to_sdfg.py | 56 +++++++++++++++---- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 2033d16f1a..54f0b6bc12 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -68,16 +68,6 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) self.visit(node.children, sdfg=sdfg) - def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: - # When creating a state boundary, include all inter-state assignments that precede it. - assignments = {} - for symbol in self._interstate_symbols: - assignments[symbol.name] = symbol.value - self._interstate_symbols.clear() - - self._current_state = create_state_boundary(node, sdfg, self._current_state, - StateBoundaryBehavior.STATE_TRANSITION, assignments) - def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... raise NotImplementedError(f"{type(node)} not implemented") @@ -207,6 +197,18 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: sdfg.add_edge(self._current_state, merge_state, InterstateEdge()) self._current_state = merge_state + def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: + # TODO add this + raise NotImplementedError + + def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: + # AFAIK we don't support consume scopes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_PipelineScope(self, node: tn.PipelineScope, sdfg: SDFG) -> None: + # AFAIK we don't support pipeline scopes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add Tasklet to current state tasklet = node.node @@ -234,6 +236,40 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # cache write access node (or update an existing one) for read after write cases cache[memlet.data] = access_node + def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: + # AFAIK we expand all library calls in the gt4py/dace bridge before coming here. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: + # AFAIK we don't support copy nodes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_DynScopeCopyNode(self, node: tn.DynScopeCopyNode, sdfg: SDFG) -> None: + # AFAIK we don't support dyn scope copy nodes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_ViewNode(self, node: tn.ViewNode, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: + # When creating a state boundary, include all inter-state assignments that precede it. + assignments = {} + for symbol in self._interstate_symbols: + assignments[symbol.name] = symbol.value + self._interstate_symbols.clear() + + self._current_state = create_state_boundary(node, sdfg, self._current_state, + StateBoundaryBehavior.STATE_TRANSITION, assignments) + # TODO: create_dataflow_scope StreeToSDFG().visit(stree, sdfg=result) From aae267e13897ce5b7d151a5d9d2ff61e1f46a1be Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 25 Mar 2025 15:58:00 +0100 Subject: [PATCH 039/137] unrelated: format schedule_tree/treenodes.py --- dace/sdfg/analysis/schedule_tree/treenodes.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 453d13d80a..b6aa5e8629 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -377,10 +377,8 @@ def as_string(self, indent: int = 0): footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer else: - result = (indent * INDENTATION + - f'for {loop.init_statement.as_string}; ' + - f'{loop.loop_condition.as_string}; ' + - f'{loop.update_statement.as_string}:\n') + result = (indent * INDENTATION + f'for {loop.init_statement.as_string}; ' + + f'{loop.loop_condition.as_string}; ' + f'{loop.update_statement.as_string}:\n') return result + super().as_string(indent) else: if loop.inverted: @@ -497,14 +495,18 @@ def as_string(self, indent: int = 0): def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return super().input_memlets(root, - propagate={k: v - for k, v in zip(self.node.map.params, self.node.map.range)}, + propagate={ + k: v + for k, v in zip(self.node.map.params, self.node.map.range) + }, **kwargs) def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: return super().output_memlets(root, - propagate={k: v - for k, v in zip(self.node.map.params, self.node.map.range)}, + propagate={ + k: v + for k, v in zip(self.node.map.params, self.node.map.range) + }, **kwargs) From 0477cc7324dc7756d11fe7e9bff86ef5539f1a75 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 25 Mar 2025 16:01:10 +0100 Subject: [PATCH 040/137] WIP: visit_MapScope done This should be okay for a rudimentary prototpye. Let's see if we can run anything with that. --- .../analysis/schedule_tree/tree_to_sdfg.py | 138 ++++++++++++++++-- dace/sdfg/analysis/schedule_tree/treenodes.py | 3 + tests/schedule_tree/to_sdfg_test.py | 62 ++++++++ 3 files changed, 190 insertions(+), 13 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 54f0b6bc12..c0a3f921bb 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,12 +1,13 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict +from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu, utils as sdfg_utils from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge from dace.sdfg.state import SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn from enum import Enum, auto -from typing import Dict, List, Optional, Set +from typing import Dict, Final, List, Optional, Set, Tuple class StateBoundaryBehavior(Enum): @@ -14,6 +15,10 @@ class StateBoundaryBehavior(Enum): EMPTY_MEMLET = auto() #: Happens-before empty memlet edges in the same state +PREFIX_PASSTHROUGH_IN: Final[str] = "IN_" +PREFIX_PASSTHROUGH_OUT: Final[str] = "OUT_" + + def from_schedule_tree(stree: tn.ScheduleTreeRoot, state_boundary_behavior: StateBoundaryBehavior = StateBoundaryBehavior.STATE_TRANSITION) -> SDFG: """ @@ -45,13 +50,16 @@ def __init__(self) -> None: # inter-state symbol assignments self._interstate_symbols: List[tn.AssignNode] = [] + # dataflow scopes + self._dataflow_stack: List[Tuple[nodes.EntryNode, nodes.ExitNode]] = [] + # caches - self.access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} + self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} def _pop_state(self, label: Optional[str] = None) -> SDFGState: - """Pops the last state from the stack. + """Pops the last state from the state stack. - :param label: ensures the popped state's label starts with the given string + :param str, optional label: Ensures the popped state's label starts with the given string. :return: The popped state. """ @@ -64,10 +72,33 @@ def _pop_state(self, label: Optional[str] = None) -> SDFGState: return popped + def _ensure_access_cache(self, state: SDFGState) -> Dict[str, nodes.AccessNode]: + """Ensure an access_cache entry for the given state. + + Checks if there exists an access_cache for `state`. Creates an empty one if it doesn't exist yet. + + :param SDFGState state: The state to check. + + :return: The state's access_cache. + """ + if state not in self._access_cache: + self._access_cache[state] = {} + + return self._access_cache[state] + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: + assert self._current_state is None, "Expected no 'current_state' at root." + assert not self._state_stack, "Expected empty state stack at root." + assert not self._dataflow_stack, "Expected empty dataflow stack at root." + assert not self._interstate_symbols, "Expected empty list of symbols at root." + self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) self.visit(node.children, sdfg=sdfg) + assert not self._state_stack, "Expected empty state stack." + assert not self._dataflow_stack, "Expected empty dataflow stack." + assert not self._interstate_symbols, "Expected empty list of symbols to add." + def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... raise NotImplementedError(f"{type(node)} not implemented") @@ -189,7 +220,7 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: false_state = self._pop_state("false_state") self._current_state = false_state - # visit children + # visit children inside the else branch self.visit(node.children, sdfg=sdfg) # merge false-branch into merge_state @@ -198,8 +229,71 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: self._current_state = merge_state def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: - # TODO add this - raise NotImplementedError + dataflow_stack_size = len(self._dataflow_stack) + outer_map_entry, outer_map_exit = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) + + cache = self._ensure_access_cache(self._current_state) + + # map entry + map_entry = nodes.MapEntry(node.node.map) + self._current_state.add_node(map_entry) + + for memlet in node.input_memlets(): + map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + + if outer_map_entry is not None: + # passthrough if we are inside another map + self._current_state.add_edge(outer_map_entry, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + else: + # add access node "outside the map" and connect to it + if memlet.data not in cache: + # cache read access + cache[memlet.data] = self._current_state.add_read(memlet.data) + + self._current_state.add_edge(cache[memlet.data], None, map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + + # Add empty memlet if outer_map_entry has no out_connectors to connect to + if outer_map_entry is not None and not outer_map_entry.out_connectors and self._current_state.out_degree( + outer_map_entry) < 1: + self._current_state.add_edge(outer_map_entry, None, map_entry, None, memlet=Memlet()) + + # map exit + map_exit = nodes.MapExit(node.node.map) + self._current_state.add_node(map_exit) + + for memlet in node.output_memlets(): + map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + + if outer_map_exit: + # passthrough if we are inside another map + self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", outer_map_exit, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + else: + # add access nodes "outside the map" and connect to it + # we always write to a new access_node + access_node = self._current_state.add_write(memlet.data) + self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", access_node, None, + memlet) + + # cache write access node (or update an existing one) for read after write cases + cache[memlet.data] = access_node + + # Add empty memlet if outer_map_exit has no in_connectors to connect to + if outer_map_exit is not None and not outer_map_exit.in_connectors and self._current_state.in_degree( + outer_map_exit) < 1: + self._current_state.add_edge(map_exit, None, outer_map_exit, None, memlet=Memlet()) + + self._dataflow_stack.append((map_entry, map_exit)) + + # visit children inside the map + self.visit(node.children, sdfg=sdfg) + + self._dataflow_stack.pop() + assert len(self._dataflow_stack) == dataflow_stack_size # sanity check def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: # AFAIK we don't support consume scopes in the gt4py/dace bridge. @@ -214,13 +308,17 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: tasklet = node.node self._current_state.add_node(tasklet) - # Manage access cache - if not self._current_state in self.access_cache: - self.access_cache[self._current_state] = {} - cache = self.access_cache[self._current_state] + cache = self._ensure_access_cache(self._current_state) + map_entry, map_exit = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) - # Connect inputs and outputs + # Connect input memlets for name, memlet in node.in_memlets.items(): + # connect to dataflow_stack (if applicable) + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}" + if map_entry is not None and connector_name in map_entry.out_connectors: + self._current_state.add_edge(map_entry, connector_name, tasklet, name, memlet) + continue + # cache read access if memlet.data not in cache: cache[memlet.data] = self._current_state.add_read(memlet.data) @@ -228,7 +326,18 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: access_node = cache[memlet.data] self._current_state.add_memlet_path(access_node, tasklet, dst_conn=name, memlet=memlet) + # Add empty memlet if map_entry has no out_connectors to connect to + if map_entry is not None and not map_entry.out_connectors and self._current_state.out_degree(map_entry) < 1: + self._current_state.add_edge(map_entry, None, tasklet, None, memlet=Memlet()) + + # Connect output memlets for name, memlet in node.out_memlets.items(): + # connect to dataflow_stack (if applicable) + connector_name = f"{PREFIX_PASSTHROUGH_IN}{memlet.data}" + if map_exit is not None and connector_name in map_exit.in_connectors: + self._current_state.add_edge(tasklet, name, map_exit, connector_name, memlet) + continue + # we always write to a new access_node access_node = self._current_state.add_write(memlet.data) self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) @@ -236,6 +345,10 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # cache write access node (or update an existing one) for read after write cases cache[memlet.data] = access_node + # Add empty memlet if map_exit has no in_connectors to connect to + if map_exit is not None and not map_exit.in_connectors and self._current_state.in_degree(map_exit) < 1: + self._current_state.add_edge(tasklet, None, map_exit, None, memlet=Memlet()) + def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: # AFAIK we expand all library calls in the gt4py/dace bridge before coming here. raise NotImplementedError(f"{type(node)} not implemented") @@ -270,7 +383,6 @@ def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> Non self._current_state = create_state_boundary(node, sdfg, self._current_state, StateBoundaryBehavior.STATE_TRANSITION, assignments) - # TODO: create_dataflow_scope StreeToSDFG().visit(stree, sdfg=result) return result diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index b6aa5e8629..ad3240621d 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -487,6 +487,7 @@ class MapScope(DataflowScope): """ Map scope. """ + node: nodes.MapEntry def as_string(self, indent: int = 0): rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) @@ -515,6 +516,7 @@ class ConsumeScope(DataflowScope): """ Consume scope. """ + node: nodes.ConsumeEntry def as_string(self, indent: int = 0): node: nodes.ConsumeEntry = self.node @@ -528,6 +530,7 @@ class PipelineScope(MapScope): """ Pipeline scope. """ + node: nodes.PipelineEntry def as_string(self, indent: int = 0): rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 90f036e9f0..5011bf20de 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -3,6 +3,7 @@ Tests components in conversion of schedule trees to SDFGs. """ import dace +from dace import subsets as sbs from dace.codegen import control_flow as cf from dace.properties import CodeBlock from dace.sdfg import nodes @@ -365,6 +366,64 @@ def test_create_if_without_else(): sdfg.validate() +def test_create_map_scope_write(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot(name="tester", + containers={'A': dace.data.Array(dace.float64, [20])}, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", + sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("asdf", {}, {"out"}, "out = i"), {}, + {"out": dace.Memlet("A[i]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_map_scope_copy(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot(name="tester", + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", + sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("copy", {"inp"}, {"out"}, "out = inp"), + {"inp": dace.Memlet("A[i]")}, + {"out": dace.Memlet("B[i]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_nested_map_scope(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': dace.data.Array(dace.float64, [20])}, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("blub", "j", sbs.Range.from_string("0:10"))), + children=[ + tn.TaskletNode(nodes.Tasklet("asdf", {}, {"out"}, "out = i*10+j"), {}, + {"out": dace.Memlet("A[i*10+j]")}) + ]) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -386,3 +445,6 @@ def test_create_if_without_else(): test_create_while_loop() test_create_if_else() test_create_if_without_else() + test_create_map_scope_write() + test_create_map_scope_copy() + test_create_nested_map_scope() From b6a5958c83238fe20929fe311948150f3d787c35 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 26 Mar 2025 17:57:26 +0100 Subject: [PATCH 041/137] WIP: use nested sdfg inside map for state machine --- .../analysis/schedule_tree/tree_to_sdfg.py | 117 +++++++++++++++++- dace/sdfg/memlet_utils.py | 3 + tests/schedule_tree/to_sdfg_test.py | 20 +++ 3 files changed, 134 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index c0a3f921bb..decdf5fc2c 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -37,15 +37,15 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, result.constants_prop = copy.deepcopy(stree.constants) result.symbols = copy.deepcopy(stree.symbols) - # TODO: Fill SDFG contents - stree = insert_state_boundaries_to_tree(stree) # after WAW, before label, etc. + # after WAW, before label, etc. + stree = insert_state_boundaries_to_tree(stree) class StreeToSDFG(tn.ScheduleNodeVisitor): - def __init__(self) -> None: + def __init__(self, start_state: Optional[SDFGState] = None) -> None: # state management self._state_stack: List[SDFGState] = [] - self._current_state: Optional[SDFGState] = None + self._current_state = start_state # inter-state symbol assignments self._interstate_symbols: List[tn.AssignNode] = [] @@ -228,10 +228,9 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: sdfg.add_edge(self._current_state, merge_state, InterstateEdge()) self._current_state = merge_state - def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: + def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) outer_map_entry, outer_map_exit = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) - cache = self._ensure_access_cache(self._current_state) # map entry @@ -295,6 +294,112 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._dataflow_stack.pop() assert len(self._dataflow_stack) == dataflow_stack_size # sanity check + def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: + inputs = node.input_memlets() + outputs = node.output_memlets() + + # setup nested SDFG + nsdfg = SDFG("nested_sdfg", parent=self._current_state) + start_state = nsdfg.add_state("nested_root", is_start_block=True) + for memlet in inputs: + nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) + for memlet in outputs: + nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) + + # visit children inside nested SDFG + inner_visitor = StreeToSDFG(start_state) + for child in node.children: + inner_visitor.visit(child, sdfg=nsdfg) + + nested_SDFG = self._current_state.add_nested_sdfg(nsdfg, + sdfg, + inputs={memlet.data + for memlet in node.input_memlets()}, + outputs={memlet.data + for memlet in node.output_memlets()}) + + dataflow_stack_size = len(self._dataflow_stack) + outer_map_entry, outer_map_exit = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) + cache = self._ensure_access_cache(self._current_state) + + # map entry + map_entry = nodes.MapEntry(node.node.map) + self._current_state.add_node(map_entry) + + for memlet in inputs: + map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + + # connect nested SDFG to map scope + self._current_state.add_edge(map_entry, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", nested_SDFG, + memlet.data, Memlet.from_memlet(memlet)) + + # connect map scope to "outer world" + if outer_map_entry is not None: + # passthrough if we are inside another map + self._current_state.add_edge(outer_map_entry, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + else: + # add access node "outside the map" and connect to it + if memlet.data not in cache: + # cache read access + cache[memlet.data] = self._current_state.add_read(memlet.data) + + self._current_state.add_edge(cache[memlet.data], None, map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + + # Add empty memlet if no explicit connection from map_entry to nested_SDFG has been done so far + if not inputs: + self._current_state.add_edge(map_entry, None, nested_SDFG, None, memlet=Memlet()) + + # Add empty memlet if outer_map_entry has no out_connectors to connect to + if outer_map_entry is not None and not outer_map_entry.out_connectors and self._current_state.out_degree( + outer_map_entry) < 1: + self._current_state.add_edge(outer_map_entry, None, map_entry, None, memlet=Memlet()) + + # map exit + map_exit = nodes.MapExit(node.node.map) + self._current_state.add_node(map_exit) + + for memlet in outputs: + map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + + # connect nested SDFG to map scope + self._current_state.add_edge(nested_SDFG, memlet.data, map_exit, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", Memlet.from_memlet(memlet)) + + # connect map scope to "outer world" + if outer_map_exit: + # passthrough if we are inside another map + self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", outer_map_exit, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + else: + # add access nodes "outside the map" and connect to it + # we always write to a new access_node + access_node = self._current_state.add_write(memlet.data) + self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", access_node, None, + memlet) + + # cache write access node (or update an existing one) for read after write cases + cache[memlet.data] = access_node + + # Add empty memlet if no explicit connection from map_entry to nested_SDFG has been done so far + if not outputs: + self._current_state.add_edge(nested_SDFG, None, map_exit, None, memlet=Memlet()) + + # Add empty memlet if outer_map_exit has no in_connectors to connect to + if outer_map_exit is not None and not outer_map_exit.in_connectors and self._current_state.in_degree( + outer_map_exit) < 1: + self._current_state.add_edge(map_exit, None, outer_map_exit, None, memlet=Memlet()) + + def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: + if any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): + # support multiple states within this map by inserting a nested SDFG + return self._generate_MapScope_with_nested_SDFG(node, sdfg) + + self._generate_MapScope(node, sdfg) + def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: # AFAIK we don't support consume scopes in the gt4py/dace bridge. raise NotImplementedError(f"{type(node)} not implemented") diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 65b34db6f4..6cc1354b71 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -105,6 +105,9 @@ def __iter__(self): for subset in self.internal_set.values(): yield from subset + def __len__(self): + return len(self.internal_set) + def update(self, *iterable: Iterable[Memlet]): """ Updates set of memlets via union of existing ranges. diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 5011bf20de..0836073c73 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -424,6 +424,25 @@ def test_create_nested_map_scope(): sdfg.validate() +def test_map_with_two_tasklets(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot(name="tester", + containers={'A': dace.data.Array(dace.float64, [20])}, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", + sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = i'), {}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2*i'), {}, + {'out': dace.Memlet('A[1]')}), + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -448,3 +467,4 @@ def test_create_nested_map_scope(): test_create_map_scope_write() test_create_map_scope_copy() test_create_nested_map_scope() + test_map_with_two_tasklets() From a6235835730dfb218c5c2bad04d305923f49c4bd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 28 Mar 2025 14:41:41 +0100 Subject: [PATCH 042/137] Add options to simplify/validate generated SDFG When converting (back) from schedule tree to SDFG, allow options to validate and/or simplify the generated SDFG. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index ad3240621d..b51e35f286 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -192,9 +192,27 @@ class ScheduleTreeRoot(ScheduleTreeScope): callback_mapping: Dict[str, str] = field(default_factory=dict) arg_names: List[str] = field(default_factory=list) - def as_sdfg(self) -> SDFG: + def as_sdfg(self, validate: bool = True, simplify: bool = True) -> SDFG: + """ + Convert this schedule tree representation (back) into an SDFG. + + :param validate: If true, validate generated SDFG. + :param simplify: If true, simplify generated SDFG. The conversion might insert things like extra + empty states that can be cleaned up automatically. The value of `validate` is + passed on to `simplify()`. + + :return: SDFG version of this schedule tree. + """ from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s # Avoid import loop - return t2s.from_schedule_tree(self) + sdfg = t2s.from_schedule_tree(self) + + if validate: + sdfg.validate() + + if simplify: + sdfg.simplify(validate=validate) + + return sdfg def get_root(self) -> 'ScheduleTreeRoot': return self From 9ec322178e8a5955386cf4cac02f822cd28fa3bc Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 9 Apr 2025 09:17:29 +0200 Subject: [PATCH 043/137] Unrelated: fix typo --- dace/sdfg/sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 94b5a5f1b3..3db8daf3a4 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -257,7 +257,7 @@ def used_symbols(self, all_symbols: bool) -> Set[str]: # - condition = 'i < 10', assignments = {'i': '3'} # - assignments = {'j': 'i + 1', 'i': '3'} # The new algorithm below addresses the issue by iterating over the edge's condition and assignments and - # exlcuding keys from being considered "defined" if they have been already read. + # excluding keys from being considered "defined" if they have been already read. # Symbols in conditions are always free, because the condition is executed before the assignments cond_symbols = set(map(str, dace.symbolic.symbols_in_ast(self.condition.code[0]))) From 214eef172ca3ae426506c8e6020e062932cd4021 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 9 Apr 2025 09:50:32 +0200 Subject: [PATCH 044/137] Fix reading twice from the same memlet in map scope --- .../analysis/schedule_tree/tree_to_sdfg.py | 10 ++++---- tests/schedule_tree/to_sdfg_test.py | 23 +++++++++++++++++++ 2 files changed, 29 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index decdf5fc2c..27158312d8 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -251,8 +251,9 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # cache read access cache[memlet.data] = self._current_state.add_read(memlet.data) - self._current_state.add_edge(cache[memlet.data], None, map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + if not self._current_state.edges_between(cache[memlet.data], map_entry): + self._current_state.add_edge(cache[memlet.data], None, map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) # Add empty memlet if outer_map_entry has no out_connectors to connect to if outer_map_entry is not None and not outer_map_entry.out_connectors and self._current_state.out_degree( @@ -345,8 +346,9 @@ def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> # cache read access cache[memlet.data] = self._current_state.add_read(memlet.data) - self._current_state.add_edge(cache[memlet.data], None, map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + if not self._current_state.edges_between(cache[memlet.data], map_entry): + self._current_state.add_edge(cache[memlet.data], None, map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) # Add empty memlet if no explicit connection from map_entry to nested_SDFG has been done so far if not inputs: diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 0836073c73..105c35251c 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -404,6 +404,28 @@ def test_create_map_scope_copy(): sdfg.validate() +def test_create_map_scope_double_memlet(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:10"))), + children=[ + tn.TaskletNode(nodes.Tasklet("sum", {"first", "second"}, {"out"}, "out = first + second"), { + "first": dace.Memlet("A[i]"), + "second": dace.Memlet("A[i+10]") + }, {"out": dace.Memlet("B[i]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + def test_create_nested_map_scope(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( @@ -466,5 +488,6 @@ def test_map_with_two_tasklets(): test_create_if_without_else() test_create_map_scope_write() test_create_map_scope_copy() + test_create_map_scope_double_memlet() test_create_nested_map_scope() test_map_with_two_tasklets() From a864245d937af41df1079e9e6d31f7c1a3e142e0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:44:36 +0200 Subject: [PATCH 045/137] Fix: don't duplicate data descriptors --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 27158312d8..b5c32146bd 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -302,10 +302,9 @@ def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> # setup nested SDFG nsdfg = SDFG("nested_sdfg", parent=self._current_state) start_state = nsdfg.add_state("nested_root", is_start_block=True) - for memlet in inputs: - nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) - for memlet in outputs: - nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) + for memlet in [*inputs, *outputs]: + if memlet.data not in nsdfg.arrays: + nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) # visit children inside nested SDFG inner_visitor = StreeToSDFG(start_state) From c31c42063b1620b7a24bee8df99c2ebfb9746918 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:48:01 +0200 Subject: [PATCH 046/137] Fix: find if scope in parent's children with is operator --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index b5c32146bd..b9f46de02c 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -183,7 +183,7 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: # Check if there's an `ElseScope` following this node (in the parent's children). # Filter StateBoundaryNodes, which we inserted earlier, for this analysis. filtered = [n for n in node.parent.children if not isinstance(n, tn.StateBoundaryNode)] - if_index = filtered.index(node) + if_index = _list_index(filtered, node) has_else_branch = len(filtered) > if_index + 1 and isinstance(filtered[if_index + 1], tn.ElseScope) if has_else_branch: @@ -630,3 +630,15 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, label = "cf_state_boundary" if bnode.due_to_control_flow else "state_boundary" return sdfg_region.add_state_after(state, label=label, assignments=assignments) + + +def _list_index(list: List[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> int: + """Check if node is in list with "is" operator.""" + index = 0 + for element in list: + # compare with "is" to get memory comparison. ".index()" uses value comparison + if element is node: + return index + index += 1 + + raise StopIteration From c0a9db576df04a57558862cedc8422abb4ddfd43 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 9 Apr 2025 10:51:18 +0200 Subject: [PATCH 047/137] Fix: symbol assignment on edges --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index b9f46de02c..7fb35056c4 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -483,7 +483,7 @@ def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> Non # When creating a state boundary, include all inter-state assignments that precede it. assignments = {} for symbol in self._interstate_symbols: - assignments[symbol.name] = symbol.value + assignments[symbol.name] = symbol.value.as_string self._interstate_symbols.clear() self._current_state = create_state_boundary(node, sdfg, self._current_state, From 8588a42c4d159d1c3b8877f196bce7f5fd190ead Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 9 Apr 2025 12:11:28 +0200 Subject: [PATCH 048/137] Fix passthrough connectors for multiple reads --- .../analysis/schedule_tree/tree_to_sdfg.py | 42 ++++++++++++------- tests/schedule_tree/to_sdfg_test.py | 27 ++++++++++++ 2 files changed, 55 insertions(+), 14 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 7fb35056c4..84dcb830d0 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -238,8 +238,12 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_node(map_entry) for memlet in node.input_memlets(): - map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + assert new_in_connector == new_out_connector + + if not new_in_connector: + continue if outer_map_entry is not None: # passthrough if we are inside another map @@ -251,9 +255,8 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # cache read access cache[memlet.data] = self._current_state.add_read(memlet.data) - if not self._current_state.edges_between(cache[memlet.data], map_entry): - self._current_state.add_edge(cache[memlet.data], None, map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + self._current_state.add_edge(cache[memlet.data], None, map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) # Add empty memlet if outer_map_entry has no out_connectors to connect to if outer_map_entry is not None and not outer_map_entry.out_connectors and self._current_state.out_degree( @@ -265,8 +268,12 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_node(map_exit) for memlet in node.output_memlets(): - map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + new_in_connector = map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + assert new_in_connector == new_out_connector + + if not new_in_connector: + continue if outer_map_exit: # passthrough if we are inside another map @@ -327,8 +334,12 @@ def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> self._current_state.add_node(map_entry) for memlet in inputs: - map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + assert new_in_connector == new_out_connector + + if not new_in_connector: + continue # connect nested SDFG to map scope self._current_state.add_edge(map_entry, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", nested_SDFG, @@ -345,9 +356,8 @@ def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> # cache read access cache[memlet.data] = self._current_state.add_read(memlet.data) - if not self._current_state.edges_between(cache[memlet.data], map_entry): - self._current_state.add_edge(cache[memlet.data], None, map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + self._current_state.add_edge(cache[memlet.data], None, map_entry, + f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) # Add empty memlet if no explicit connection from map_entry to nested_SDFG has been done so far if not inputs: @@ -363,8 +373,12 @@ def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> self._current_state.add_node(map_exit) for memlet in outputs: - map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + new_in_connector = map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + assert new_in_connector == new_out_connector + + if not new_in_connector: + continue # connect nested SDFG to map scope self._current_state.add_edge(nested_SDFG, memlet.data, map_exit, diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 105c35251c..b34d53f410 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -446,6 +446,32 @@ def test_create_nested_map_scope(): sdfg.validate() +def test_create_nested_map_scope_multi_read(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [10]) + }, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("blub", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.TaskletNode( + nodes.Tasklet("asdf", {"a_1", "a_2"}, {"out"}, "out = a_1 + a_2"), { + "a_1": dace.Memlet("A[i*5+j]"), + "a_2": dace.Memlet("A[10+i*5+j]"), + }, {"out": dace.Memlet("B[i*5+j]")}) + ]) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + def test_map_with_two_tasklets(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", @@ -490,4 +516,5 @@ def test_map_with_two_tasklets(): test_create_map_scope_copy() test_create_map_scope_double_memlet() test_create_nested_map_scope() + test_create_nested_map_scope_multi_read() test_map_with_two_tasklets() From 3d3f9c44aeacd4c843021a9f1fdd2598d3cb1009 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 10 Apr 2025 17:38:57 +0200 Subject: [PATCH 049/137] WIP (not working): iterating on map creation challenge is where/when to add access nodes to ensure that read after write situations are correctly captured. - seems to be fine for simple maps - might work for nested maps - needs work to fix maps with nested sdfgs (e.g. state boundaries) --- .../analysis/schedule_tree/tree_to_sdfg.py | 176 +++++++++++------- tests/schedule_tree/to_sdfg_test.py | 36 ++++ 2 files changed, 142 insertions(+), 70 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 84dcb830d0..b4dc9aa64f 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -51,7 +51,7 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._interstate_symbols: List[tn.AssignNode] = [] # dataflow scopes - self._dataflow_stack: List[Tuple[nodes.EntryNode, nodes.ExitNode]] = [] + self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]]] = [] # caches self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} @@ -230,77 +230,109 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) - outer_map_entry, outer_map_exit = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) - cache = self._ensure_access_cache(self._current_state) # map entry + # --------- map_entry = nodes.MapEntry(node.node.map) self._current_state.add_node(map_entry) + self._dataflow_stack.append((map_entry, dict())) - for memlet in node.input_memlets(): - new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - new_out_connector = map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") - assert new_in_connector == new_out_connector + # keep a copy of the access cache + access_cache = self._ensure_access_cache(self._current_state) - if not new_in_connector: + # clear the access_cache before visiting children such that they have their + # own access cache (per map scope) + self._access_cache[self._current_state].clear() + + # visit children inside the map + self.visit(node.children, sdfg=sdfg) + _, to_connect = self._dataflow_stack.pop() + + # reset the access_cache + self._access_cache[self._current_state] = access_cache + + assert len(self._dataflow_stack) == dataflow_stack_size + outer_map_entry, outer_to_connect = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) + + # connect potential input connectors on map_entry + input_memlets = node.input_memlets() + for connector in map_entry.in_connectors: + memlet_data = connector.removeprefix(PREFIX_PASSTHROUGH_IN) + # find input memlet + memlets = [memlet for memlet in input_memlets if memlet.data == memlet_data] + assert len(memlets) == 1 + + # connect to local access node (if available) + if memlet_data in access_cache: + cached_access = access_cache[memlet_data] + self._current_state.add_memlet_path(cached_access, + map_entry, + dst_conn=connector, + memlet=input_memlets[0]) continue if outer_map_entry is not None: - # passthrough if we are inside another map - self._current_state.add_edge(outer_map_entry, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) + # get it from outside the map + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet_data}" + if connector_name not in outer_map_entry.out_connectors: + new_in_connector = outer_map_entry.add_in_connector(connector) + new_out_connector = outer_map_entry.add_out_connector(connector_name) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, memlets[0]) else: - # add access node "outside the map" and connect to it - if memlet.data not in cache: - # cache read access - cache[memlet.data] = self._current_state.add_read(memlet.data) + # cache local read access + assert memlet_data not in access_cache + access_cache[memlet_data] = self._current_state.add_read(memlet_data) + cached_access = access_cache[memlet_data] + self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, memlet=memlets[0]) - self._current_state.add_edge(cache[memlet.data], None, map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) - - # Add empty memlet if outer_map_entry has no out_connectors to connect to - if outer_map_entry is not None and not outer_map_entry.out_connectors and self._current_state.out_degree( - outer_map_entry) < 1: + if outer_map_entry is not None and self._current_state.out_degree(outer_map_entry) < 1: self._current_state.add_edge(outer_map_entry, None, map_entry, None, memlet=Memlet()) - # map exit + # map_exit + # -------- map_exit = nodes.MapExit(node.node.map) self._current_state.add_node(map_exit) - for memlet in node.output_memlets(): - new_in_connector = map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - new_out_connector = map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") + # connect writes to map_exit node + output_memlets = node.output_memlets() + for name in to_connect: + in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" + out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" + new_in_connector = map_exit.add_in_connector(in_connector_name) + new_out_connector = map_exit.add_out_connector(out_connector_name) assert new_in_connector == new_out_connector - if not new_in_connector: - continue + # connect "inside the map" + access_node, memlet = to_connect[name] + self._current_state.add_memlet_path(access_node, + map_exit, + dst_conn=in_connector_name, + memlet=copy.deepcopy(memlet)) - if outer_map_exit: - # passthrough if we are inside another map - self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", outer_map_exit, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) - else: - # add access nodes "outside the map" and connect to it - # we always write to a new access_node - access_node = self._current_state.add_write(memlet.data) - self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", access_node, None, - memlet) - - # cache write access node (or update an existing one) for read after write cases - cache[memlet.data] = access_node + # connect "outside the map" + # find output memlet + memlets = [memlet for memlet in output_memlets if memlet.data == name] + assert len(memlets) == 1 - # Add empty memlet if outer_map_exit has no in_connectors to connect to - if outer_map_exit is not None and not outer_map_exit.in_connectors and self._current_state.in_degree( - outer_map_exit) < 1: - self._current_state.add_edge(map_exit, None, outer_map_exit, None, memlet=Memlet()) + access_node = self._current_state.add_write(name) + self._current_state.add_memlet_path(map_exit, + access_node, + src_conn=out_connector_name, + memlet=memlets[0]) - self._dataflow_stack.append((map_entry, map_exit)) + # cache write access into access_cache + access_cache[name] = access_node - # visit children inside the map - self.visit(node.children, sdfg=sdfg) + if outer_to_connect is not None: + outer_to_connect[name] = (access_node, memlets[0]) - self._dataflow_stack.pop() - assert len(self._dataflow_stack) == dataflow_stack_size # sanity check + # TODO If nothing is connected at this point, figure out what's the last thing that + # we should connect to. Then, add an empty memlet from that last thing to this + # map_exit. + assert len(self._current_state.in_edges(map_exit)) > 0 def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: inputs = node.input_memlets() @@ -411,7 +443,8 @@ def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: if any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): # support multiple states within this map by inserting a nested SDFG - return self._generate_MapScope_with_nested_SDFG(node, sdfg) + # return self._generate_MapScope_with_nested_SDFG(node, sdfg) + raise NotImplementedError("todo") self._generate_MapScope(node, sdfg) @@ -429,35 +462,39 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: self._current_state.add_node(tasklet) cache = self._ensure_access_cache(self._current_state) - map_entry, map_exit = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) + map_entry, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) # Connect input memlets for name, memlet in node.in_memlets.items(): - # connect to dataflow_stack (if applicable) - connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}" - if map_entry is not None and connector_name in map_entry.out_connectors: - self._current_state.add_edge(map_entry, connector_name, tasklet, name, memlet) + # connect to local access node if possible + if memlet.data in cache: + cached_access = cache[memlet.data] + self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) continue - # cache read access - if memlet.data not in cache: - cache[memlet.data] = self._current_state.add_read(memlet.data) + if map_entry is not None: + # get it from outside the map + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}" + if connector_name not in map_entry.out_connectors: + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = map_entry.add_out_connector(connector_name) + assert new_in_connector == True + assert new_in_connector == new_out_connector - access_node = cache[memlet.data] - self._current_state.add_memlet_path(access_node, tasklet, dst_conn=name, memlet=memlet) + self._current_state.add_edge(map_entry, connector_name, tasklet, name, memlet) + else: + # cache local read access + assert memlet.data not in cache + cache[memlet.data] = self._current_state.add_read(memlet.data) + cached_access = cache[memlet.data] + self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) # Add empty memlet if map_entry has no out_connectors to connect to - if map_entry is not None and not map_entry.out_connectors and self._current_state.out_degree(map_entry) < 1: + if map_entry is not None and self._current_state.out_degree(map_entry) < 1: self._current_state.add_edge(map_entry, None, tasklet, None, memlet=Memlet()) # Connect output memlets for name, memlet in node.out_memlets.items(): - # connect to dataflow_stack (if applicable) - connector_name = f"{PREFIX_PASSTHROUGH_IN}{memlet.data}" - if map_exit is not None and connector_name in map_exit.in_connectors: - self._current_state.add_edge(tasklet, name, map_exit, connector_name, memlet) - continue - # we always write to a new access_node access_node = self._current_state.add_write(memlet.data) self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) @@ -465,9 +502,8 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # cache write access node (or update an existing one) for read after write cases cache[memlet.data] = access_node - # Add empty memlet if map_exit has no in_connectors to connect to - if map_exit is not None and not map_exit.in_connectors and self._current_state.in_degree(map_exit) < 1: - self._current_state.add_edge(tasklet, None, map_exit, None, memlet=Memlet()) + if to_connect is not None: + to_connect[memlet.data] = (access_node, memlet) def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: # AFAIK we expand all library calls in the gt4py/dace bridge before coming here. diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index b34d53f410..6c481a01f4 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -383,6 +383,28 @@ def test_create_map_scope_write(): sdfg.validate() +def test_create_map_scope_read_after_write(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("write", {}, {"out"}, "out = i"), {}, + {"out": dace.Memlet("B[i]")}), + tn.TaskletNode(nodes.Tasklet("read", {"in_field"}, {"out_field"}, "out_field = in_field"), + {"in_field": dace.Memlet("B[i]")}, {"out_field": dace.Memlet("A[i]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + def test_create_map_scope_copy(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", @@ -404,6 +426,8 @@ def test_create_map_scope_copy(): sdfg.validate() +# TODO: restart testing here +# TODO 2: find an automatic way to test stuff here def test_create_map_scope_double_memlet(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( @@ -491,6 +515,18 @@ def test_map_with_two_tasklets(): sdfg.validate() +def test_xppm_tmp(): + loaded = dace.SDFG.from_file("test.sdfgz") + stree = loaded.as_schedule_tree() + + # TODO + # - fix missing data dependency with "al" + # - fix read after write issue + + sdfg = stree.as_sdfg() + sdfg.validate() + + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() From 5b6c1b6d54c4ecd5b80d6d9dbb4d260eec319ba6 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 11 Apr 2025 10:21:23 +0200 Subject: [PATCH 050/137] WIP: maps without state boundaries are now good again --- .../analysis/schedule_tree/tree_to_sdfg.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index b4dc9aa64f..2e0aa26ae7 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict +from dace import subsets from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu, utils as sdfg_utils from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge @@ -260,15 +261,21 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: memlet_data = connector.removeprefix(PREFIX_PASSTHROUGH_IN) # find input memlet memlets = [memlet for memlet in input_memlets if memlet.data == memlet_data] - assert len(memlets) == 1 + assert len(memlets) > 0 + memlet = copy.deepcopy(memlets[0]) + if len(memlets) > 1: + # merge memlets + for index, element in enumerate(memlets): + if index == 0: + continue + memlet.subset = subsets.union(memlet.subset, element.subset) + # TODO(later): figure out the volume thing (also in MemletSet). Also: num_accesses (for legacy reasons) + memlet.volume += element.volume # connect to local access node (if available) if memlet_data in access_cache: cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, - map_entry, - dst_conn=connector, - memlet=input_memlets[0]) + self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, memlet=memlet) continue if outer_map_entry is not None: @@ -280,13 +287,13 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert new_in_connector == True assert new_in_connector == new_out_connector - self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, memlets[0]) + self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, memlet) else: # cache local read access assert memlet_data not in access_cache access_cache[memlet_data] = self._current_state.add_read(memlet_data) cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, memlet=memlets[0]) + self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, memlet=memlet) if outer_map_entry is not None and self._current_state.out_degree(outer_map_entry) < 1: self._current_state.add_edge(outer_map_entry, None, map_entry, None, memlet=Memlet()) From 4922784a19524b4508963b2e26ed92f7eabefa6f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 11 Apr 2025 15:01:08 +0200 Subject: [PATCH 051/137] WIP: working map scopes with state boundaries next step is to investigate the race condition issue with state variable assignments. --- .../analysis/schedule_tree/tree_to_sdfg.py | 197 +++++++----------- tests/schedule_tree/to_sdfg_test.py | 13 +- 2 files changed, 78 insertions(+), 132 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 2e0aa26ae7..cddd131a8c 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -52,7 +52,8 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._interstate_symbols: List[tn.AssignNode] = [] # dataflow scopes - self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]]] = [] + self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode | nodes.NestedSDFG, + Memlet]]]] = [] # caches self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} @@ -229,7 +230,58 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: sdfg.add_edge(self._current_state, merge_state, InterstateEdge()) self._current_state = merge_state - def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: + def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: + inputs = node.input_memlets() + outputs = node.output_memlets() + + # setup nested SDFG + nsdfg = SDFG("nested_sdfg", parent=self._current_state) + start_state = nsdfg.add_state("nested_root", is_start_block=True) + for memlet in [*inputs, *outputs]: + if memlet.data not in nsdfg.arrays: + nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if sdfg.arrays[memlet.data].transient: + nsdfg.arrays[memlet.data].transient = False + + # visit children inside nested SDFG + inner_visitor = StreeToSDFG(start_state) + for child in node.children: + inner_visitor.visit(child, sdfg=nsdfg) + + nested_SDFG = self._current_state.add_nested_sdfg(nsdfg, + sdfg, + inputs={memlet.data + for memlet in inputs}, + outputs={memlet.data + for memlet in outputs}) + + assert self._dataflow_stack + map_entry, to_connect = self._dataflow_stack[-1] + + # connect input memlets + for memlet in inputs: + # get it from outside the map + array_name = memlet.data + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{array_name}" + if connector_name not in map_entry.out_connectors: + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{array_name}") + new_out_connector = map_entry.add_out_connector(connector_name) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(map_entry, connector_name, nested_SDFG, array_name, memlet) + + # Add empty memlet if we didn't add any in the loop above + if self._current_state.out_degree(map_entry) < 1: + self._current_state.add_edge(map_entry, None, nested_SDFG, None, memlet=Memlet()) + + # connect output memlets + for memlet in outputs: + to_connect[memlet.data] = (nested_SDFG, memlet) + + def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) # map entry @@ -241,17 +293,22 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # keep a copy of the access cache access_cache = self._ensure_access_cache(self._current_state) - # clear the access_cache before visiting children such that they have their - # own access cache (per map scope) - self._access_cache[self._current_state].clear() + # Set a new access_cache before visiting children such that they have their + # own access cache (per map scope). + self._access_cache[self._current_state] = {} # visit children inside the map - self.visit(node.children, sdfg=sdfg) - _, to_connect = self._dataflow_stack.pop() + if any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): + # to the funky stuff + self._insert_nestedSDFG(node, sdfg) + else: + self.visit(node.children, sdfg=sdfg) # reset the access_cache self._access_cache[self._current_state] = access_cache + # dataflow stack management + _, to_connect = self._dataflow_stack.pop() assert len(self._dataflow_stack) == dataflow_stack_size outer_map_entry, outer_to_connect = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) @@ -314,10 +371,14 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect "inside the map" access_node, memlet = to_connect[name] - self._current_state.add_memlet_path(access_node, - map_exit, - dst_conn=in_connector_name, - memlet=copy.deepcopy(memlet)) + if isinstance(access_node, nodes.NestedSDFG): + self._current_state.add_edge(access_node, name, map_exit, in_connector_name, copy.deepcopy(memlet)) + else: + assert isinstance(access_node, nodes.AccessNode) + self._current_state.add_memlet_path(access_node, + map_exit, + dst_conn=in_connector_name, + memlet=copy.deepcopy(memlet)) # connect "outside the map" # find output memlet @@ -341,120 +402,6 @@ def _generate_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # map_exit. assert len(self._current_state.in_edges(map_exit)) > 0 - def _generate_MapScope_with_nested_SDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: - inputs = node.input_memlets() - outputs = node.output_memlets() - - # setup nested SDFG - nsdfg = SDFG("nested_sdfg", parent=self._current_state) - start_state = nsdfg.add_state("nested_root", is_start_block=True) - for memlet in [*inputs, *outputs]: - if memlet.data not in nsdfg.arrays: - nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) - - # visit children inside nested SDFG - inner_visitor = StreeToSDFG(start_state) - for child in node.children: - inner_visitor.visit(child, sdfg=nsdfg) - - nested_SDFG = self._current_state.add_nested_sdfg(nsdfg, - sdfg, - inputs={memlet.data - for memlet in node.input_memlets()}, - outputs={memlet.data - for memlet in node.output_memlets()}) - - dataflow_stack_size = len(self._dataflow_stack) - outer_map_entry, outer_map_exit = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) - cache = self._ensure_access_cache(self._current_state) - - # map entry - map_entry = nodes.MapEntry(node.node.map) - self._current_state.add_node(map_entry) - - for memlet in inputs: - new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - new_out_connector = map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") - assert new_in_connector == new_out_connector - - if not new_in_connector: - continue - - # connect nested SDFG to map scope - self._current_state.add_edge(map_entry, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", nested_SDFG, - memlet.data, Memlet.from_memlet(memlet)) - - # connect map scope to "outer world" - if outer_map_entry is not None: - # passthrough if we are inside another map - self._current_state.add_edge(outer_map_entry, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) - else: - # add access node "outside the map" and connect to it - if memlet.data not in cache: - # cache read access - cache[memlet.data] = self._current_state.add_read(memlet.data) - - self._current_state.add_edge(cache[memlet.data], None, map_entry, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) - - # Add empty memlet if no explicit connection from map_entry to nested_SDFG has been done so far - if not inputs: - self._current_state.add_edge(map_entry, None, nested_SDFG, None, memlet=Memlet()) - - # Add empty memlet if outer_map_entry has no out_connectors to connect to - if outer_map_entry is not None and not outer_map_entry.out_connectors and self._current_state.out_degree( - outer_map_entry) < 1: - self._current_state.add_edge(outer_map_entry, None, map_entry, None, memlet=Memlet()) - - # map exit - map_exit = nodes.MapExit(node.node.map) - self._current_state.add_node(map_exit) - - for memlet in outputs: - new_in_connector = map_exit.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - new_out_connector = map_exit.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}") - assert new_in_connector == new_out_connector - - if not new_in_connector: - continue - - # connect nested SDFG to map scope - self._current_state.add_edge(nested_SDFG, memlet.data, map_exit, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", Memlet.from_memlet(memlet)) - - # connect map scope to "outer world" - if outer_map_exit: - # passthrough if we are inside another map - self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", outer_map_exit, - f"{PREFIX_PASSTHROUGH_IN}{memlet.data}", memlet) - else: - # add access nodes "outside the map" and connect to it - # we always write to a new access_node - access_node = self._current_state.add_write(memlet.data) - self._current_state.add_edge(map_exit, f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}", access_node, None, - memlet) - - # cache write access node (or update an existing one) for read after write cases - cache[memlet.data] = access_node - - # Add empty memlet if no explicit connection from map_entry to nested_SDFG has been done so far - if not outputs: - self._current_state.add_edge(nested_SDFG, None, map_exit, None, memlet=Memlet()) - - # Add empty memlet if outer_map_exit has no in_connectors to connect to - if outer_map_exit is not None and not outer_map_exit.in_connectors and self._current_state.in_degree( - outer_map_exit) < 1: - self._current_state.add_edge(map_exit, None, outer_map_exit, None, memlet=Memlet()) - - def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: - if any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): - # support multiple states within this map by inserting a nested SDFG - # return self._generate_MapScope_with_nested_SDFG(node, sdfg) - raise NotImplementedError("todo") - - self._generate_MapScope(node, sdfg) - def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: # AFAIK we don't support consume scopes in the gt4py/dace bridge. raise NotImplementedError(f"{type(node)} not implemented") diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 6c481a01f4..825531f8dc 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -389,7 +389,7 @@ def test_create_map_scope_read_after_write(): name="tester", containers={ 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [20]), + 'B': dace.data.Array(dace.float64, [20], transient=True), }, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), @@ -426,8 +426,6 @@ def test_create_map_scope_copy(): sdfg.validate() -# TODO: restart testing here -# TODO 2: find an automatic way to test stuff here def test_create_map_scope_double_memlet(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( @@ -496,7 +494,7 @@ def test_create_nested_map_scope_multi_read(): sdfg.validate() -def test_map_with_two_tasklets(): +def test_map_with_state_boundary_inside(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", containers={'A': dace.data.Array(dace.float64, [20])}, @@ -520,13 +518,14 @@ def test_xppm_tmp(): stree = loaded.as_schedule_tree() # TODO - # - fix missing data dependency with "al" - # - fix read after write issue + # - fix issue with state transition variable assignments sdfg = stree.as_sdfg() sdfg.validate() +# TODO: find an automatic way to test stuff here + if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() @@ -553,4 +552,4 @@ def test_xppm_tmp(): test_create_map_scope_double_memlet() test_create_nested_map_scope() test_create_nested_map_scope_multi_read() - test_map_with_two_tasklets() + test_map_with_state_boundary_inside() From e85c97a78dee04f4c395bd0a2a598664462355f9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 11 Apr 2025 19:10:25 +0200 Subject: [PATCH 052/137] WIP: fixed isssue with edge assignments Next problem (to get XPPM working) is to check for read after write cases in nestedSDFG and filter inputs with those to make the SDFG validator a bit happier. --- .../analysis/schedule_tree/tree_to_sdfg.py | 107 +++++++++++++++--- tests/schedule_tree/to_sdfg_test.py | 18 ++- 2 files changed, 107 insertions(+), 18 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index cddd131a8c..580863346e 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,7 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict -from dace import subsets +from dace import subsets, symbolic from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu, utils as sdfg_utils from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge @@ -115,15 +115,16 @@ def visit_GotoNode(self, node: tn.GotoNode, sdfg: SDFG) -> None: def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: # We just collect them here. They'll be added when state boundaries are added, - # see `visit_StateBoundaryNode()` above. + # see visitors below. self._interstate_symbols.append(node) def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: before_state = self._current_state - guard_state = sdfg.add_state(label="loop_guard") + pending = self._pending_interstate_assignments() + pending[node.header.itervar] = node.header.init + + guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) self._current_state = guard_state - sdfg.add_edge(before_state, self._current_state, - InterstateEdge(assignments=dict({node.header.itervar: node.header.init}))) body_state = sdfg.add_state(label="loop_body") self._current_state = body_state @@ -131,8 +132,10 @@ def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: # visit children inside the loop self.visit(node.children, sdfg=sdfg) - sdfg.add_edge(self._current_state, guard_state, - InterstateEdge(assignments=dict({node.header.itervar: node.header.update}))) + + pending = self._pending_interstate_assignments() + pending[node.header.itervar] = node.header.update + _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) after_state = sdfg.add_state(label="loop_after") self._current_state = after_state @@ -140,9 +143,11 @@ def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: before_state = self._current_state - guard_state = sdfg.add_state(label="guard_state") + guard_state = _insert_and_split_assignments(sdfg, + before_state, + label="guard_state", + assignments=self._pending_interstate_assignments()) self._current_state = guard_state - sdfg.add_edge(before_state, guard_state, InterstateEdge()) body_state = sdfg.add_state(label="loop_body") self._current_state = body_state @@ -150,7 +155,10 @@ def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: # visit children inside the loop self.visit(node.children, sdfg=sdfg) - sdfg.add_edge(self._current_state, guard_state, InterstateEdge()) + _insert_and_split_assignments(sdfg, + before_state=self._current_state, + after_state=guard_state, + assignments=self._pending_interstate_assignments()) after_state = sdfg.add_state(label="loop_after") self._current_state = after_state @@ -168,8 +176,10 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: before_state = self._current_state # add guard state - guard_state = sdfg.add_state(label="guard_state") - sdfg.add_edge(before_state, guard_state, InterstateEdge()) + guard_state = _insert_and_split_assignments(sdfg, + before_state, + label="guard_state", + assignments=self._pending_interstate_assignments()) # add true_state true_state = sdfg.add_state(label="true_state") @@ -180,7 +190,10 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: self.visit(node.children, sdfg=sdfg) # add merge_state - merge_state = sdfg.add_state_after(self._current_state, label="merge_state") + merge_state = _insert_and_split_assignments(sdfg, + self._current_state, + label="merge_state", + assignments=self._pending_interstate_assignments()) # Check if there's an `ElseScope` following this node (in the parent's children). # Filter StateBoundaryNodes, which we inserted earlier, for this analysis. @@ -227,7 +240,10 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: # merge false-branch into merge_state merge_state = self._pop_state("merge_state") - sdfg.add_edge(self._current_state, merge_state, InterstateEdge()) + _insert_and_split_assignments(sdfg, + before_state=self._current_state, + after_state=merge_state, + assignments=self._pending_interstate_assignments()) self._current_state = merge_state def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: @@ -485,13 +501,25 @@ def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: # When creating a state boundary, include all inter-state assignments that precede it. + pending = self._pending_interstate_assignments() + + self._current_state = create_state_boundary(node, + sdfg, + self._current_state, + StateBoundaryBehavior.STATE_TRANSITION, + assignments=pending) + + def _pending_interstate_assignments(self) -> Dict: + """ + Return currently pending interstate assignments. Clears the cache. + """ assignments = {} + for symbol in self._interstate_symbols: assignments[symbol.name] = symbol.value.as_string self._interstate_symbols.clear() - self._current_state = create_state_boundary(node, sdfg, self._current_state, - StateBoundaryBehavior.STATE_TRANSITION, assignments) + return assignments StreeToSDFG().visit(stree, sdfg=result) @@ -633,7 +661,52 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, # behavior. Fall back to state transition in that case. label = "cf_state_boundary" if bnode.due_to_control_flow else "state_boundary" - return sdfg_region.add_state_after(state, label=label, assignments=assignments) + return _insert_and_split_assignments(sdfg_region, state, label=label, assignments=assignments) + + +def _insert_and_split_assignments(sdfg_region: ControlFlowRegion, + before_state: SDFGState, + after_state: Optional[SDFGState] = None, + label: Optional[str] = None, + assignments: Optional[Dict] = None) -> SDFGState: + """ + Insert given assignments splitting them in case of potential race conditions. + + DaCe validation (currently) won't let us add multiple assignment with read after + write pattern on the same edge. We thus split them over multiple state transitions + (inserting empty states in between) to be safe. + + NOTE (later) This should be double-checked since python dictionaries preserve + insertion order since python 3.7 (which we rely on in this function + too). Depending on code generation it could(TM) be that we can + weaken (best case remove) the corresponding check from the sdfg + validator. + """ + has_potential_race = False + for key, value in assignments.items(): + syms = symbolic.free_symbols_and_functions(value) + also_assigned = (syms & assignments.keys()) - {key} + if also_assigned: + has_potential_race = True + break + + if not has_potential_race: + if after_state is not None: + sdfg_region.add_edge(before_state, after_state, InterstateEdge(assignments=assignments)) + return after_state + return sdfg_region.add_state_after(before_state, label=label, assignments=assignments) + + last_state = before_state + for index, assignment in enumerate(assignments.items()): + key, value = assignment + is_last_state = index == len(assignments) - 1 + if is_last_state and after_state is not None: + sdfg_region.add_edge(last_state, after_state, InterstateEdge(assignments={key: value})) + last_state = after_state + else: + last_state = sdfg_region.add_state_after(last_state, label=label, assignments={key: value}) + + return last_state def _list_index(list: List[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> int: diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 825531f8dc..0dbe69b96e 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -513,12 +513,27 @@ def test_map_with_state_boundary_inside(): sdfg.validate() +def test_edge_assignment_read_after_write(): + stree = tn.ScheduleTreeRoot(name="tester", + containers={}, + children=[ + tn.AssignNode("my_condition", CodeBlock("True"), dace.InterstateEdge()), + tn.AssignNode("condition", CodeBlock("my_condition"), dace.InterstateEdge()), + tn.StateBoundaryNode() + ]) + + sdfg = stree.as_sdfg(simplify=False) + + assert [node.name for node in sdfg.nodes()] == ["tree_root", "state_boundary", "state_boundary_0"] + assert [edge.data.assignments for edge in sdfg.edges()] == [{"my_condition": "True"}, {"condition": "my_condition"}] + + def test_xppm_tmp(): loaded = dace.SDFG.from_file("test.sdfgz") stree = loaded.as_schedule_tree() # TODO - # - fix issue with state transition variable assignments + # - nestedSDFG: don't add input connections in read after write situations sdfg = stree.as_sdfg() sdfg.validate() @@ -553,3 +568,4 @@ def test_xppm_tmp(): test_create_nested_map_scope() test_create_nested_map_scope_multi_read() test_map_with_state_boundary_inside() + test_edge_assignment_read_after_write() From 1bae7fa69dae75d5d9dda540a3a348c59b296b0c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 15 Apr 2025 14:22:10 +0200 Subject: [PATCH 053/137] Unrelated: fix typos --- dace/subsets.py | 4 ++-- dace/transformation/helpers.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index 0fdc36c22e..69b236405b 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1121,7 +1121,7 @@ def covers(self, other): if isinstance(other, SubsetUnion): for subset in self.subset_list: - # check if ther is a subset in self that covers every subset in other + # check if there is a subset in self that covers every subset in other if all(subset.covers(s) for s in other.subset_list): return True # return False if that's not the case for any of the subsets in self @@ -1139,7 +1139,7 @@ def covers_precise(self, other): if isinstance(other, SubsetUnion): for subset in self.subset_list: - # check if ther is a subset in self that covers every subset in other + # check if there is a subset in self that covers every subset in other if all(subset.covers_precise(s) for s in other.subset_list): return True # return False if that's not the case for any of the subsets in self diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 7824030e5c..dab5644978 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -235,7 +235,7 @@ def _copy_state(sdfg: SDFG, sdfg.add_edge(state_copy, state, InterstateEdge(condition=condition)) else: condition = None - # NOTE: The following should be unecessary for preserving program semantics. Therefore we comment it out to + # NOTE: The following should be unnecessary for preserving program semantics. Therefore we comment it out to # avoid the overhead of evaluating the condition. # if out_conditions: # condition = 'or'.join([f"({c})" for c in out_conditions]) @@ -267,7 +267,7 @@ def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: ipostdom = utils.postdominators(sdfg) cft = cf.structured_control_flow_tree(sdfg, None) - # Iterate over the SDFG's control flow scopes and create for each an SDFG subraph. These subgraphs must be disjoint, + # Iterate over the SDFG's control flow scopes and create for each an SDFG subgraph. These subgraphs must be disjoint, # so we duplicate SDFGStates that appear in more than one scopes (guards and exits of loops and conditionals). components = {} visited = {} # Dict[SDFGState, bool]: True if SDFGState in Scope (non-SingleState) @@ -799,6 +799,7 @@ def unsqueeze_memlet(internal_memlet: Memlet, external_offset: Tuple[int] = None) -> Memlet: """ Unsqueezes and offsets a memlet, as per the semantics of nested SDFGs. + :param internal_memlet: The internal memlet (inside nested SDFG) before modification. :param external_memlet: The external memlet before modification. :param preserve_minima: Do not change the subset's minimum elements. From 7d46d2dece169389208972c71dfb888b63142f94 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 15 Apr 2025 16:50:02 +0200 Subject: [PATCH 054/137] WIP: memlet propagation seems to work like this (doesn't validate) Validation fails because of missing symbols (`__i`, `__j`, `__k`) on the nested SDFG, which are added after the nested SDFG is added to the SDFG. --- .../analysis/schedule_tree/tree_to_sdfg.py | 288 +++++++++++++----- 1 file changed, 207 insertions(+), 81 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 580863346e..c88d1ba606 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -7,6 +7,7 @@ from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge from dace.sdfg.state import SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg import propagation from enum import Enum, auto from typing import Dict, Final, List, Optional, Set, Tuple @@ -52,8 +53,8 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._interstate_symbols: List[tn.AssignNode] = [] # dataflow scopes - self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode | nodes.NestedSDFG, - Memlet]]]] = [] + self._dataflow_stack: List[Tuple[nodes.EntryNode | nodes.NestedSDFG, + Dict[str, Tuple[nodes.AccessNode | nodes.NestedSDFG, Memlet]]]] = [] # caches self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} @@ -247,55 +248,125 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: self._current_state = merge_state def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: - inputs = node.input_memlets() - outputs = node.output_memlets() - - # setup nested SDFG - nsdfg = SDFG("nested_sdfg", parent=self._current_state) - start_state = nsdfg.add_state("nested_root", is_start_block=True) - for memlet in [*inputs, *outputs]: - if memlet.data not in nsdfg.arrays: - nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if sdfg.arrays[memlet.data].transient: - nsdfg.arrays[memlet.data].transient = False - - # visit children inside nested SDFG - inner_visitor = StreeToSDFG(start_state) - for child in node.children: - inner_visitor.visit(child, sdfg=nsdfg) + dataflow_stack_size = len(self._dataflow_stack) - nested_SDFG = self._current_state.add_nested_sdfg(nsdfg, - sdfg, - inputs={memlet.data - for memlet in inputs}, - outputs={memlet.data - for memlet in outputs}) + # setup inner SDFG + inner_sdfg = SDFG("nested_sdfg", parent=self._current_state) + nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, sdfg, inputs={}, outputs={}) + start_state = inner_sdfg.add_state("nested_root", is_start_block=True) + # update stacks and current state + old_state_label = self._current_state.label + self._state_stack.append(self._current_state) + self._dataflow_stack.append((nsdfg, dict())) + self._current_state = start_state + + # visit children + for child in node.children: + self.visit(child, sdfg=inner_sdfg) + + # restore current state and do stack handling + self._current_state = self._pop_state(old_state_label) + _, to_connect = self._dataflow_stack.pop() + assert not to_connect + assert len(self._dataflow_stack) == dataflow_stack_size assert self._dataflow_stack map_entry, to_connect = self._dataflow_stack[-1] - # connect input memlets - for memlet in inputs: - # get it from outside the map - array_name = memlet.data - connector_name = f"{PREFIX_PASSTHROUGH_OUT}{array_name}" - if connector_name not in map_entry.out_connectors: - new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{array_name}") - new_out_connector = map_entry.add_out_connector(connector_name) - assert new_in_connector == True - assert new_in_connector == new_out_connector + # connect nsdfg input memlets (to be propagated upon completion of the SDFG) + for name in nsdfg.in_connectors: + out_connector = f"{PREFIX_PASSTHROUGH_OUT}{name}" + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{name}") + new_out_connector = map_entry.add_out_connector(out_connector) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(map_entry, out_connector, nsdfg, name, Memlet(name)) + + # Add empty memlet if we didn't add any in the loop above + if self._current_state.out_degree(map_entry) < 1: + self._current_state.add_nedge(map_entry, nsdfg, memlet=Memlet()) + + # connect nsdfg output memlets (to be propagated) + for name in nsdfg.out_connectors: + to_connect[name] = (nsdfg, Memlet(name)) + + assert not nsdfg.sdfg.free_symbols + return + + # connect nsdfg input memlets + for name in nsdfg.in_connectors: + memlets = [memlet for memlet in node.input_memlets() if memlet.data == name] + assert len(memlets) == 1 + + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{name}") + new_out_connector = map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{name}") + assert new_in_connector + assert new_in_connector == new_out_connector - self._current_state.add_edge(map_entry, connector_name, nested_SDFG, array_name, memlet) + self._current_state.add_edge(map_entry, f"{PREFIX_PASSTHROUGH_OUT}{name}", nsdfg, name, memlets[0]) # Add empty memlet if we didn't add any in the loop above if self._current_state.out_degree(map_entry) < 1: - self._current_state.add_edge(map_entry, None, nested_SDFG, None, memlet=Memlet()) + self._current_state.add_edge(map_entry, None, nsdfg, None, memlet=Memlet()) + + # connect nsdfg output memlets + for name in nsdfg.out_connectors: + memlets = [memlet for memlet in node.output_memlets() if memlet.data == name] + assert len(memlets) == 1 - # connect output memlets - for memlet in outputs: - to_connect[memlet.data] = (nested_SDFG, memlet) + to_connect[memlets[0].data] = (nsdfg, memlets[0]) + + # # ------------- old + # inputs = node.input_memlets() + # outputs = node.output_memlets() + # + # # setup nested SDFG + # nsdfg = SDFG("nested_sdfg", parent=self._current_state) + # start_state = nsdfg.add_state("nested_root", is_start_block=True) + # for memlet in [*inputs, *outputs]: + # if memlet.data not in nsdfg.arrays: + # nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) + # + # # Transients passed into a nested SDFG become non-transient inside that nested SDFG + # if sdfg.arrays[memlet.data].transient: + # nsdfg.arrays[memlet.data].transient = False + # + # # visit children inside nested SDFG + # inner_visitor = StreeToSDFG(start_state) + # for child in node.children: + # inner_visitor.visit(child, sdfg=nsdfg) + # + # nested_SDFG = self._current_state.add_nested_sdfg(nsdfg, + # sdfg, + # inputs={memlet.data + # for memlet in inputs}, + # outputs={memlet.data + # for memlet in outputs}) + # + # assert self._dataflow_stack + # map_entry, to_connect = self._dataflow_stack[-1] + # + # # connect input memlets + # for memlet in inputs: + # # get it from outside the map + # array_name = memlet.data + # connector_name = f"{PREFIX_PASSTHROUGH_OUT}{array_name}" + # if connector_name not in map_entry.out_connectors: + # new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{array_name}") + # new_out_connector = map_entry.add_out_connector(connector_name) + # assert new_in_connector == True + # assert new_in_connector == new_out_connector + # + # self._current_state.add_edge(map_entry, connector_name, nested_SDFG, array_name, memlet) + # + # # Add empty memlet if we didn't add any in the loop above + # if self._current_state.out_degree(map_entry) < 1: + # self._current_state.add_edge(map_entry, None, nested_SDFG, None, memlet=Memlet()) + # + # # connect output memlets + # for memlet in outputs: + # to_connect[memlet.data] = (nested_SDFG, memlet) def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) @@ -317,6 +388,9 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: if any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): # to the funky stuff self._insert_nestedSDFG(node, sdfg) + # only propagate memlets once the full SDFG is built to ensure that all memlets + # are connected to their (outermost) AccessNode. + # propagate_memlets_nested_sdfg(sdfg, self._current_state, nsdfg) else: self.visit(node.children, sdfg=sdfg) @@ -329,26 +403,29 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: outer_map_entry, outer_to_connect = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) # connect potential input connectors on map_entry - input_memlets = node.input_memlets() + # input_memlets = node.input_memlets() for connector in map_entry.in_connectors: memlet_data = connector.removeprefix(PREFIX_PASSTHROUGH_IN) - # find input memlet - memlets = [memlet for memlet in input_memlets if memlet.data == memlet_data] - assert len(memlets) > 0 - memlet = copy.deepcopy(memlets[0]) - if len(memlets) > 1: - # merge memlets - for index, element in enumerate(memlets): - if index == 0: - continue - memlet.subset = subsets.union(memlet.subset, element.subset) - # TODO(later): figure out the volume thing (also in MemletSet). Also: num_accesses (for legacy reasons) - memlet.volume += element.volume + # # find input memlet + # memlets = [memlet for memlet in input_memlets if memlet.data == memlet_data] + # assert len(memlets) > 0 + # memlet = copy.deepcopy(memlets[0]) + # if len(memlets) > 1: + # # merge memlets + # for index, element in enumerate(memlets): + # if index == 0: + # continue + # memlet.subset = subsets.union(memlet.subset, element.subset) + # # TODO(later): figure out the volume thing (also in MemletSet). Also: num_accesses (for legacy reasons) + # memlet.volume += element.volume # connect to local access node (if available) if memlet_data in access_cache: cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, memlet=memlet) + self._current_state.add_memlet_path(cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet(memlet_data)) continue if outer_map_entry is not None: @@ -360,16 +437,20 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert new_in_connector == True assert new_in_connector == new_out_connector - self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, memlet) + self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, + Memlet(memlet_data)) else: # cache local read access assert memlet_data not in access_cache access_cache[memlet_data] = self._current_state.add_read(memlet_data) cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, memlet=memlet) + self._current_state.add_memlet_path(cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet(memlet_data)) if outer_map_entry is not None and self._current_state.out_degree(outer_map_entry) < 1: - self._current_state.add_edge(outer_map_entry, None, map_entry, None, memlet=Memlet()) + self._current_state.add_nedge(outer_map_entry, map_entry, memlet=Memlet()) # map_exit # -------- @@ -377,7 +458,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_node(map_exit) # connect writes to map_exit node - output_memlets = node.output_memlets() + # output_memlets = node.output_memlets() for name in to_connect: in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" @@ -388,30 +469,30 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect "inside the map" access_node, memlet = to_connect[name] if isinstance(access_node, nodes.NestedSDFG): - self._current_state.add_edge(access_node, name, map_exit, in_connector_name, copy.deepcopy(memlet)) + self._current_state.add_edge(access_node, name, map_exit, in_connector_name, memlet) else: assert isinstance(access_node, nodes.AccessNode) self._current_state.add_memlet_path(access_node, map_exit, dst_conn=in_connector_name, - memlet=copy.deepcopy(memlet)) + memlet=memlet) # connect "outside the map" - # find output memlet - memlets = [memlet for memlet in output_memlets if memlet.data == name] - assert len(memlets) == 1 + # # find output memlet + # memlets = [memlet for memlet in output_memlets if memlet.data == name] + # assert len(memlets) == 1 access_node = self._current_state.add_write(name) self._current_state.add_memlet_path(map_exit, access_node, src_conn=out_connector_name, - memlet=memlets[0]) + memlet=Memlet(name)) # cache write access into access_cache access_cache[name] = access_node if outer_to_connect is not None: - outer_to_connect[name] = (access_node, memlets[0]) + outer_to_connect[name] = (access_node, Memlet(name)) # TODO If nothing is connected at this point, figure out what's the last thing that # we should connect to. Then, add an empty memlet from that last thing to this @@ -432,7 +513,7 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: self._current_state.add_node(tasklet) cache = self._ensure_access_cache(self._current_state) - map_entry, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) + scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) # Connect input memlets for name, memlet in node.in_memlets.items(): @@ -442,26 +523,51 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) continue - if map_entry is not None: + if isinstance(scope_node, nodes.MapEntry): # get it from outside the map connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}" - if connector_name not in map_entry.out_connectors: - new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - new_out_connector = map_entry.add_out_connector(connector_name) + if connector_name not in scope_node.out_connectors: + new_in_connector = scope_node.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = scope_node.add_out_connector(connector_name) assert new_in_connector == True assert new_in_connector == new_out_connector - self._current_state.add_edge(map_entry, connector_name, tasklet, name, memlet) + self._current_state.add_edge(scope_node, connector_name, tasklet, name, memlet) + continue + + if isinstance(scope_node, nodes.NestedSDFG): + # connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" + # scope_node.add_in_connector(connector_name) + # self._current_state.add_edge(scope_node, connector_name, tasklet, name, memlet) + # continue + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + new_in_connector = scope_node.add_in_connector(memlet.data) + assert new_in_connector == True else: - # cache local read access - assert memlet.data not in cache - cache[memlet.data] = self._current_state.add_read(memlet.data) - cached_access = cache[memlet.data] - self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) + assert scope_node is None + + # cache local read access + assert memlet.data not in cache + cache[memlet.data] = self._current_state.add_read(memlet.data) + cached_access = cache[memlet.data] + self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) # Add empty memlet if map_entry has no out_connectors to connect to - if map_entry is not None and self._current_state.out_degree(map_entry) < 1: - self._current_state.add_edge(map_entry, None, tasklet, None, memlet=Memlet()) + if isinstance(scope_node, nodes.MapEntry) and self._current_state.out_degree(scope_node) < 1: + self._current_state.add_nedge(scope_node, tasklet, Memlet()) # Connect output memlets for name, memlet in node.out_memlets.items(): @@ -472,8 +578,26 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # cache write access node (or update an existing one) for read after write cases cache[memlet.data] = access_node - if to_connect is not None: - to_connect[memlet.data] = (access_node, memlet) + if isinstance(scope_node, nodes.MapEntry): + # copy the memlet since we already used it in the memlet path above + to_connect[memlet.data] = (access_node, copy.deepcopy(memlet)) + continue + + if isinstance(scope_node, nodes.NestedSDFG): + if memlet.data not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + + # Add out_connector in any case if not yet present, e.g. write after read + scope_node.add_out_connector(memlet.data) + # self._current_state.add_memlet_path(access_node, scope_node, dst_conn=memlet.data, memlet=copy.deepcopy(memlet)) + + else: + assert scope_node is None def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: # AFAIK we expand all library calls in the gt4py/dace bridge before coming here. @@ -523,6 +647,8 @@ def _pending_interstate_assignments(self) -> Dict: StreeToSDFG().visit(stree, sdfg=result) + propagation.propagate_memlets_sdfg(result) + return result From ed5e48ed1669be2ac015158253627e5d3dbf829e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 15 Apr 2025 16:54:53 +0200 Subject: [PATCH 055/137] WIP: delete unreachable/commented code --- .../analysis/schedule_tree/tree_to_sdfg.py | 104 +----------------- 1 file changed, 1 insertion(+), 103 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index c88d1ba606..a85b494489 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -3,7 +3,7 @@ from collections import defaultdict from dace import subsets, symbolic from dace.memlet import Memlet -from dace.sdfg import nodes, memlet_utils as mmu, utils as sdfg_utils +from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge from dace.sdfg.state import SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn @@ -292,81 +292,6 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: to_connect[name] = (nsdfg, Memlet(name)) assert not nsdfg.sdfg.free_symbols - return - - # connect nsdfg input memlets - for name in nsdfg.in_connectors: - memlets = [memlet for memlet in node.input_memlets() if memlet.data == name] - assert len(memlets) == 1 - - new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{name}") - new_out_connector = map_entry.add_out_connector(f"{PREFIX_PASSTHROUGH_OUT}{name}") - assert new_in_connector - assert new_in_connector == new_out_connector - - self._current_state.add_edge(map_entry, f"{PREFIX_PASSTHROUGH_OUT}{name}", nsdfg, name, memlets[0]) - - # Add empty memlet if we didn't add any in the loop above - if self._current_state.out_degree(map_entry) < 1: - self._current_state.add_edge(map_entry, None, nsdfg, None, memlet=Memlet()) - - # connect nsdfg output memlets - for name in nsdfg.out_connectors: - memlets = [memlet for memlet in node.output_memlets() if memlet.data == name] - assert len(memlets) == 1 - - to_connect[memlets[0].data] = (nsdfg, memlets[0]) - - # # ------------- old - # inputs = node.input_memlets() - # outputs = node.output_memlets() - # - # # setup nested SDFG - # nsdfg = SDFG("nested_sdfg", parent=self._current_state) - # start_state = nsdfg.add_state("nested_root", is_start_block=True) - # for memlet in [*inputs, *outputs]: - # if memlet.data not in nsdfg.arrays: - # nsdfg.add_datadesc(memlet.data, sdfg.arrays[memlet.data].clone()) - # - # # Transients passed into a nested SDFG become non-transient inside that nested SDFG - # if sdfg.arrays[memlet.data].transient: - # nsdfg.arrays[memlet.data].transient = False - # - # # visit children inside nested SDFG - # inner_visitor = StreeToSDFG(start_state) - # for child in node.children: - # inner_visitor.visit(child, sdfg=nsdfg) - # - # nested_SDFG = self._current_state.add_nested_sdfg(nsdfg, - # sdfg, - # inputs={memlet.data - # for memlet in inputs}, - # outputs={memlet.data - # for memlet in outputs}) - # - # assert self._dataflow_stack - # map_entry, to_connect = self._dataflow_stack[-1] - # - # # connect input memlets - # for memlet in inputs: - # # get it from outside the map - # array_name = memlet.data - # connector_name = f"{PREFIX_PASSTHROUGH_OUT}{array_name}" - # if connector_name not in map_entry.out_connectors: - # new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{array_name}") - # new_out_connector = map_entry.add_out_connector(connector_name) - # assert new_in_connector == True - # assert new_in_connector == new_out_connector - # - # self._current_state.add_edge(map_entry, connector_name, nested_SDFG, array_name, memlet) - # - # # Add empty memlet if we didn't add any in the loop above - # if self._current_state.out_degree(map_entry) < 1: - # self._current_state.add_edge(map_entry, None, nested_SDFG, None, memlet=Memlet()) - # - # # connect output memlets - # for memlet in outputs: - # to_connect[memlet.data] = (nested_SDFG, memlet) def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) @@ -386,11 +311,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # visit children inside the map if any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): - # to the funky stuff self._insert_nestedSDFG(node, sdfg) - # only propagate memlets once the full SDFG is built to ensure that all memlets - # are connected to their (outermost) AccessNode. - # propagate_memlets_nested_sdfg(sdfg, self._current_state, nsdfg) else: self.visit(node.children, sdfg=sdfg) @@ -403,21 +324,8 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: outer_map_entry, outer_to_connect = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) # connect potential input connectors on map_entry - # input_memlets = node.input_memlets() for connector in map_entry.in_connectors: memlet_data = connector.removeprefix(PREFIX_PASSTHROUGH_IN) - # # find input memlet - # memlets = [memlet for memlet in input_memlets if memlet.data == memlet_data] - # assert len(memlets) > 0 - # memlet = copy.deepcopy(memlets[0]) - # if len(memlets) > 1: - # # merge memlets - # for index, element in enumerate(memlets): - # if index == 0: - # continue - # memlet.subset = subsets.union(memlet.subset, element.subset) - # # TODO(later): figure out the volume thing (also in MemletSet). Also: num_accesses (for legacy reasons) - # memlet.volume += element.volume # connect to local access node (if available) if memlet_data in access_cache: @@ -458,7 +366,6 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_node(map_exit) # connect writes to map_exit node - # output_memlets = node.output_memlets() for name in to_connect: in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" @@ -478,10 +385,6 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: memlet=memlet) # connect "outside the map" - # # find output memlet - # memlets = [memlet for memlet in output_memlets if memlet.data == name] - # assert len(memlets) == 1 - access_node = self._current_state.add_write(name) self._current_state.add_memlet_path(map_exit, access_node, @@ -536,10 +439,6 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: continue if isinstance(scope_node, nodes.NestedSDFG): - # connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" - # scope_node.add_in_connector(connector_name) - # self._current_state.add_edge(scope_node, connector_name, tasklet, name, memlet) - # continue # Copy data descriptor from parent SDFG and add input connector if memlet.data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent @@ -594,7 +493,6 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add out_connector in any case if not yet present, e.g. write after read scope_node.add_out_connector(memlet.data) - # self._current_state.add_memlet_path(access_node, scope_node, dst_conn=memlet.data, memlet=copy.deepcopy(memlet)) else: assert scope_node is None From 65706822749118d236f878b121ad17a2b6aff1e0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 16 Apr 2025 10:51:33 +0200 Subject: [PATCH 056/137] XPPM test generates an SDFG that validates \o/ This just generates a validating SDFG. No other testing was done. And this needs a bit of cleanup too. Steps to make validation / memlet propagation happy included: - use Memlet.from_array() to generate a memlet of the full data - add nested SDFG only once fully built (fixed the missing symbols) --- .../analysis/schedule_tree/tree_to_sdfg.py | 65 +++++++++++++------ tests/schedule_tree/to_sdfg_test.py | 3 - 2 files changed, 44 insertions(+), 24 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index a85b494489..0889434f8a 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,7 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict -from dace import subsets, symbolic +from dace import dtypes, subsets, symbolic from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge @@ -53,8 +53,9 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._interstate_symbols: List[tn.AssignNode] = [] # dataflow scopes - self._dataflow_stack: List[Tuple[nodes.EntryNode | nodes.NestedSDFG, - Dict[str, Tuple[nodes.AccessNode | nodes.NestedSDFG, Memlet]]]] = [] + # List[ (MapEntryNode, ToConnect) | (SDFG, (inputs, outputs))] + self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]] + | Tuple[SDFG, Tuple[Set[str], Set[str]]]] = [] # caches self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} @@ -249,27 +250,32 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) + state_stack_size = len(self._state_stack) # setup inner SDFG inner_sdfg = SDFG("nested_sdfg", parent=self._current_state) - nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, sdfg, inputs={}, outputs={}) start_state = inner_sdfg.add_state("nested_root", is_start_block=True) # update stacks and current state old_state_label = self._current_state.label self._state_stack.append(self._current_state) - self._dataflow_stack.append((nsdfg, dict())) + self._dataflow_stack.append((inner_sdfg, (set(), set()))) self._current_state = start_state # visit children for child in node.children: self.visit(child, sdfg=inner_sdfg) - # restore current state and do stack handling + # restore current state and stacks self._current_state = self._pop_state(old_state_label) - _, to_connect = self._dataflow_stack.pop() - assert not to_connect + assert len(self._state_stack) == state_stack_size + _, connectors = self._dataflow_stack.pop() assert len(self._dataflow_stack) == dataflow_stack_size + + # insert nested SDFG + nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, sdfg, inputs=connectors[0], outputs=connectors[1]) + + # connect nested SDFG to surrounding map scope assert self._dataflow_stack map_entry, to_connect = self._dataflow_stack[-1] @@ -281,7 +287,8 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: assert new_in_connector == True assert new_in_connector == new_out_connector - self._current_state.add_edge(map_entry, out_connector, nsdfg, name, Memlet(name)) + self._current_state.add_edge(map_entry, out_connector, nsdfg, name, + Memlet.from_array(name, nsdfg.sdfg.arrays[name])) # Add empty memlet if we didn't add any in the loop above if self._current_state.out_degree(map_entry) < 1: @@ -289,9 +296,18 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect nsdfg output memlets (to be propagated) for name in nsdfg.out_connectors: - to_connect[name] = (nsdfg, Memlet(name)) + to_connect[name] = (nsdfg, Memlet.from_array(name, nsdfg.sdfg.arrays[name])) - assert not nsdfg.sdfg.free_symbols + # # Add new global symbols to nested SDFG + # new_symbols = {sym: sym for sym in nsdfg.sdfg.free_symbols if sym not in nsdfg.symbol_mapping} + # nsdfg.symbol_mapping.update(new_symbols) + # + # from dace.codegen.tools.type_inference import infer_expr_type + # for sym, symval in new_symbols.items(): + # if sym not in nsdfg.sdfg.symbols: + # nsdfg.sdfg.add_symbol(sym, infer_expr_type(symval, sdfg.symbols) or dtypes.typeclass(int)) + + # assert not nsdfg.sdfg.free_symbols def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) @@ -299,6 +315,9 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # map entry # --------- map_entry = nodes.MapEntry(node.node.map) + # # Add node.node.map.params as symbols here? + # for param in node.node.map.params: + # sdfg.add_symbol(param) self._current_state.add_node(map_entry) self._dataflow_stack.append((map_entry, dict())) @@ -333,7 +352,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, - memlet=Memlet(memlet_data)) + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) continue if outer_map_entry is not None: @@ -346,7 +365,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert new_in_connector == new_out_connector self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, - Memlet(memlet_data)) + Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) else: # cache local read access assert memlet_data not in access_cache @@ -355,7 +374,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_memlet_path(cached_access, map_entry, dst_conn=connector, - memlet=Memlet(memlet_data)) + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) if outer_map_entry is not None and self._current_state.out_degree(outer_map_entry) < 1: self._current_state.add_nedge(outer_map_entry, map_entry, memlet=Memlet()) @@ -389,13 +408,13 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_memlet_path(map_exit, access_node, src_conn=out_connector_name, - memlet=Memlet(name)) + memlet=Memlet.from_array(name, sdfg.arrays[name])) # cache write access into access_cache access_cache[name] = access_node if outer_to_connect is not None: - outer_to_connect[name] = (access_node, Memlet(name)) + outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) # TODO If nothing is connected at this point, figure out what's the last thing that # we should connect to. Then, add an empty memlet from that last thing to this @@ -438,7 +457,7 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: self._current_state.add_edge(scope_node, connector_name, tasklet, name, memlet) continue - if isinstance(scope_node, nodes.NestedSDFG): + if isinstance(scope_node, SDFG): # Copy data descriptor from parent SDFG and add input connector if memlet.data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent @@ -453,8 +472,10 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # NOTE This can probably be done automatically by a cleanup pass in the end. # Something like DDE should be able to do this. - new_in_connector = scope_node.add_in_connector(memlet.data) - assert new_in_connector == True + assert memlet.data not in to_connect[0] + to_connect[0].add(memlet.data) + # new_in_connector = scope_node.add_in_connector(memlet.data) + # assert new_in_connector == True else: assert scope_node is None @@ -482,7 +503,7 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: to_connect[memlet.data] = (access_node, copy.deepcopy(memlet)) continue - if isinstance(scope_node, nodes.NestedSDFG): + if isinstance(scope_node, SDFG): if memlet.data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) @@ -492,7 +513,9 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: sdfg.arrays[memlet.data].transient = False # Add out_connector in any case if not yet present, e.g. write after read - scope_node.add_out_connector(memlet.data) + # assert memlet.data not in to_connect[1] + to_connect[1].add(memlet.data) + # scope_node.add_out_connector(memlet.data) else: assert scope_node is None diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 0dbe69b96e..c211101323 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -532,9 +532,6 @@ def test_xppm_tmp(): loaded = dace.SDFG.from_file("test.sdfgz") stree = loaded.as_schedule_tree() - # TODO - # - nestedSDFG: don't add input connections in read after write situations - sdfg = stree.as_sdfg() sdfg.validate() From 45d43c68b8a2710e220ae1db88a1be3314381511 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 16 Apr 2025 11:41:12 +0200 Subject: [PATCH 057/137] fixup: minor cleanups after passing xxpm test --- .../analysis/schedule_tree/tree_to_sdfg.py | 38 +++++-------------- 1 file changed, 9 insertions(+), 29 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 0889434f8a..cb7268fba1 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -53,9 +53,9 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._interstate_symbols: List[tn.AssignNode] = [] # dataflow scopes - # List[ (MapEntryNode, ToConnect) | (SDFG, (inputs, outputs))] + # List[ (MapEntryNode, ToConnect) | (SDFG, {"inputs": set(), "outputs": set()}) ] self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]] - | Tuple[SDFG, Tuple[Set[str], Set[str]]]] = [] + | Tuple[SDFG, Dict[str, Set[str]]]] = [] # caches self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} @@ -252,14 +252,14 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) state_stack_size = len(self._state_stack) - # setup inner SDFG + # prepare inner SDFG inner_sdfg = SDFG("nested_sdfg", parent=self._current_state) start_state = inner_sdfg.add_state("nested_root", is_start_block=True) # update stacks and current state old_state_label = self._current_state.label self._state_stack.append(self._current_state) - self._dataflow_stack.append((inner_sdfg, (set(), set()))) + self._dataflow_stack.append((inner_sdfg, {"inputs": set(), "outputs": set()})) self._current_state = start_state # visit children @@ -273,7 +273,7 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: assert len(self._dataflow_stack) == dataflow_stack_size # insert nested SDFG - nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, sdfg, inputs=connectors[0], outputs=connectors[1]) + nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, sdfg, inputs=connectors["inputs"], outputs=connectors["outputs"]) # connect nested SDFG to surrounding map scope assert self._dataflow_stack @@ -298,34 +298,18 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: for name in nsdfg.out_connectors: to_connect[name] = (nsdfg, Memlet.from_array(name, nsdfg.sdfg.arrays[name])) - # # Add new global symbols to nested SDFG - # new_symbols = {sym: sym for sym in nsdfg.sdfg.free_symbols if sym not in nsdfg.symbol_mapping} - # nsdfg.symbol_mapping.update(new_symbols) - # - # from dace.codegen.tools.type_inference import infer_expr_type - # for sym, symval in new_symbols.items(): - # if sym not in nsdfg.sdfg.symbols: - # nsdfg.sdfg.add_symbol(sym, infer_expr_type(symval, sdfg.symbols) or dtypes.typeclass(int)) - - # assert not nsdfg.sdfg.free_symbols - def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) # map entry # --------- map_entry = nodes.MapEntry(node.node.map) - # # Add node.node.map.params as symbols here? - # for param in node.node.map.params: - # sdfg.add_symbol(param) self._current_state.add_node(map_entry) self._dataflow_stack.append((map_entry, dict())) - # keep a copy of the access cache - access_cache = self._ensure_access_cache(self._current_state) - # Set a new access_cache before visiting children such that they have their # own access cache (per map scope). + access_cache = self._ensure_access_cache(self._current_state) self._access_cache[self._current_state] = {} # visit children inside the map @@ -472,10 +456,8 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # NOTE This can probably be done automatically by a cleanup pass in the end. # Something like DDE should be able to do this. - assert memlet.data not in to_connect[0] - to_connect[0].add(memlet.data) - # new_in_connector = scope_node.add_in_connector(memlet.data) - # assert new_in_connector == True + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) else: assert scope_node is None @@ -513,9 +495,7 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: sdfg.arrays[memlet.data].transient = False # Add out_connector in any case if not yet present, e.g. write after read - # assert memlet.data not in to_connect[1] - to_connect[1].add(memlet.data) - # scope_node.add_out_connector(memlet.data) + to_connect["outputs"].add(memlet.data) else: assert scope_node is None From dab3fa70f625dfe7409f085ba6df8d64f3f2a3ec Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 23 Apr 2025 15:15:56 +0200 Subject: [PATCH 058/137] Avoid "passthrough write nodes" --- .../analysis/schedule_tree/tree_to_sdfg.py | 26 +++++++++++++++---- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index cb7268fba1..e91755a4e5 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -273,7 +273,10 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: assert len(self._dataflow_stack) == dataflow_stack_size # insert nested SDFG - nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, sdfg, inputs=connectors["inputs"], outputs=connectors["outputs"]) + nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, + sdfg, + inputs=connectors["inputs"], + outputs=connectors["outputs"]) # connect nested SDFG to surrounding map scope assert self._dataflow_stack @@ -382,10 +385,23 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_edge(access_node, name, map_exit, in_connector_name, memlet) else: assert isinstance(access_node, nodes.AccessNode) - self._current_state.add_memlet_path(access_node, - map_exit, - dst_conn=in_connector_name, - memlet=memlet) + if self._current_state.out_degree(access_node) == 0 and self._current_state.in_degree( + access_node) == 1: + # this access_node is not used for anything else. + # let's remove it and add a direct connection instead + edges = [edge for edge in self._current_state.edges() if edge.dst == access_node] + assert len(edges) == 1 + self._current_state.add_memlet_path(edges[0].src, + map_exit, + src_conn=edges[0].src_conn, + dst_conn=in_connector_name, + memlet=edges[0].data) + self._current_state.remove_node(access_node) # edge is remove automatically + else: + self._current_state.add_memlet_path(access_node, + map_exit, + dst_conn=in_connector_name, + memlet=memlet) # connect "outside the map" access_node = self._current_state.add_write(name) From bf5b9e1352ab3c04a2c0bde39df4e0f144c19ba4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 23 Apr 2025 18:15:23 +0200 Subject: [PATCH 059/137] Allow nested SDFGs inside other. Add CopyNode translation --- .../analysis/schedule_tree/tree_to_sdfg.py | 25 ++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index e91755a4e5..43032b3959 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -316,7 +316,10 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._access_cache[self._current_state] = {} # visit children inside the map - if any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): + if [type(child) for child in node.children] == [tn.StateBoundaryNode, tn.MapScope]: + # skip weirdly added StateBoundaryNode + self.visit(node.children[1], sdfg=sdfg) + elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): self._insert_nestedSDFG(node, sdfg) else: self.visit(node.children, sdfg=sdfg) @@ -343,6 +346,8 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: continue if outer_map_entry is not None: + assert isinstance(outer_map_entry, nodes.EntryNode) + # get it from outside the map connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet_data}" if connector_name not in outer_map_entry.out_connectors: @@ -461,6 +466,9 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Copy data descriptor from parent SDFG and add input connector if memlet.data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent + while memlet.data not in parent_sdfg.arrays: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) # Transients passed into a nested SDFG become non-transient inside that nested SDFG @@ -521,8 +529,19 @@ def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: - # AFAIK we don't support copy nodes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + # apparently we need this for the first prototype + self._ensure_access_cache(self._current_state) + access_cache = self._access_cache[self._current_state] + + # assumption source access may or may not yet exist (in this state) + src_name = node.memlet.data + source = access_cache[src_name] if src_name in access_cache else self._current_state.add_read(src_name) + + # assumption: target access node doesn't exist yet + assert node.target not in access_cache + target = self._current_state.add_write(node.target) + + self._current_state.add_memlet_path(source, target, memlet=node.memlet) def visit_DynScopeCopyNode(self, node: tn.DynScopeCopyNode, sdfg: SDFG) -> None: # AFAIK we don't support dyn scope copy nodes in the gt4py/dace bridge. From aa613dd99c91b04e2186a1892f4f9ff0bfb35e43 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 24 Apr 2025 12:20:26 +0200 Subject: [PATCH 060/137] [to be reverted] test_delnflux_tmp --- tests/schedule_tree/to_sdfg_test.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index c211101323..faa040c5c8 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -536,6 +536,14 @@ def test_xppm_tmp(): sdfg.validate() +def test_delnflux_tmp(): + loaded = dace.SDFG.from_file("tmp_delnflux.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + # TODO: find an automatic way to test stuff here if __name__ == '__main__': From 901475243b1702993b33c2043c03d9a7c48d3232 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 24 Apr 2025 18:14:56 +0200 Subject: [PATCH 061/137] Fix syntax issue (unexpected keyword arg "memlet") --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 43032b3959..76ec893ba7 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -295,7 +295,7 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: # Add empty memlet if we didn't add any in the loop above if self._current_state.out_degree(map_entry) < 1: - self._current_state.add_nedge(map_entry, nsdfg, memlet=Memlet()) + self._current_state.add_nedge(map_entry, nsdfg, Memlet()) # connect nsdfg output memlets (to be propagated) for name in nsdfg.out_connectors: @@ -369,7 +369,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) if outer_map_entry is not None and self._current_state.out_degree(outer_map_entry) < 1: - self._current_state.add_nedge(outer_map_entry, map_entry, memlet=Memlet()) + self._current_state.add_nedge(outer_map_entry, map_entry, Memlet()) # map_exit # -------- From f775b1779d95dff147dacbfec2fdd26761e0ae44 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 25 Apr 2025 09:08:07 +0200 Subject: [PATCH 062/137] Todo: write Nview visitor (for Fillz, Ray_Fast) Just logging a todo for after the DelnFlux debugging. --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 76ec893ba7..c0f53b576f 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -552,7 +552,7 @@ def visit_ViewNode(self, node: tn.ViewNode, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... + # TODO: Fillz and Ray_Fast will need these ... raise NotImplementedError(f"{type(node)} not implemented") def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: From be05ba5518bbe1b0d3024e6447306bed458533d3 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 5 May 2025 15:56:13 +0200 Subject: [PATCH 063/137] "back propagate" state boundaries from nested SDFGs through maps --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 13 +++++++++++++ tests/schedule_tree/to_sdfg_test.py | 4 ++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index c0f53b576f..88bf60a734 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -318,6 +318,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # visit children inside the map if [type(child) for child in node.children] == [tn.StateBoundaryNode, tn.MapScope]: # skip weirdly added StateBoundaryNode + # tmp: use this - for now - to "backprop-insert" extra state boundaries for nested SDFGs self.visit(node.children[1], sdfg=sdfg) elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): self._insert_nestedSDFG(node, sdfg) @@ -619,6 +620,18 @@ def visit_StateLabel(self, node: tn.StateLabel): # Then, insert boundaries after unmet memory dependencies or potential data races _insert_memory_dependency_state_boundaries(stree) + # Hack: "backprop-insert" state boundaries from nested SDFGs + class NestedSDFGStateBoundaryInserter(tn.ScheduleNodeTransformer): + + def visit_scope(self, scope: tn.ScheduleTreeScope): + visited = self.generic_visit(scope) + if isinstance(scope, tn.MapScope) and any( + [isinstance(child, tn.StateBoundaryNode) for child in scope.children]): + return [tn.StateBoundaryNode(), visited] + return visited + + stree = NestedSDFGStateBoundaryInserter().visit(stree) + return stree diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index faa040c5c8..21954dfaaf 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -536,8 +536,8 @@ def test_xppm_tmp(): sdfg.validate() -def test_delnflux_tmp(): - loaded = dace.SDFG.from_file("tmp_delnflux.sdfgz") +def test_DelnFluxNoSG_tmp(): + loaded = dace.SDFG.from_file("tmp_DelnFluxNoSG.sdfgz") stree = loaded.as_schedule_tree() sdfg = stree.as_sdfg() From c2e53c340073af22d5a41d5f413c580248ce5004 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 6 May 2025 10:30:35 +0200 Subject: [PATCH 064/137] WIP: This seems to fix DelnFlux --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 7 +++++-- tests/schedule_tree/to_sdfg_test.py | 8 ++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 88bf60a734..4883ff06c8 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -316,10 +316,13 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._access_cache[self._current_state] = {} # visit children inside the map - if [type(child) for child in node.children] == [tn.StateBoundaryNode, tn.MapScope]: + type_of_children = [type(child) for child in node.children] + last_child_is_MapScope = type_of_children[-1] == tn.MapScope + all_others_are_Boundaries = type_of_children.count(tn.StateBoundaryNode) == len(type_of_children) - 1 + if last_child_is_MapScope and all_others_are_Boundaries: # skip weirdly added StateBoundaryNode # tmp: use this - for now - to "backprop-insert" extra state boundaries for nested SDFGs - self.visit(node.children[1], sdfg=sdfg) + self.visit(node.children[-1], sdfg=sdfg) elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): self._insert_nestedSDFG(node, sdfg) else: diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 21954dfaaf..364219faa2 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -544,6 +544,14 @@ def test_DelnFluxNoSG_tmp(): sdfg.validate() +def test_DelnFlux_tmp(): + loaded = dace.SDFG.from_file("tmp_DelnFlux.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + # TODO: find an automatic way to test stuff here if __name__ == '__main__': From 4accdb9ec5d562b109541bc5a072aff5c7135c62 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 6 May 2025 17:44:35 +0200 Subject: [PATCH 065/137] WIP: This seems to fix FvTp2d For D_SW we still have duplicate names for "dt". --- .../analysis/schedule_tree/tree_to_sdfg.py | 61 +++++++++++++++++-- tests/schedule_tree/to_sdfg_test.py | 8 +++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 4883ff06c8..826664e4ad 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -19,6 +19,7 @@ class StateBoundaryBehavior(Enum): PREFIX_PASSTHROUGH_IN: Final[str] = "IN_" PREFIX_PASSTHROUGH_OUT: Final[str] = "OUT_" +MAX_NESTED_SDFGS: Final[int] = 1000 def from_schedule_tree(stree: tn.ScheduleTreeRoot, @@ -349,8 +350,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) continue - if outer_map_entry is not None: - assert isinstance(outer_map_entry, nodes.EntryNode) + if isinstance(outer_map_entry, nodes.EntryNode): # get it from outside the map connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet_data}" @@ -363,6 +363,31 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) else: + if isinstance(outer_map_entry, SDFG): + # Copy data descriptor from parent SDFG and add input connector + if memlet_data not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while memlet_data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + sdfg.add_datadesc(memlet_data, parent_sdfg.arrays[memlet_data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet_data].transient: + sdfg.arrays[memlet_data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + assert memlet_data not in outer_to_connect["inputs"] + outer_to_connect["inputs"].add(memlet_data) + else: + assert outer_map_entry is None + # cache local read access assert memlet_data not in access_cache access_cache[memlet_data] = self._current_state.add_read(memlet_data) @@ -372,7 +397,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dst_conn=connector, memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) - if outer_map_entry is not None and self._current_state.out_degree(outer_map_entry) < 1: + if isinstance(outer_map_entry, nodes.EntryNode) and self._current_state.out_degree(outer_map_entry) < 1: self._current_state.add_nedge(outer_map_entry, map_entry, Memlet()) # map_exit @@ -412,6 +437,23 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dst_conn=in_connector_name, memlet=memlet) + if isinstance(outer_map_entry, SDFG): + if name not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while name not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + sdfg.add_datadesc(name, parent_sdfg.arrays[name].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[name].transient: + sdfg.arrays[name].transient = False + + # Add out_connector in any case if not yet present, e.g. write after read + outer_to_connect["outputs"].add(name) + # connect "outside the map" access_node = self._current_state.add_write(name) self._current_state.add_memlet_path(map_exit, @@ -422,8 +464,10 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # cache write access into access_cache access_cache[name] = access_node - if outer_to_connect is not None: + if isinstance(outer_map_entry, nodes.EntryNode): outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) + else: + assert isinstance(outer_map_entry, SDFG) or outer_map_entry is None # TODO If nothing is connected at this point, figure out what's the last thing that # we should connect to. Then, add an empty memlet from that last thing to this @@ -470,9 +514,11 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Copy data descriptor from parent SDFG and add input connector if memlet.data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent - while memlet.data not in parent_sdfg.arrays: + sdfg_counter = 1 + while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: parent_sdfg = parent_sdfg.parent.parent assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) # Transients passed into a nested SDFG become non-transient inside that nested SDFG @@ -516,6 +562,11 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: if isinstance(scope_node, SDFG): if memlet.data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) # Transients passed into a nested SDFG become non-transient inside that nested SDFG diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 364219faa2..7f8a1c4c8b 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -552,6 +552,14 @@ def test_DelnFlux_tmp(): sdfg.validate() +def test_FvTp2d_tmp(): + loaded = dace.SDFG.from_file("tmp_FvTp2d.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + # TODO: find an automatic way to test stuff here if __name__ == '__main__': From 98071f729e74cd06b10836141116990bdb423d3c Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 7 May 2025 18:50:53 +0200 Subject: [PATCH 066/137] WIP: force boundary after assigns, timings also avoid duplicate StateBoundary nodes before maps (if possible) --- .../analysis/schedule_tree/tree_to_sdfg.py | 58 +++++++++++++- tests/schedule_tree/to_sdfg_test.py | 80 +++++++++++++++++++ 2 files changed, 135 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 826664e4ad..02bbca6f06 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict +import time from dace import dtypes, subsets, symbolic from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu @@ -33,15 +34,23 @@ def from_schedule_tree(stree: tn.ScheduleTreeRoot, :return: An SDFG representing the schedule tree. """ # Set SDFG descriptor repository + s = time.time() result = SDFG(stree.name, propagate=False) result.arg_names = copy.deepcopy(stree.arg_names) for key, container in stree.containers.items(): result._arrays[key] = copy.deepcopy(container) result.constants_prop = copy.deepcopy(stree.constants) result.symbols = copy.deepcopy(stree.symbols) + print("\n") + print(f"Setup SDFG descriptor repository in {(time.time() - s):.3f} seconds.") # after WAW, before label, etc. + s = time.time() stree = insert_state_boundaries_to_tree(stree) + print(f"Inserted state boundaries in {(time.time() - s):.3f} seconds.") + + # main visitor + s = time.time() class StreeToSDFG(tn.ScheduleNodeVisitor): @@ -637,8 +646,12 @@ def _pending_interstate_assignments(self) -> Dict: return assignments StreeToSDFG().visit(stree, sdfg=result) + print(f"Main visitor took {(time.time() - s):.3f} seconds.") + # memlet propagation + s = time.time() propagation.propagate_memlets_sdfg(result) + print(f"Memlet propagation took {(time.time() - s):.3f} seconds.") return result @@ -657,6 +670,8 @@ def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTr :param stree: The schedule tree to operate on. """ + s = time.time() + # Simple boundary node inserter for control flow blocks and state labels class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer): @@ -670,21 +685,58 @@ def visit_StateLabel(self, node: tn.StateLabel): # First, insert boundaries around labels and control flow stree = SimpleStateBoundaryInserter().visit(stree) + print(f"\tSimpleStateBoundaryInserter took {(time.time() - s):.3f} seconds.") + s = time.time() # Then, insert boundaries after unmet memory dependencies or potential data races _insert_memory_dependency_state_boundaries(stree) + print(f"\tMemory dependency analysis took {(time.time() - s):.3f} seconds.") + + s = time.time() + + # Insert a state boundary after every symbol assignment to ensure symbols are assigned before usage + class SymbolAssignmentBoundaryInserter(tn.ScheduleNodeTransformer): + + def visit_AssignNode(self, node: tn.AssignNode): + # We can assume that assignment nodes are at least contained in the root scope. + assert node.parent, "Expected assignment nodes live a parent scope." + + # Find this node in the parent's children. + node_index = _list_index(node.parent.children, node) + + # Don't add boundary if there's already one or for immediately following assignment nodes. + if node_index < len(node.parent.children) - 1 and isinstance(node.parent.children[node_index + 1], + (tn.StateBoundaryNode, tn.AssignNode)): + return self.generic_visit(node) + + return [self.generic_visit(node), tn.StateBoundaryNode()] + + stree = SymbolAssignmentBoundaryInserter().visit(stree) + print(f"\tSymbolAssignmentBoundaryInserter took {(time.time() - s):.3f} seconds.") # Hack: "backprop-insert" state boundaries from nested SDFGs + s = time.time() + class NestedSDFGStateBoundaryInserter(tn.ScheduleNodeTransformer): - def visit_scope(self, scope: tn.ScheduleTreeScope): + def visit_MapScope(self, scope: tn.MapScope): visited = self.generic_visit(scope) - if isinstance(scope, tn.MapScope) and any( - [isinstance(child, tn.StateBoundaryNode) for child in scope.children]): + if any([isinstance(child, tn.StateBoundaryNode) for child in scope.children]): + # We can assume that map nodes are at least contained in the root scope. + assert scope.parent is not None + + # Find this scope in its parent's children + node_index = _list_index(scope.parent.children, scope) + + # If there's already a state boundary before the map, don't add another one + if node_index > 0 and isinstance(scope.parent.children[node_index - 1], tn.StateBoundaryNode): + return visited + return [tn.StateBoundaryNode(), visited] return visited stree = NestedSDFGStateBoundaryInserter().visit(stree) + print(f"\tNestedSDFGStateBoundaryInserter took {(time.time() - s):.3f} seconds.") return stree diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 7f8a1c4c8b..d9ebabe1c9 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -528,6 +528,62 @@ def test_edge_assignment_read_after_write(): assert [edge.data.assignments for edge in sdfg.edges()] == [{"my_condition": "True"}, {"condition": "my_condition"}] +def test_assign_nodes_force_state_transition(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = mySymbol'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] + + +def test_assign_nodes_multiple_force_one_transition(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), + tn.AssignNode("myOtherSymbol", CodeBlock("2"), dace.InterstateEdge()), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = mySymbol + myOtherSymbol'), {}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [type(child) + for child in stree.children] == [tn.AssignNode, tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] + + +def test_assign_nodes_avoid_duplicate_boundaries(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + children=[ + tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), + tn.StateBoundaryNode(), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = mySymbol + myOtherSymbol'), {}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] + + def test_xppm_tmp(): loaded = dace.SDFG.from_file("test.sdfgz") stree = loaded.as_schedule_tree() @@ -560,6 +616,30 @@ def test_FvTp2d_tmp(): sdfg.validate() +def test_FxAdv_tmp(): + loaded = dace.SDFG.from_file("tmp_FxAdv.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_D_SW_tmp(): + loaded = dace.SDFG.from_file("tmp_D_SW.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_UpdateDzD(): + loaded = dace.SDFG.from_file("tmp_UpdateDzD.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + # TODO: find an automatic way to test stuff here if __name__ == '__main__': From b4a9ecd4be8e518cc172d6b875c8da41e8ba9d7e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 8 May 2025 10:35:05 +0200 Subject: [PATCH 067/137] Perf: cache if one subset is contained in another --- dace/sdfg/memlet_utils.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 6cc1354b71..897dfb90f9 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -188,6 +188,7 @@ class MemletDict(Dict[Memlet, T]): """ Implements a dictionary with memlet keys that considers subsets that intersect or are covered by its other memlets. """ + covers_cache: Dict[Tuple, bool] = {} def __init__(self, **kwargs): self.internal_dict: Dict[str, Dict[Memlet, T]] = defaultdict(dict) @@ -202,7 +203,12 @@ def _getkey(self, elem: Memlet) -> Optional[Memlet]: if elem.data not in self.internal_dict: return None for existing_memlet in self.internal_dict[elem.data]: - if existing_memlet.subset.covers(elem.subset): + key = (existing_memlet.subset, elem.subset) + is_covered = self.covers_cache.get(key, None) + if is_covered is None: + is_covered = existing_memlet.subset.covers(elem.subset) + self.covers_cache[key] = is_covered + if is_covered: return existing_memlet try: if existing_memlet.subset.intersects(elem.subset) == False: # Definitely does not intersect From 4c60fdd8d51405a9e272fbb5c72a2f6fb063aeb7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 8 May 2025 18:07:58 +0200 Subject: [PATCH 068/137] tmp: just added a couple roundtrip tests --- tests/schedule_tree/to_sdfg_test.py | 22 +++++++++++++++++++--- 1 file changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index d9ebabe1c9..f1e154fc02 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -584,8 +584,8 @@ def test_assign_nodes_avoid_duplicate_boundaries(): assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] -def test_xppm_tmp(): - loaded = dace.SDFG.from_file("test.sdfgz") +def test_XPPM_tmp(): + loaded = dace.SDFG.from_file("tmp_XPPM.sdfgz") stree = loaded.as_schedule_tree() sdfg = stree.as_sdfg() @@ -632,7 +632,7 @@ def test_D_SW_tmp(): sdfg.validate() -def test_UpdateDzD(): +def test_UpdateDzD_tmp(): loaded = dace.SDFG.from_file("tmp_UpdateDzD.sdfgz") stree = loaded.as_schedule_tree() @@ -640,6 +640,22 @@ def test_UpdateDzD(): sdfg.validate() +def test_Fillz_tmp(): + loaded = dace.SDFG.from_file("tmp_Fillz.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_Ray_Fast_tmp(): + loaded = dace.SDFG.from_file("tmp_Ray_Fast.sdfgz") + stree = loaded.as_schedule_tree() + + sdfg = stree.as_sdfg() + sdfg.validate() + + # TODO: find an automatic way to test stuff here if __name__ == '__main__': From a14103d12770342f4e86b972102a65e8b0603660 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Sun, 11 May 2025 20:49:19 -0700 Subject: [PATCH 069/137] Remove unnecessary symbols from the schedule tree descriptor repository --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index c3a9ea049b..7c6b2f88ec 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -341,7 +341,13 @@ def create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot) # the nested SDFGs' descriptor repositories for nsdfg in sdfg.all_sdfgs_recursive(): transients = {k: v for k, v in nsdfg.arrays.items() if v.transient} - symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols} + + # Get all symbols that are not participating in nested SDFG symbol mappings (they will be removed) + syms_to_ignore = set() + if nsdfg.parent_nsdfg_node is not None: + syms_to_ignore = nsdfg.parent_nsdfg_node.symbol_mapping.keys() + symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols and k not in syms_to_ignore} + constants = {k: v for k, v in nsdfg.constants_prop.items() if k not in stree.constants} stree.containers.update(transients) stree.symbols.update(symbols) From 806aa5ea768358274bd65afac57b930a29bb0c3e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 12 May 2025 15:48:49 +0200 Subject: [PATCH 070/137] WIP: AssignNode array access inside nested SDFG In case AssignNodes access arrays, make sure the arrays are available in the nested SDFG. This allows D_SW's translate tests in PyFV3 to validate. It also allows UpdateDzD's translate test to validate independetly of ConstantPropagation in the orchestration pipeline. --- .../analysis/schedule_tree/tree_to_sdfg.py | 40 +++++++++++++++++++ dace/sdfg/analysis/schedule_tree/treenodes.py | 7 +++- tests/schedule_tree/to_sdfg_test.py | 2 +- 3 files changed, 47 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 02bbca6f06..9cf3b0747a 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -130,6 +130,46 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: # see visitors below. self._interstate_symbols.append(node) + # If AssignNode depends on arrays, e.g. `my_sym = my_array[__k] > 0`, make sure array accesses can be resolved. + input_memlets = node.input_memlets() + if not input_memlets: + return + + for entry in reversed(self._dataflow_stack): + scope_node, to_connect = entry + if isinstance(scope_node, SDFG): + # In case we are inside a nested SDFG, make sure memlet data can be + # resolved by explicitly adding inputs. + for memlet in input_memlets: + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) + return + + for memlet in input_memlets: + # If we aren't inside a nested SDFG, make sure all memlets can be resolved. + # Imo, this should always be the case. It not, raise an error. + if memlet.data not in sdfg.arrays: + raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: before_state = self._current_state pending = self._pending_interstate_assignments() diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index b51e35f286..9a3cad19dc 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -12,6 +12,8 @@ from dace.memlet import Memlet from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union +from dace.transformation.passes.simplify import SimplifyPass + INDENTATION = ' ' @@ -209,8 +211,11 @@ def as_sdfg(self, validate: bool = True, simplify: bool = True) -> SDFG: if validate: sdfg.validate() + # TODO + # UpdateDzD-ConstantPropagation.sdfgz generates an SDFG here that validates, but that doesn't + # simplify. Simplification fails in constant propagation with `__k_6` not found. if simplify: - sdfg.simplify(validate=validate) + SimplifyPass(validate=validate, skip=["ConstantPropagation"]).apply_pass(sdfg, {}) return sdfg diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index f1e154fc02..099257e0a0 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -633,7 +633,7 @@ def test_D_SW_tmp(): def test_UpdateDzD_tmp(): - loaded = dace.SDFG.from_file("tmp_UpdateDzD.sdfgz") + loaded = dace.SDFG.from_file("tmp_UpdateDzD-ConstantPropagation.sdfgz") stree = loaded.as_schedule_tree() sdfg = stree.as_sdfg() From 40c00eab5dce917b50825ad74186dba24378475a Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 14 May 2025 15:18:08 +0200 Subject: [PATCH 071/137] Unrelated: Fixing typos in comments --- dace/subsets.py | 2 +- dace/transformation/dataflow/map_expansion.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/subsets.py b/dace/subsets.py index 69b236405b..c5e0debc0b 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -565,7 +565,7 @@ def from_string(string): # Open parenthesis found, increase count by 1 if token[i] == '(': count += 1 - # Closing parenthesis found, decrease cound by 1 + # Closing parenthesis found, decrease count by 1 elif token[i] == ')': count -= 1 # Move to the next character diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 8bc14213b0..bef3b930e2 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -34,7 +34,7 @@ class MapExpansion(pm.SingleStateTransformation): dtype=dtypes.ScheduleType, default=dtypes.ScheduleType.Sequential, allow_none=True) - expansion_limit = Property(desc="How many unidimensional maps will be creaed, known as k. " + expansion_limit = Property(desc="How many unidimensional maps will be created, known as k. " "If None, the default no limit is in place.", dtype=int, allow_none=True, From 45e5d0413b216d1c2001a541c50fab8ef019a940 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 15 May 2025 11:05:08 +0200 Subject: [PATCH 072/137] Go back to just simplify. Can't repo the problem anymore. --- dace/sdfg/analysis/schedule_tree/treenodes.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 9a3cad19dc..448c4e09cb 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -211,11 +211,8 @@ def as_sdfg(self, validate: bool = True, simplify: bool = True) -> SDFG: if validate: sdfg.validate() - # TODO - # UpdateDzD-ConstantPropagation.sdfgz generates an SDFG here that validates, but that doesn't - # simplify. Simplification fails in constant propagation with `__k_6` not found. if simplify: - SimplifyPass(validate=validate, skip=["ConstantPropagation"]).apply_pass(sdfg, {}) + sdfg.simplify(validate=validate) return sdfg From 04d442bbfc3cad024255e0878d30fe5c78da2365 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 15 May 2025 11:19:03 +0200 Subject: [PATCH 073/137] Unrelated: return empty set, not dict --- dace/sdfg/propagation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index cc279479ff..a5d6d7bf95 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1519,11 +1519,11 @@ def propagate_subset(memlets: List[Memlet], return new_memlet -def _freesyms(expr): +def _freesyms(expr) -> Set: """ Helper function that either returns free symbols for sympy expressions or an empty set if constant. """ if isinstance(expr, sympy.Basic): return expr.free_symbols - return {} + return set() From 3ae906f90dbffeff7997fed8f70e02d35e9b6277 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Fri, 16 May 2025 13:04:55 +0200 Subject: [PATCH 074/137] Backport: Memlet propagation: return set - as promised (#2008) The function is supposed to return an empty set, but instead returned an empty dict. Now it returns a set as promised. Added a return type such that mypy would complain. This is a backport of https://github.com/spcl/dace/pull/2007 to `v1/maintenance`. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- dace/sdfg/propagation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index f048389421..f1b7d3baf7 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1497,11 +1497,11 @@ def propagate_subset(memlets: List[Memlet], return new_memlet -def _freesyms(expr): - """ +def _freesyms(expr) -> Set: + """ Helper function that either returns free symbols for sympy expressions or an empty set if constant. """ if isinstance(expr, sympy.Basic): return expr.free_symbols - return {} + return set() From 954c088d26276d214f26f3f4fdee4f74088e9da9 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 6 Jun 2025 08:12:19 +0200 Subject: [PATCH 075/137] [florian] fix type of symbols directory --- dace/sdfg/analysis/schedule_tree/treenodes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 448c4e09cb..f942f5ff0c 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,7 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from dataclasses import dataclass, field -from dace import nodes, data, subsets +from dace import nodes, data, subsets, dtypes from dace.codegen import control_flow as cf from dace.properties import CodeBlock from dace.sdfg.memlet_utils import MemletSet @@ -189,7 +189,7 @@ class ScheduleTreeRoot(ScheduleTreeScope): """ name: str containers: Dict[str, data.Data] = field(default_factory=dict) - symbols: Dict[str, symbol] = field(default_factory=dict) + symbols: Dict[str, dtypes.typeclass] = field(default_factory=dict) constants: Dict[str, Tuple[data.Data, Any]] = field(default_factory=dict) callback_mapping: Dict[str, str] = field(default_factory=dict) arg_names: List[str] = field(default_factory=list) From 8d710cdf93d35980e98edb0dfcb4564d06afa934 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 6 Jun 2025 10:16:58 +0200 Subject: [PATCH 076/137] unrelated: memlet propagation with indices --- dace/sdfg/propagation.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index a5d6d7bf95..0755c43999 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1472,13 +1472,19 @@ def propagate_subset(memlets: List[Memlet], # free symbols list of the subset dimension or is undefined outside tmp_subset_rng = [] for s, ea in zip(subset, entire_array): - contains_params = False - contains_undefs = False - for sdim in s: - fsyms = _freesyms(sdim) + if isinstance(subset, subsets.Indices): + fsyms = _freesyms(s) fsyms_str = set(map(str, fsyms)) - contains_params |= len(fsyms_str & paramset) != 0 - contains_undefs |= len(fsyms - defined_variables) != 0 + contains_params = len(fsyms_str & paramset) != 0 + contains_undefs = len(fsyms - defined_variables) != 0 + else: + contains_params = False + contains_undefs = False + for sdim in s: + fsyms = _freesyms(sdim) + fsyms_str = set(map(str, fsyms)) + contains_params |= len(fsyms_str & paramset) != 0 + contains_undefs |= len(fsyms - defined_variables) != 0 if contains_params or contains_undefs: tmp_subset_rng.append(ea) else: From 13ff3ece03d0405aabc76ea70f665ee196589802 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 6 Jun 2025 17:04:57 +0200 Subject: [PATCH 077/137] Unrelated: subset intersection between ranges and indices --- dace/sdfg/memlet_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 897dfb90f9..cf886573f1 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -135,7 +135,7 @@ def add(self, elem: Memlet): # TODO(later): Consider other_subset as well for existing_memlet in self.internal_set[elem.data]: try: - if existing_memlet.subset.intersects(elem.subset) == True: # Definitely intersects + if subsets.intersects(existing_memlet.subset, elem.subset) == True: # Definitely intersects if existing_memlet.subset.covers(elem.subset): break # Nothing to do @@ -161,7 +161,7 @@ def __contains__(self, elem: Memlet) -> bool: return True if self.intersection_is_contained: try: - if existing_memlet.subset.intersects(elem.subset) == False: + if subsets.intersects(existing_memlet.subset, elem.subset) == False: continue else: # May intersect or indeterminate return True @@ -211,7 +211,7 @@ def _getkey(self, elem: Memlet) -> Optional[Memlet]: if is_covered: return existing_memlet try: - if existing_memlet.subset.intersects(elem.subset) == False: # Definitely does not intersect + if subsets.intersects(existing_memlet.subset, elem.subset) == False: # Definitely does not intersect continue except TypeError: pass From f78ea5ddf8e5f79aee67612a0ab2ba5a0e5925e2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 13 Jun 2025 09:46:31 +0200 Subject: [PATCH 078/137] Stree to SDFG: allow to configure simplify() --- dace/sdfg/analysis/schedule_tree/treenodes.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index f942f5ff0c..4bb11ec880 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -8,12 +8,9 @@ from dace.sdfg.propagation import propagate_subset from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast from dace.sdfg.state import SDFGState -from dace.symbolic import symbol from dace.memlet import Memlet from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union -from dace.transformation.passes.simplify import SimplifyPass - INDENTATION = ' ' @@ -194,7 +191,12 @@ class ScheduleTreeRoot(ScheduleTreeScope): callback_mapping: Dict[str, str] = field(default_factory=dict) arg_names: List[str] = field(default_factory=list) - def as_sdfg(self, validate: bool = True, simplify: bool = True) -> SDFG: + def as_sdfg(self, + validate: bool = True, + simplify: bool = True, + validate_all: bool = False, + skip: Set[str] = set(), + verbose: bool = False) -> SDFG: """ Convert this schedule tree representation (back) into an SDFG. @@ -202,6 +204,9 @@ def as_sdfg(self, validate: bool = True, simplify: bool = True) -> SDFG: :param simplify: If true, simplify generated SDFG. The conversion might insert things like extra empty states that can be cleaned up automatically. The value of `validate` is passed on to `simplify()`. + :param validate_all: When simplifying, validate all intermediate SDFGs. Unused if simplify is False. + :param skip: Set of names of simplify passes to skip. Unused if simplify is False. + :param verbose: Turn on verbose logging of simplify. Unused if simplify is False. :return: SDFG version of this schedule tree. """ @@ -212,7 +217,8 @@ def as_sdfg(self, validate: bool = True, simplify: bool = True) -> SDFG: sdfg.validate() if simplify: - sdfg.simplify(validate=validate) + from dace.transformation.passes.simplify import SimplifyPass + SimplifyPass(validate=validate, validate_all=validate_all, skip=skip, verbose=verbose).apply_pass(self, {}) return sdfg From 8a76ed82bad510c576c41d8f2c94e52ed5877f77 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 13 Jun 2025 10:09:25 +0200 Subject: [PATCH 079/137] Fixup: simplify sdfg, not stree --- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 4bb11ec880..e9faa9c141 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -218,7 +218,7 @@ def as_sdfg(self, if simplify: from dace.transformation.passes.simplify import SimplifyPass - SimplifyPass(validate=validate, validate_all=validate_all, skip=skip, verbose=verbose).apply_pass(self, {}) + SimplifyPass(validate=validate, validate_all=validate_all, skip=skip, verbose=verbose).apply_pass(sdfg, {}) return sdfg From 89efd9af34d904fadb60c8a9691b2dc2179b35db Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 17 Jun 2025 15:45:34 +0200 Subject: [PATCH 080/137] Nested SDFGs inside maps inherit their schedule --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 9cf3b0747a..c377c642b8 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -326,7 +326,8 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, sdfg, inputs=connectors["inputs"], - outputs=connectors["outputs"]) + outputs=connectors["outputs"], + schedule=node.node.map.schedule) # connect nested SDFG to surrounding map scope assert self._dataflow_stack From b9bcc3d4b24984b14eb8d25a86fee71cd11dfefb Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 18 Jun 2025 09:37:50 +0200 Subject: [PATCH 081/137] Don't loose symbols in tswds Don't (ever) loose already existing symbols in `traverse_sdfg_with_defined_symbols`. Imo the symbol list should only ever grow. --- dace/sdfg/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 623a82c5bf..e2400f4089 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1564,7 +1564,7 @@ def _tswds_cf_region( if edge.src not in visited: visited.add(edge.src) if isinstance(edge.src, SDFGState): - yield from _tswds_state(sdfg, edge.src, {}, recursive) + yield from _tswds_state(sdfg, edge.src, symbols, recursive) elif isinstance(edge.src, ControlFlowRegion): yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) From b29a263af5523222049e31564b58b62d39554917 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 18 Jun 2025 10:45:00 +0200 Subject: [PATCH 082/137] Only report cycles if we actually find them The issue here is that `find_cycles()` returns a generator, which is truthy even if there are no cycles found. The proposed fix is to evaluate the generator into a list. That way, if the list is empty, we don't report having found a cycle if there are none. --- dace/sdfg/state.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c99f45bb44..4cb9f27091 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -617,16 +617,16 @@ def scope_children(self, # Sanity checks if validate and len(eq) != 0: - cycles = self.find_cycles() + cycles = list(self.find_cycles()) if cycles: - raise ValueError('Found cycles in state %s: %s' % (self.label, list(cycles))) + raise ValueError('Found cycles in state %s: %s' % (self.label, cycles)) raise RuntimeError("Leftover nodes in queue: {}".format(eq)) entry_nodes = set(n for n in self.nodes() if isinstance(n, nd.EntryNode)) | {None} if (validate and len(result) != len(entry_nodes)): - cycles = self.find_cycles() + cycles = list(self.find_cycles()) if cycles: - raise ValueError('Found cycles in state %s: %s' % (self.label, list(cycles))) + raise ValueError('Found cycles in state %s: %s' % (self.label, cycles)) raise RuntimeError("Some nodes were not processed: {}".format(entry_nodes - result.keys())) # Cache result From a0d912571982d5896c4518fe2bdf3403b98b6fa6 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Thu, 19 Jun 2025 06:42:04 +0200 Subject: [PATCH 083/137] Only report cycles if we found some (#2055) ## Description The issue here is that `find_cycles()` returns a generator, which is truthy even if there are no cycles. The proposed solution is to evaluate the generator. That way, if the generated list is empty, it will be falsy and we won't report cycles if there are none. This allows developers to see the other error message, which is currently unreachable. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- dace/sdfg/state.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index c99f45bb44..4cb9f27091 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -617,16 +617,16 @@ def scope_children(self, # Sanity checks if validate and len(eq) != 0: - cycles = self.find_cycles() + cycles = list(self.find_cycles()) if cycles: - raise ValueError('Found cycles in state %s: %s' % (self.label, list(cycles))) + raise ValueError('Found cycles in state %s: %s' % (self.label, cycles)) raise RuntimeError("Leftover nodes in queue: {}".format(eq)) entry_nodes = set(n for n in self.nodes() if isinstance(n, nd.EntryNode)) | {None} if (validate and len(result) != len(entry_nodes)): - cycles = self.find_cycles() + cycles = list(self.find_cycles()) if cycles: - raise ValueError('Found cycles in state %s: %s' % (self.label, list(cycles))) + raise ValueError('Found cycles in state %s: %s' % (self.label, cycles)) raise RuntimeError("Some nodes were not processed: {}".format(entry_nodes - result.keys())) # Cache result From 439b929ac6d431ada18c5e51364c6a8e083dda40 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 25 Jun 2025 05:20:13 +0200 Subject: [PATCH 084/137] Avoid losing symbols in `traverse_sdfg_with_defined_symbols` (#2054) ## Description From what I can tell, `traverse_sdfg_with_defined_symbols()` is supposed to traverse the SDFG, accumulating symbols as it runs down the tree. The current implementation resets the dict of "already known symbols" when recursing down into a state that is the source of an edge. That seems odd. In particular, not even the globally available `sdfg.symbols` are marked as defined anymore with that reset. The proposed fix passes along already known symbols like in all other cases, ensuring that the dict of known symbols only ever grows. --------- Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- dace/sdfg/utils.py | 2 +- tests/sdfg/utils_test.py | 21 +++++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) create mode 100644 tests/sdfg/utils_test.py diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 623a82c5bf..e2400f4089 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1564,7 +1564,7 @@ def _tswds_cf_region( if edge.src not in visited: visited.add(edge.src) if isinstance(edge.src, SDFGState): - yield from _tswds_state(sdfg, edge.src, {}, recursive) + yield from _tswds_state(sdfg, edge.src, symbols, recursive) elif isinstance(edge.src, ControlFlowRegion): yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) diff --git a/tests/sdfg/utils_test.py b/tests/sdfg/utils_test.py new file mode 100644 index 0000000000..fe3ae0abdf --- /dev/null +++ b/tests/sdfg/utils_test.py @@ -0,0 +1,21 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import dace + +from dace.sdfg import utils + + +def test_traverse_sdfg_with_defined_symbols(): + pass + sdfg = dace.SDFG("tester") + sdfg.add_symbol("my_symbol", dace.int32) + + start = sdfg.add_state("start", is_start_block=True) + start.add_tasklet("noop", set(), set(), "") + sdfg.add_state_after(start, "next") + + for _state, _node, defined_symbols in utils.traverse_sdfg_with_defined_symbols(sdfg): + assert "my_symbol" in defined_symbols + + +if __name__ == "__main": + test_traverse_sdfg_with_defined_symbols() From 1bf841113c0b4fedbef3cbe5b4cdbbe1b1fc787f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 8 Jul 2025 17:40:46 +0200 Subject: [PATCH 085/137] Move main visitor out and remove print statements --- .../analysis/schedule_tree/tree_to_sdfg.py | 1186 ++++++++--------- 1 file changed, 581 insertions(+), 605 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index c377c642b8..8dc55c8c16 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,8 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict -import time -from dace import dtypes, subsets, symbolic +from dace import symbolic from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge @@ -23,676 +22,664 @@ class StateBoundaryBehavior(Enum): MAX_NESTED_SDFGS: Final[int] = 1000 -def from_schedule_tree(stree: tn.ScheduleTreeRoot, - state_boundary_behavior: StateBoundaryBehavior = StateBoundaryBehavior.STATE_TRANSITION) -> SDFG: - """ - Converts a schedule tree into an SDFG. - - :param stree: The schedule tree root to convert. - :param state_boundary_behavior: Sets the behavior upon encountering a state boundary (e.g., write-after-write). - See the ``StateBoundaryBehavior`` enumeration for more details. - :return: An SDFG representing the schedule tree. - """ - # Set SDFG descriptor repository - s = time.time() - result = SDFG(stree.name, propagate=False) - result.arg_names = copy.deepcopy(stree.arg_names) - for key, container in stree.containers.items(): - result._arrays[key] = copy.deepcopy(container) - result.constants_prop = copy.deepcopy(stree.constants) - result.symbols = copy.deepcopy(stree.symbols) - print("\n") - print(f"Setup SDFG descriptor repository in {(time.time() - s):.3f} seconds.") - - # after WAW, before label, etc. - s = time.time() - stree = insert_state_boundaries_to_tree(stree) - print(f"Inserted state boundaries in {(time.time() - s):.3f} seconds.") - - # main visitor - s = time.time() - - class StreeToSDFG(tn.ScheduleNodeVisitor): +class StreeToSDFG(tn.ScheduleNodeVisitor): - def __init__(self, start_state: Optional[SDFGState] = None) -> None: - # state management - self._state_stack: List[SDFGState] = [] - self._current_state = start_state + def __init__(self, start_state: Optional[SDFGState] = None) -> None: + # state management + self._state_stack: List[SDFGState] = [] + self._current_state = start_state - # inter-state symbol assignments - self._interstate_symbols: List[tn.AssignNode] = [] + # inter-state symbol assignments + self._interstate_symbols: List[tn.AssignNode] = [] - # dataflow scopes - # List[ (MapEntryNode, ToConnect) | (SDFG, {"inputs": set(), "outputs": set()}) ] - self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]] - | Tuple[SDFG, Dict[str, Set[str]]]] = [] + # dataflow scopes + # List[ (MapEntryNode, ToConnect) | (SDFG, {"inputs": set(), "outputs": set()}) ] + self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]] + | Tuple[SDFG, Dict[str, Set[str]]]] = [] - # caches - self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} + # caches + self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} - def _pop_state(self, label: Optional[str] = None) -> SDFGState: - """Pops the last state from the state stack. + def _pop_state(self, label: Optional[str] = None) -> SDFGState: + """Pops the last state from the state stack. - :param str, optional label: Ensures the popped state's label starts with the given string. + :param str, optional label: Ensures the popped state's label starts with the given string. - :return: The popped state. - """ - if not self._state_stack: - raise ValueError("Can't pop state from empty stack.") + :return: The popped state. + """ + if not self._state_stack: + raise ValueError("Can't pop state from empty stack.") - popped = self._state_stack.pop() - if label is not None: - assert popped.label.startswith(label) + popped = self._state_stack.pop() + if label is not None: + assert popped.label.startswith(label) - return popped + return popped - def _ensure_access_cache(self, state: SDFGState) -> Dict[str, nodes.AccessNode]: - """Ensure an access_cache entry for the given state. + def _ensure_access_cache(self, state: SDFGState) -> Dict[str, nodes.AccessNode]: + """Ensure an access_cache entry for the given state. - Checks if there exists an access_cache for `state`. Creates an empty one if it doesn't exist yet. + Checks if there exists an access_cache for `state`. Creates an empty one if it doesn't exist yet. - :param SDFGState state: The state to check. + :param SDFGState state: The state to check. - :return: The state's access_cache. - """ - if state not in self._access_cache: - self._access_cache[state] = {} + :return: The state's access_cache. + """ + if state not in self._access_cache: + self._access_cache[state] = {} - return self._access_cache[state] + return self._access_cache[state] - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: - assert self._current_state is None, "Expected no 'current_state' at root." - assert not self._state_stack, "Expected empty state stack at root." - assert not self._dataflow_stack, "Expected empty dataflow stack at root." - assert not self._interstate_symbols, "Expected empty list of symbols at root." - - self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) - self.visit(node.children, sdfg=sdfg) + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: + assert self._current_state is None, "Expected no 'current_state' at root." + assert not self._state_stack, "Expected empty state stack at root." + assert not self._dataflow_stack, "Expected empty dataflow stack at root." + assert not self._interstate_symbols, "Expected empty list of symbols at root." - assert not self._state_stack, "Expected empty state stack." - assert not self._dataflow_stack, "Expected empty dataflow stack." - assert not self._interstate_symbols, "Expected empty list of symbols to add." + self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) + self.visit(node.children, sdfg=sdfg) - def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + assert not self._state_stack, "Expected empty state stack." + assert not self._dataflow_stack, "Expected empty dataflow stack." + assert not self._interstate_symbols, "Expected empty list of symbols to add." - def visit_StateLabel(self, node: tn.StateLabel, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") - def visit_GotoNode(self, node: tn.GotoNode, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + def visit_StateLabel(self, node: tn.StateLabel, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") - def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: - # We just collect them here. They'll be added when state boundaries are added, - # see visitors below. - self._interstate_symbols.append(node) - - # If AssignNode depends on arrays, e.g. `my_sym = my_array[__k] > 0`, make sure array accesses can be resolved. - input_memlets = node.input_memlets() - if not input_memlets: - return - - for entry in reversed(self._dataflow_stack): - scope_node, to_connect = entry - if isinstance(scope_node, SDFG): - # In case we are inside a nested SDFG, make sure memlet data can be - # resolved by explicitly adding inputs. - for memlet in input_memlets: - # Copy data descriptor from parent SDFG and add input connector - if memlet.data not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet.data].transient: - sdfg.arrays[memlet.data].transient = False - # TODO - # ... unless they are only ever used inside the nested SDFG, in which case - # we should delete them from the parent SDFG's array list. - # NOTE This can probably be done automatically by a cleanup pass in the end. - # Something like DDE should be able to do this. - - assert memlet.data not in to_connect["inputs"] - to_connect["inputs"].add(memlet.data) - return - - for memlet in input_memlets: - # If we aren't inside a nested SDFG, make sure all memlets can be resolved. - # Imo, this should always be the case. It not, raise an error. - if memlet.data not in sdfg.arrays: - raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") + def visit_GotoNode(self, node: tn.GotoNode, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") - def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: - before_state = self._current_state - pending = self._pending_interstate_assignments() - pending[node.header.itervar] = node.header.init + def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: + # We just collect them here. They'll be added when state boundaries are added, + # see visitors below. + self._interstate_symbols.append(node) - guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) - self._current_state = guard_state + # If AssignNode depends on arrays, e.g. `my_sym = my_array[__k] > 0`, make sure array accesses can be resolved. + input_memlets = node.input_memlets() + if not input_memlets: + return - body_state = sdfg.add_state(label="loop_body") - self._current_state = body_state - sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) - - # visit children inside the loop - self.visit(node.children, sdfg=sdfg) - - pending = self._pending_interstate_assignments() - pending[node.header.itervar] = node.header.update - _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) - - after_state = sdfg.add_state(label="loop_after") - self._current_state = after_state - sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) - - def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: - before_state = self._current_state - guard_state = _insert_and_split_assignments(sdfg, - before_state, - label="guard_state", - assignments=self._pending_interstate_assignments()) - self._current_state = guard_state - - body_state = sdfg.add_state(label="loop_body") - self._current_state = body_state - sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) - - # visit children inside the loop - self.visit(node.children, sdfg=sdfg) - _insert_and_split_assignments(sdfg, - before_state=self._current_state, - after_state=guard_state, - assignments=self._pending_interstate_assignments()) - - after_state = sdfg.add_state(label="loop_after") - self._current_state = after_state - sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) - - def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: - # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: - before_state = self._current_state - - # add guard state - guard_state = _insert_and_split_assignments(sdfg, - before_state, - label="guard_state", - assignments=self._pending_interstate_assignments()) - - # add true_state - true_state = sdfg.add_state(label="true_state") - sdfg.add_edge(guard_state, true_state, InterstateEdge(condition=node.condition)) - self._current_state = true_state - - # visit children in the true branch - self.visit(node.children, sdfg=sdfg) - - # add merge_state - merge_state = _insert_and_split_assignments(sdfg, - self._current_state, - label="merge_state", - assignments=self._pending_interstate_assignments()) - - # Check if there's an `ElseScope` following this node (in the parent's children). - # Filter StateBoundaryNodes, which we inserted earlier, for this analysis. - filtered = [n for n in node.parent.children if not isinstance(n, tn.StateBoundaryNode)] - if_index = _list_index(filtered, node) - has_else_branch = len(filtered) > if_index + 1 and isinstance(filtered[if_index + 1], tn.ElseScope) - - if has_else_branch: - # push merge_state on the stack for later usage in `visit_ElseScope` - self._state_stack.append(merge_state) - false_state = sdfg.add_state(label="false_state") - - sdfg.add_edge(guard_state, false_state, InterstateEdge(condition=f"not {node.condition.as_string}")) - - # push false_state on the stack for later usage in `visit_ElseScope` - self._state_stack.append(false_state) - else: - sdfg.add_edge(guard_state, merge_state, InterstateEdge(condition=f"not {node.condition.as_string}")) - self._current_state = merge_state - - def visit_StateIfScope(self, node: tn.StateIfScope, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_BreakNode(self, node: tn.BreakNode, sdfg: SDFG) -> None: - # AFAIK we don't support for break statements in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_ContinueNode(self, node: tn.ContinueNode, sdfg: SDFG) -> None: - # AFAIK we don't support for continue statements in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_ElifScope(self, node: tn.ElifScope, sdfg: SDFG) -> None: - # AFAIK we don't support elif scopes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: - # get false_state form stack - false_state = self._pop_state("false_state") - self._current_state = false_state - - # visit children inside the else branch - self.visit(node.children, sdfg=sdfg) - - # merge false-branch into merge_state - merge_state = self._pop_state("merge_state") - _insert_and_split_assignments(sdfg, - before_state=self._current_state, - after_state=merge_state, - assignments=self._pending_interstate_assignments()) - self._current_state = merge_state - - def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: - dataflow_stack_size = len(self._dataflow_stack) - state_stack_size = len(self._state_stack) - - # prepare inner SDFG - inner_sdfg = SDFG("nested_sdfg", parent=self._current_state) - start_state = inner_sdfg.add_state("nested_root", is_start_block=True) - - # update stacks and current state - old_state_label = self._current_state.label - self._state_stack.append(self._current_state) - self._dataflow_stack.append((inner_sdfg, {"inputs": set(), "outputs": set()})) - self._current_state = start_state - - # visit children - for child in node.children: - self.visit(child, sdfg=inner_sdfg) - - # restore current state and stacks - self._current_state = self._pop_state(old_state_label) - assert len(self._state_stack) == state_stack_size - _, connectors = self._dataflow_stack.pop() - assert len(self._dataflow_stack) == dataflow_stack_size - - # insert nested SDFG - nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, - sdfg, - inputs=connectors["inputs"], - outputs=connectors["outputs"], - schedule=node.node.map.schedule) - - # connect nested SDFG to surrounding map scope - assert self._dataflow_stack - map_entry, to_connect = self._dataflow_stack[-1] - - # connect nsdfg input memlets (to be propagated upon completion of the SDFG) - for name in nsdfg.in_connectors: - out_connector = f"{PREFIX_PASSTHROUGH_OUT}{name}" - new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{name}") - new_out_connector = map_entry.add_out_connector(out_connector) - assert new_in_connector == True - assert new_in_connector == new_out_connector - - self._current_state.add_edge(map_entry, out_connector, nsdfg, name, - Memlet.from_array(name, nsdfg.sdfg.arrays[name])) - - # Add empty memlet if we didn't add any in the loop above - if self._current_state.out_degree(map_entry) < 1: - self._current_state.add_nedge(map_entry, nsdfg, Memlet()) - - # connect nsdfg output memlets (to be propagated) - for name in nsdfg.out_connectors: - to_connect[name] = (nsdfg, Memlet.from_array(name, nsdfg.sdfg.arrays[name])) - - def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: - dataflow_stack_size = len(self._dataflow_stack) - - # map entry - # --------- - map_entry = nodes.MapEntry(node.node.map) - self._current_state.add_node(map_entry) - self._dataflow_stack.append((map_entry, dict())) - - # Set a new access_cache before visiting children such that they have their - # own access cache (per map scope). - access_cache = self._ensure_access_cache(self._current_state) - self._access_cache[self._current_state] = {} - - # visit children inside the map - type_of_children = [type(child) for child in node.children] - last_child_is_MapScope = type_of_children[-1] == tn.MapScope - all_others_are_Boundaries = type_of_children.count(tn.StateBoundaryNode) == len(type_of_children) - 1 - if last_child_is_MapScope and all_others_are_Boundaries: - # skip weirdly added StateBoundaryNode - # tmp: use this - for now - to "backprop-insert" extra state boundaries for nested SDFGs - self.visit(node.children[-1], sdfg=sdfg) - elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): - self._insert_nestedSDFG(node, sdfg) - else: - self.visit(node.children, sdfg=sdfg) - - # reset the access_cache - self._access_cache[self._current_state] = access_cache - - # dataflow stack management - _, to_connect = self._dataflow_stack.pop() - assert len(self._dataflow_stack) == dataflow_stack_size - outer_map_entry, outer_to_connect = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) - - # connect potential input connectors on map_entry - for connector in map_entry.in_connectors: - memlet_data = connector.removeprefix(PREFIX_PASSTHROUGH_IN) - - # connect to local access node (if available) - if memlet_data in access_cache: - cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, - map_entry, - dst_conn=connector, - memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) - continue - - if isinstance(outer_map_entry, nodes.EntryNode): - - # get it from outside the map - connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet_data}" - if connector_name not in outer_map_entry.out_connectors: - new_in_connector = outer_map_entry.add_in_connector(connector) - new_out_connector = outer_map_entry.add_out_connector(connector_name) - assert new_in_connector == True - assert new_in_connector == new_out_connector - - self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, - Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) - else: - if isinstance(outer_map_entry, SDFG): - # Copy data descriptor from parent SDFG and add input connector - if memlet_data not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while memlet_data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(memlet_data, parent_sdfg.arrays[memlet_data].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet_data].transient: - sdfg.arrays[memlet_data].transient = False - # TODO - # ... unless they are only ever used inside the nested SDFG, in which case - # we should delete them from the parent SDFG's array list. - # NOTE This can probably be done automatically by a cleanup pass in the end. - # Something like DDE should be able to do this. - - assert memlet_data not in outer_to_connect["inputs"] - outer_to_connect["inputs"].add(memlet_data) - else: - assert outer_map_entry is None - - # cache local read access - assert memlet_data not in access_cache - access_cache[memlet_data] = self._current_state.add_read(memlet_data) - cached_access = access_cache[memlet_data] - self._current_state.add_memlet_path(cached_access, - map_entry, - dst_conn=connector, - memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) - - if isinstance(outer_map_entry, nodes.EntryNode) and self._current_state.out_degree(outer_map_entry) < 1: - self._current_state.add_nedge(outer_map_entry, map_entry, Memlet()) - - # map_exit - # -------- - map_exit = nodes.MapExit(node.node.map) - self._current_state.add_node(map_exit) - - # connect writes to map_exit node - for name in to_connect: - in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" - out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" - new_in_connector = map_exit.add_in_connector(in_connector_name) - new_out_connector = map_exit.add_out_connector(out_connector_name) - assert new_in_connector == new_out_connector - - # connect "inside the map" - access_node, memlet = to_connect[name] - if isinstance(access_node, nodes.NestedSDFG): - self._current_state.add_edge(access_node, name, map_exit, in_connector_name, memlet) - else: - assert isinstance(access_node, nodes.AccessNode) - if self._current_state.out_degree(access_node) == 0 and self._current_state.in_degree( - access_node) == 1: - # this access_node is not used for anything else. - # let's remove it and add a direct connection instead - edges = [edge for edge in self._current_state.edges() if edge.dst == access_node] - assert len(edges) == 1 - self._current_state.add_memlet_path(edges[0].src, - map_exit, - src_conn=edges[0].src_conn, - dst_conn=in_connector_name, - memlet=edges[0].data) - self._current_state.remove_node(access_node) # edge is remove automatically - else: - self._current_state.add_memlet_path(access_node, - map_exit, - dst_conn=in_connector_name, - memlet=memlet) - - if isinstance(outer_map_entry, SDFG): - if name not in sdfg.arrays: + for entry in reversed(self._dataflow_stack): + scope_node, to_connect = entry + if isinstance(scope_node, SDFG): + # In case we are inside a nested SDFG, make sure memlet data can be + # resolved by explicitly adding inputs. + for memlet in input_memlets: + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent sdfg_counter = 1 - while name not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: parent_sdfg = parent_sdfg.parent.parent assert isinstance(parent_sdfg, SDFG) sdfg_counter += 1 - sdfg.add_datadesc(name, parent_sdfg.arrays[name].clone()) + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[name].transient: - sdfg.arrays[name].transient = False + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. - # Add out_connector in any case if not yet present, e.g. write after read - outer_to_connect["outputs"].add(name) + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) + return - # connect "outside the map" - access_node = self._current_state.add_write(name) - self._current_state.add_memlet_path(map_exit, - access_node, - src_conn=out_connector_name, - memlet=Memlet.from_array(name, sdfg.arrays[name])) + for memlet in input_memlets: + # If we aren't inside a nested SDFG, make sure all memlets can be resolved. + # Imo, this should always be the case. It not, raise an error. + if memlet.data not in sdfg.arrays: + raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") + + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + before_state = self._current_state + pending = self._pending_interstate_assignments() + pending[node.header.itervar] = node.header.init + + guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) + self._current_state = guard_state + + body_state = sdfg.add_state(label="loop_body") + self._current_state = body_state + sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) + + # visit children inside the loop + self.visit(node.children, sdfg=sdfg) + + pending = self._pending_interstate_assignments() + pending[node.header.itervar] = node.header.update + _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) + + after_state = sdfg.add_state(label="loop_after") + self._current_state = after_state + sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) + + def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + before_state = self._current_state + guard_state = _insert_and_split_assignments(sdfg, + before_state, + label="guard_state", + assignments=self._pending_interstate_assignments()) + self._current_state = guard_state + + body_state = sdfg.add_state(label="loop_body") + self._current_state = body_state + sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) + + # visit children inside the loop + self.visit(node.children, sdfg=sdfg) + _insert_and_split_assignments(sdfg, + before_state=self._current_state, + after_state=guard_state, + assignments=self._pending_interstate_assignments()) + + after_state = sdfg.add_state(label="loop_after") + self._current_state = after_state + sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) + + def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: + # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: + before_state = self._current_state + + # add guard state + guard_state = _insert_and_split_assignments(sdfg, + before_state, + label="guard_state", + assignments=self._pending_interstate_assignments()) + + # add true_state + true_state = sdfg.add_state(label="true_state") + sdfg.add_edge(guard_state, true_state, InterstateEdge(condition=node.condition)) + self._current_state = true_state + + # visit children in the true branch + self.visit(node.children, sdfg=sdfg) + + # add merge_state + merge_state = _insert_and_split_assignments(sdfg, + self._current_state, + label="merge_state", + assignments=self._pending_interstate_assignments()) + + # Check if there's an `ElseScope` following this node (in the parent's children). + # Filter StateBoundaryNodes, which we inserted earlier, for this analysis. + filtered = [n for n in node.parent.children if not isinstance(n, tn.StateBoundaryNode)] + if_index = _list_index(filtered, node) + has_else_branch = len(filtered) > if_index + 1 and isinstance(filtered[if_index + 1], tn.ElseScope) + + if has_else_branch: + # push merge_state on the stack for later usage in `visit_ElseScope` + self._state_stack.append(merge_state) + false_state = sdfg.add_state(label="false_state") + + sdfg.add_edge(guard_state, false_state, InterstateEdge(condition=f"not {node.condition.as_string}")) + + # push false_state on the stack for later usage in `visit_ElseScope` + self._state_stack.append(false_state) + else: + sdfg.add_edge(guard_state, merge_state, InterstateEdge(condition=f"not {node.condition.as_string}")) + self._current_state = merge_state - # cache write access into access_cache - access_cache[name] = access_node + def visit_StateIfScope(self, node: tn.StateIfScope, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_BreakNode(self, node: tn.BreakNode, sdfg: SDFG) -> None: + # AFAIK we don't support for break statements in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_ContinueNode(self, node: tn.ContinueNode, sdfg: SDFG) -> None: + # AFAIK we don't support for continue statements in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_ElifScope(self, node: tn.ElifScope, sdfg: SDFG) -> None: + # AFAIK we don't support elif scopes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: + # get false_state form stack + false_state = self._pop_state("false_state") + self._current_state = false_state + + # visit children inside the else branch + self.visit(node.children, sdfg=sdfg) + + # merge false-branch into merge_state + merge_state = self._pop_state("merge_state") + _insert_and_split_assignments(sdfg, + before_state=self._current_state, + after_state=merge_state, + assignments=self._pending_interstate_assignments()) + self._current_state = merge_state + + def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: + dataflow_stack_size = len(self._dataflow_stack) + state_stack_size = len(self._state_stack) + + # prepare inner SDFG + inner_sdfg = SDFG("nested_sdfg", parent=self._current_state) + start_state = inner_sdfg.add_state("nested_root", is_start_block=True) + + # update stacks and current state + old_state_label = self._current_state.label + self._state_stack.append(self._current_state) + self._dataflow_stack.append((inner_sdfg, {"inputs": set(), "outputs": set()})) + self._current_state = start_state + + # visit children + for child in node.children: + self.visit(child, sdfg=inner_sdfg) + + # restore current state and stacks + self._current_state = self._pop_state(old_state_label) + assert len(self._state_stack) == state_stack_size + _, connectors = self._dataflow_stack.pop() + assert len(self._dataflow_stack) == dataflow_stack_size + + # insert nested SDFG + nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, + sdfg, + inputs=connectors["inputs"], + outputs=connectors["outputs"], + schedule=node.node.map.schedule) + + # connect nested SDFG to surrounding map scope + assert self._dataflow_stack + map_entry, to_connect = self._dataflow_stack[-1] + + # connect nsdfg input memlets (to be propagated upon completion of the SDFG) + for name in nsdfg.in_connectors: + out_connector = f"{PREFIX_PASSTHROUGH_OUT}{name}" + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{name}") + new_out_connector = map_entry.add_out_connector(out_connector) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(map_entry, out_connector, nsdfg, name, + Memlet.from_array(name, nsdfg.sdfg.arrays[name])) + + # Add empty memlet if we didn't add any in the loop above + if self._current_state.out_degree(map_entry) < 1: + self._current_state.add_nedge(map_entry, nsdfg, Memlet()) + + # connect nsdfg output memlets (to be propagated) + for name in nsdfg.out_connectors: + to_connect[name] = (nsdfg, Memlet.from_array(name, nsdfg.sdfg.arrays[name])) + + def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: + dataflow_stack_size = len(self._dataflow_stack) + + # map entry + # --------- + map_entry = nodes.MapEntry(node.node.map) + self._current_state.add_node(map_entry) + self._dataflow_stack.append((map_entry, dict())) + + # Set a new access_cache before visiting children such that they have their + # own access cache (per map scope). + access_cache = self._ensure_access_cache(self._current_state) + self._access_cache[self._current_state] = {} + + # visit children inside the map + type_of_children = [type(child) for child in node.children] + last_child_is_MapScope = type_of_children[-1] == tn.MapScope + all_others_are_Boundaries = type_of_children.count(tn.StateBoundaryNode) == len(type_of_children) - 1 + if last_child_is_MapScope and all_others_are_Boundaries: + # skip weirdly added StateBoundaryNode + # tmp: use this - for now - to "backprop-insert" extra state boundaries for nested SDFGs + self.visit(node.children[-1], sdfg=sdfg) + elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): + self._insert_nestedSDFG(node, sdfg) + else: + self.visit(node.children, sdfg=sdfg) - if isinstance(outer_map_entry, nodes.EntryNode): - outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) - else: - assert isinstance(outer_map_entry, SDFG) or outer_map_entry is None - - # TODO If nothing is connected at this point, figure out what's the last thing that - # we should connect to. Then, add an empty memlet from that last thing to this - # map_exit. - assert len(self._current_state.in_edges(map_exit)) > 0 - - def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: - # AFAIK we don't support consume scopes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_PipelineScope(self, node: tn.PipelineScope, sdfg: SDFG) -> None: - # AFAIK we don't support pipeline scopes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: - # Add Tasklet to current state - tasklet = node.node - self._current_state.add_node(tasklet) - - cache = self._ensure_access_cache(self._current_state) - scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) - - # Connect input memlets - for name, memlet in node.in_memlets.items(): - # connect to local access node if possible - if memlet.data in cache: - cached_access = cache[memlet.data] - self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) - continue - - if isinstance(scope_node, nodes.MapEntry): - # get it from outside the map - connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}" - if connector_name not in scope_node.out_connectors: - new_in_connector = scope_node.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") - new_out_connector = scope_node.add_out_connector(connector_name) - assert new_in_connector == True - assert new_in_connector == new_out_connector - - self._current_state.add_edge(scope_node, connector_name, tasklet, name, memlet) - continue - - if isinstance(scope_node, SDFG): + # reset the access_cache + self._access_cache[self._current_state] = access_cache + + # dataflow stack management + _, to_connect = self._dataflow_stack.pop() + assert len(self._dataflow_stack) == dataflow_stack_size + outer_map_entry, outer_to_connect = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) + + # connect potential input connectors on map_entry + for connector in map_entry.in_connectors: + memlet_data = connector.removeprefix(PREFIX_PASSTHROUGH_IN) + + # connect to local access node (if available) + if memlet_data in access_cache: + cached_access = access_cache[memlet_data] + self._current_state.add_memlet_path(cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + continue + + if isinstance(outer_map_entry, nodes.EntryNode): + + # get it from outside the map + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet_data}" + if connector_name not in outer_map_entry.out_connectors: + new_in_connector = outer_map_entry.add_in_connector(connector) + new_out_connector = outer_map_entry.add_out_connector(connector_name) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, + Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + else: + if isinstance(outer_map_entry, SDFG): # Copy data descriptor from parent SDFG and add input connector - if memlet.data not in sdfg.arrays: + if memlet_data not in sdfg.arrays: parent_sdfg = sdfg.parent.parent sdfg_counter = 1 - while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + while memlet_data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: parent_sdfg = parent_sdfg.parent.parent assert isinstance(parent_sdfg, SDFG) sdfg_counter += 1 - sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + sdfg.add_datadesc(memlet_data, parent_sdfg.arrays[memlet_data].clone()) # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet.data].transient: - sdfg.arrays[memlet.data].transient = False + if parent_sdfg.arrays[memlet_data].transient: + sdfg.arrays[memlet_data].transient = False # TODO # ... unless they are only ever used inside the nested SDFG, in which case # we should delete them from the parent SDFG's array list. # NOTE This can probably be done automatically by a cleanup pass in the end. # Something like DDE should be able to do this. - assert memlet.data not in to_connect["inputs"] - to_connect["inputs"].add(memlet.data) + assert memlet_data not in outer_to_connect["inputs"] + outer_to_connect["inputs"].add(memlet_data) else: - assert scope_node is None + assert outer_map_entry is None # cache local read access - assert memlet.data not in cache - cache[memlet.data] = self._current_state.add_read(memlet.data) + assert memlet_data not in access_cache + access_cache[memlet_data] = self._current_state.add_read(memlet_data) + cached_access = access_cache[memlet_data] + self._current_state.add_memlet_path(cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + + if isinstance(outer_map_entry, nodes.EntryNode) and self._current_state.out_degree(outer_map_entry) < 1: + self._current_state.add_nedge(outer_map_entry, map_entry, Memlet()) + + # map_exit + # -------- + map_exit = nodes.MapExit(node.node.map) + self._current_state.add_node(map_exit) + + # connect writes to map_exit node + for name in to_connect: + in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" + out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" + new_in_connector = map_exit.add_in_connector(in_connector_name) + new_out_connector = map_exit.add_out_connector(out_connector_name) + assert new_in_connector == new_out_connector + + # connect "inside the map" + access_node, memlet = to_connect[name] + if isinstance(access_node, nodes.NestedSDFG): + self._current_state.add_edge(access_node, name, map_exit, in_connector_name, memlet) + else: + assert isinstance(access_node, nodes.AccessNode) + if self._current_state.out_degree(access_node) == 0 and self._current_state.in_degree(access_node) == 1: + # this access_node is not used for anything else. + # let's remove it and add a direct connection instead + edges = [edge for edge in self._current_state.edges() if edge.dst == access_node] + assert len(edges) == 1 + self._current_state.add_memlet_path(edges[0].src, + map_exit, + src_conn=edges[0].src_conn, + dst_conn=in_connector_name, + memlet=edges[0].data) + self._current_state.remove_node(access_node) # edge is remove automatically + else: + self._current_state.add_memlet_path(access_node, + map_exit, + dst_conn=in_connector_name, + memlet=memlet) + + if isinstance(outer_map_entry, SDFG): + if name not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while name not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + sdfg.add_datadesc(name, parent_sdfg.arrays[name].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[name].transient: + sdfg.arrays[name].transient = False + + # Add out_connector in any case if not yet present, e.g. write after read + outer_to_connect["outputs"].add(name) + + # connect "outside the map" + access_node = self._current_state.add_write(name) + self._current_state.add_memlet_path(map_exit, + access_node, + src_conn=out_connector_name, + memlet=Memlet.from_array(name, sdfg.arrays[name])) + + # cache write access into access_cache + access_cache[name] = access_node + + if isinstance(outer_map_entry, nodes.EntryNode): + outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) + else: + assert isinstance(outer_map_entry, SDFG) or outer_map_entry is None + + # TODO If nothing is connected at this point, figure out what's the last thing that + # we should connect to. Then, add an empty memlet from that last thing to this + # map_exit. + assert len(self._current_state.in_edges(map_exit)) > 0 + + def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: + # AFAIK we don't support consume scopes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_PipelineScope(self, node: tn.PipelineScope, sdfg: SDFG) -> None: + # AFAIK we don't support pipeline scopes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") + + def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: + # Add Tasklet to current state + tasklet = node.node + self._current_state.add_node(tasklet) + + cache = self._ensure_access_cache(self._current_state) + scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) + + # Connect input memlets + for name, memlet in node.in_memlets.items(): + # connect to local access node if possible + if memlet.data in cache: cached_access = cache[memlet.data] self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) + continue + + if isinstance(scope_node, nodes.MapEntry): + # get it from outside the map + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}" + if connector_name not in scope_node.out_connectors: + new_in_connector = scope_node.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = scope_node.add_out_connector(connector_name) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(scope_node, connector_name, tasklet, name, memlet) + continue + + if isinstance(scope_node, SDFG): + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) + else: + assert scope_node is None - # Add empty memlet if map_entry has no out_connectors to connect to - if isinstance(scope_node, nodes.MapEntry) and self._current_state.out_degree(scope_node) < 1: - self._current_state.add_nedge(scope_node, tasklet, Memlet()) + # cache local read access + assert memlet.data not in cache + cache[memlet.data] = self._current_state.add_read(memlet.data) + cached_access = cache[memlet.data] + self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) - # Connect output memlets - for name, memlet in node.out_memlets.items(): - # we always write to a new access_node - access_node = self._current_state.add_write(memlet.data) - self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) + # Add empty memlet if map_entry has no out_connectors to connect to + if isinstance(scope_node, nodes.MapEntry) and self._current_state.out_degree(scope_node) < 1: + self._current_state.add_nedge(scope_node, tasklet, Memlet()) - # cache write access node (or update an existing one) for read after write cases - cache[memlet.data] = access_node + # Connect output memlets + for name, memlet in node.out_memlets.items(): + # we always write to a new access_node + access_node = self._current_state.add_write(memlet.data) + self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) - if isinstance(scope_node, nodes.MapEntry): - # copy the memlet since we already used it in the memlet path above - to_connect[memlet.data] = (access_node, copy.deepcopy(memlet)) - continue + # cache write access node (or update an existing one) for read after write cases + cache[memlet.data] = access_node - if isinstance(scope_node, SDFG): - if memlet.data not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + if isinstance(scope_node, nodes.MapEntry): + # copy the memlet since we already used it in the memlet path above + to_connect[memlet.data] = (access_node, copy.deepcopy(memlet)) + continue - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet.data].transient: - sdfg.arrays[memlet.data].transient = False + if isinstance(scope_node, SDFG): + if memlet.data not in sdfg.arrays: + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) - # Add out_connector in any case if not yet present, e.g. write after read - to_connect["outputs"].add(memlet.data) + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False - else: - assert scope_node is None + # Add out_connector in any case if not yet present, e.g. write after read + to_connect["outputs"].add(memlet.data) - def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: - # AFAIK we expand all library calls in the gt4py/dace bridge before coming here. - raise NotImplementedError(f"{type(node)} not implemented") + else: + assert scope_node is None - def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: - # apparently we need this for the first prototype - self._ensure_access_cache(self._current_state) - access_cache = self._access_cache[self._current_state] + def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: + # AFAIK we expand all library calls in the gt4py/dace bridge before coming here. + raise NotImplementedError(f"{type(node)} not implemented") - # assumption source access may or may not yet exist (in this state) - src_name = node.memlet.data - source = access_cache[src_name] if src_name in access_cache else self._current_state.add_read(src_name) + def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: + # apparently we need this for the first prototype + self._ensure_access_cache(self._current_state) + access_cache = self._access_cache[self._current_state] - # assumption: target access node doesn't exist yet - assert node.target not in access_cache - target = self._current_state.add_write(node.target) + # assumption source access may or may not yet exist (in this state) + src_name = node.memlet.data + source = access_cache[src_name] if src_name in access_cache else self._current_state.add_read(src_name) - self._current_state.add_memlet_path(source, target, memlet=node.memlet) + # assumption: target access node doesn't exist yet + assert node.target not in access_cache + target = self._current_state.add_write(node.target) - def visit_DynScopeCopyNode(self, node: tn.DynScopeCopyNode, sdfg: SDFG) -> None: - # AFAIK we don't support dyn scope copy nodes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + self._current_state.add_memlet_path(source, target, memlet=node.memlet) - def visit_ViewNode(self, node: tn.ViewNode, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + def visit_DynScopeCopyNode(self, node: tn.DynScopeCopyNode, sdfg: SDFG) -> None: + # AFAIK we don't support dyn scope copy nodes in the gt4py/dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") - def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: - # TODO: Fillz and Ray_Fast will need these ... - raise NotImplementedError(f"{type(node)} not implemented") + def visit_ViewNode(self, node: tn.ViewNode, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") - def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: + # TODO: Fillz and Ray_Fast will need these ... + raise NotImplementedError(f"{type(node)} not implemented") - def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: - # When creating a state boundary, include all inter-state assignments that precede it. - pending = self._pending_interstate_assignments() + def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: + # Let's see if we need this for the first prototype ... + raise NotImplementedError(f"{type(node)} not implemented") - self._current_state = create_state_boundary(node, - sdfg, - self._current_state, - StateBoundaryBehavior.STATE_TRANSITION, - assignments=pending) + def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: + # When creating a state boundary, include all inter-state assignments that precede it. + pending = self._pending_interstate_assignments() - def _pending_interstate_assignments(self) -> Dict: - """ - Return currently pending interstate assignments. Clears the cache. - """ - assignments = {} + self._current_state = create_state_boundary(node, + sdfg, + self._current_state, + StateBoundaryBehavior.STATE_TRANSITION, + assignments=pending) - for symbol in self._interstate_symbols: - assignments[symbol.name] = symbol.value.as_string - self._interstate_symbols.clear() + def _pending_interstate_assignments(self) -> Dict: + """ + Return currently pending interstate assignments. Clears the cache. + """ + assignments = {} - return assignments + for symbol in self._interstate_symbols: + assignments[symbol.name] = symbol.value.as_string + self._interstate_symbols.clear() - StreeToSDFG().visit(stree, sdfg=result) - print(f"Main visitor took {(time.time() - s):.3f} seconds.") + return assignments - # memlet propagation - s = time.time() + +def from_schedule_tree(stree: tn.ScheduleTreeRoot, + state_boundary_behavior: StateBoundaryBehavior = StateBoundaryBehavior.STATE_TRANSITION) -> SDFG: + """ + Converts a schedule tree into an SDFG. + + :param stree: The schedule tree root to convert. + :param state_boundary_behavior: Sets the behavior upon encountering a state boundary (e.g., write-after-write). + See the ``StateBoundaryBehavior`` enumeration for more details. + :return: An SDFG representing the schedule tree. + """ + # Setup SDFG descriptor repository + result = SDFG(stree.name, propagate=False) + result.arg_names = copy.deepcopy(stree.arg_names) + for key, container in stree.containers.items(): + result._arrays[key] = copy.deepcopy(container) + result.constants_prop = copy.deepcopy(stree.constants) + result.symbols = copy.deepcopy(stree.symbols) + + # Insert artificial state boundaries after WAW, before label, etc. + stree = insert_state_boundaries_to_tree(stree) + + # Traverse tree and incrementally build SDFG, finally propagate memlets + StreeToSDFG().visit(stree, sdfg=result) propagation.propagate_memlets_sdfg(result) - print(f"Memlet propagation took {(time.time() - s):.3f} seconds.") return result @@ -711,8 +698,6 @@ def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTr :param stree: The schedule tree to operate on. """ - s = time.time() - # Simple boundary node inserter for control flow blocks and state labels class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer): @@ -726,14 +711,9 @@ def visit_StateLabel(self, node: tn.StateLabel): # First, insert boundaries around labels and control flow stree = SimpleStateBoundaryInserter().visit(stree) - print(f"\tSimpleStateBoundaryInserter took {(time.time() - s):.3f} seconds.") - s = time.time() # Then, insert boundaries after unmet memory dependencies or potential data races _insert_memory_dependency_state_boundaries(stree) - print(f"\tMemory dependency analysis took {(time.time() - s):.3f} seconds.") - - s = time.time() # Insert a state boundary after every symbol assignment to ensure symbols are assigned before usage class SymbolAssignmentBoundaryInserter(tn.ScheduleNodeTransformer): @@ -753,11 +733,8 @@ def visit_AssignNode(self, node: tn.AssignNode): return [self.generic_visit(node), tn.StateBoundaryNode()] stree = SymbolAssignmentBoundaryInserter().visit(stree) - print(f"\tSymbolAssignmentBoundaryInserter took {(time.time() - s):.3f} seconds.") # Hack: "backprop-insert" state boundaries from nested SDFGs - s = time.time() - class NestedSDFGStateBoundaryInserter(tn.ScheduleNodeTransformer): def visit_MapScope(self, scope: tn.MapScope): @@ -777,7 +754,6 @@ def visit_MapScope(self, scope: tn.MapScope): return visited stree = NestedSDFGStateBoundaryInserter().visit(stree) - print(f"\tNestedSDFGStateBoundaryInserter took {(time.time() - s):.3f} seconds.") return stree From 3a3852407515184c8d4b044868cb4d69838b28b4 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 9 Jul 2025 20:58:44 +0200 Subject: [PATCH 086/137] Fix networkx state space explosion for cycle detection in state propagation (#2078) Networkx runs into a state space explosion for large loop graphs with many branches when querying simple cycles in the graph. Since state propagation (used by memlet propagation) makes use of that, this causes the operation to not complete in any reasonable time frame. This PR back-ports the improved loop detection from DaCe 2.0 and changes state propagation's loop range annotation to make use of that, instead of relying on Networkx's cycle querying. --- dace/sdfg/propagation.py | 133 +++----------- .../interstate/loop_detection.py | 166 +++++++++++++++--- tests/sdfg/work_depth_test.py | 4 +- 3 files changed, 165 insertions(+), 138 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index f1b7d3baf7..07bfa754a3 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. """ Functionality relating to Memlet propagation (deducing external memlets from internal memory accesses and scope ranges). @@ -571,114 +571,33 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): """ # We import here to avoid cyclic imports. - from dace.sdfg import utils as sdutils - from dace.transformation.interstate.loop_detection import find_for_loop + from dace.transformation.passes.pattern_matching import match_patterns + from dace.transformation.interstate.loop_detection import LoopRangeAnnotator + from dace.sdfg.utils import dfs_conditional + from dace.sdfg.analysis import cfg as cfg_analysis condition_edges = {} - - for cycle in sdfg.find_cycles(): - # In each cycle, try to identify a valid loop guard state. - guard = None - begin = None - itvar = None - for v in cycle: - # Try to identify a valid for-loop guard. - in_edges = sdfg.in_edges(v) - out_edges = sdfg.out_edges(v) - - # A for-loop guard has two or more incoming edges (1 increment and - # n init, all identical), and exactly two outgoing edges (loop and - # exit loop). - if len(in_edges) < 2 or len(out_edges) != 2: - continue - - # All incoming guard edges must set exactly one variable and it must - # be the same for all of them. - itvars = set() - for iedge in in_edges: - if len(iedge.data.assignments) > 0: - if not itvars: - itvars = set(iedge.data.assignments.keys()) - else: - itvars &= set(iedge.data.assignments.keys()) - else: - itvars = None - break - if not itvars or len(itvars) > 1: - continue - itvar = next(iter(itvars)) - itvarsym = pystr_to_symbolic(itvar) - - # The outgoing edges must be negations of one another. - if out_edges[0].data.condition_sympy() != (sympy.Not(out_edges[1].data.condition_sympy())): - continue - - # Make sure the last state of the loop (i.e. the state leading back - # to the guard via 'increment' edge) is part of this cycle. If not, - # we're looking at the guard for a nested cycle, which we ignore for - # this cycle. - increment_edge = None - for iedge in in_edges: - if itvarsym in pystr_to_symbolic(iedge.data.assignments[itvar]).free_symbols: - increment_edge = iedge - break - if increment_edge is None: - continue - if increment_edge.src not in cycle: - continue - - # One of the child states must be in the loop (loop begin), and the - # other one must be outside the cycle (loop exit). - loop_state = None - exit_state = None - if out_edges[0].dst in cycle and out_edges[1].dst not in cycle: - loop_state = out_edges[0].dst - exit_state = out_edges[1].dst - elif out_edges[1].dst in cycle and out_edges[0].dst not in cycle: - loop_state = out_edges[1].dst - exit_state = out_edges[0].dst - if loop_state is None or exit_state is None: - continue - - # This is a valid guard state candidate. - guard = v - begin = loop_state - break - - if guard is not None and begin is not None and itvar is not None: - # A guard state was identified, see if it has valid for-loop ranges - # and annotate the loop as such. - - # Ensure that this guard's loop wasn't annotated yet. - if itvar in begin.ranges: - continue - - res = find_for_loop(sdfg, guard, begin, itervar=itvar) - if res is None: - # No range detected, mark as unbounded. - unannotated_cycle_states.append(cycle) - else: - itervar, rng, _ = res - - # Make sure the range is flipped in a direction such that the - # stride is positive (in order to match subsets.Range). - start, stop, stride = rng - # This inequality needs to be checked exactly like this due to - # constraints in sympy/symbolic expressions, do not simplify!!! - if (stride < 0) == True: - rng = (stop, start, -stride) - - loop_states = sdutils.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard) - for v in loop_states: - v.ranges[itervar] = subsets.Range([rng]) - guard.ranges[itervar] = subsets.Range([rng]) - condition_edges[guard] = sdfg.edges_between(guard, begin)[0] - guard.is_loop_guard = True - guard.itvar = itervar - else: - # There's no guard state, so this cycle marks all states in it as - # dynamically unbounded. - unannotated_cycle_states.append(cycle) + loop_back_edges = set() + + for match in match_patterns(sdfg, LoopRangeAnnotator): + annotator: LoopRangeAnnotator = match + cond_edge = annotator.loop_condition_edge() + guard_state = annotator.loop_guard_state() + loop_back_edge = annotator.loop_increment_edge() + if cond_edge is not None and guard_state is not None: + condition_edges[guard_state] = cond_edge + if loop_back_edge is not None: + loop_back_edges.add(loop_back_edge) + annotator.apply(sdfg, sdfg) + + for be in cfg_analysis.back_edges(sdfg): + if be not in loop_back_edges: + # This backedge closes a loop that was not annotated, and thus is not a proper for-loop. The states in this + # cycle are thus unannotated. + cycle_states = set() + for cycle_state in dfs_conditional(sdfg, [be.src], lambda p, _: p is not be.dst, reverse=True): + cycle_states.add(cycle_state) + unannotated_cycle_states.append(cycle_states) return condition_edges diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 8081447132..2a56a3d93b 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,4 +1,4 @@ -# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ import sympy as sp @@ -8,6 +8,7 @@ from dace import sdfg as sd, symbolic from dace.sdfg import graph as gr, utils as sdutil, InterstateEdge from dace.sdfg.state import ControlFlowRegion, ControlFlowBlock +from dace.subsets import Range from dace.transformation import transformation @@ -17,20 +18,23 @@ class DetectLoop(transformation.PatternTransformation): """ Detects a for-loop construct from an SDFG. """ # Always available - loop_begin = transformation.PatternNode(sd.SDFGState) - exit_state = transformation.PatternNode(sd.SDFGState) + loop_begin = transformation.PatternNode(ControlFlowBlock) + exit_state = transformation.PatternNode(ControlFlowBlock) # Available for natural loops - loop_guard = transformation.PatternNode(sd.SDFGState) + loop_guard = transformation.PatternNode(ControlFlowBlock) # Available for rotated loops - loop_latch = transformation.PatternNode(sd.SDFGState) + loop_latch = transformation.PatternNode(ControlFlowBlock) # Available for rotated and self loops - entry_state = transformation.PatternNode(sd.SDFGState) + entry_state = transformation.PatternNode(ControlFlowBlock) # Available for explicit-latch rotated loops - loop_break = transformation.PatternNode(sd.SDFGState) + loop_break = transformation.PatternNode(ControlFlowBlock) + + break_edges: Set[gr.Edge[InterstateEdge]] = set() + continue_edges: Set[gr.Edge[InterstateEdge]] = set() @classmethod def expressions(cls): @@ -129,15 +133,21 @@ def can_be_applied(self, elif expr_index == 4: return self.detect_self_loop(graph, accept_missing_itvar=permissive) is not None elif expr_index in (5, 7): - return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive, + return self.detect_rotated_loop(graph, + multistate_loop=True, + accept_missing_itvar=permissive, separate_latch=True) is not None elif expr_index == 6: - return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive, + return self.detect_rotated_loop(graph, + multistate_loop=False, + accept_missing_itvar=permissive, separate_latch=True) is not None raise ValueError(f'Invalid expression index {expr_index}') - def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + def detect_loop(self, + graph: ControlFlowRegion, + multistate_loop: bool, accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -183,7 +193,14 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, # All nodes inside loop must be dominated by loop guard dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) - loop_nodes = sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard) + postdominators = sdutil.postdominators(graph, True) + loop_nodes = self.loop_body() + # If the exit state is in the loop nodes, this is not a valid loop + if self.exit_state in loop_nodes: + return None + elif any(self.exit_state not in postdominators[1][n] for n in loop_nodes): + # The loop exit must post-dominate all loop nodes + return None backedge = None for node in loop_nodes: for e in graph.out_edges(node): @@ -219,8 +236,11 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, return next(iter(itvar)) - def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, - accept_missing_itvar: bool = False, separate_latch: bool = False) -> Optional[str]: + def detect_rotated_loop(self, + graph: ControlFlowRegion, + multistate_loop: bool, + accept_missing_itvar: bool = False, + separate_latch: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -260,15 +280,23 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, if latch_outedges[0].data.condition_sympy() != (sp.Not(latch_outedges[1].data.condition_sympy())): return None + # Make sure the backedge (i.e, one of the condition edges) goes from the latch to the beginning state. + if latch_outedges[0].dst is not self.loop_begin and latch_outedges[1].dst is not self.loop_begin: + return None + # All nodes inside loop must be dominated by loop start dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) if begin is ltest: loop_nodes = [begin] else: - loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes = self.loop_body() loop_nodes.append(latch) if ltest is not latch and ltest is not begin: loop_nodes.append(ltest) + postdominators = sdutil.postdominators(graph, True) + if any(self.exit_state not in postdominators[1][n] for n in loop_nodes): + # The loop exit must post-dominate all loop nodes + return None backedge = None for node in loop_nodes: for e in graph.out_edges(node): @@ -369,37 +397,72 @@ def loop_information( return find_for_loop(guard.parent_graph, guard, entry, itervar) elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch - return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar, + return find_rotated_for_loop(latch.parent_graph, + latch, + entry, + itervar, separate_latch=(self.expr_index in (5, 6, 7))) elif self.expr_index == 4: return find_rotated_for_loop(entry.parent_graph, entry, entry, itervar) raise ValueError(f'Invalid expression index {self.expr_index}') + def _loop_body_dfs(self, terminator: ControlFlowBlock) -> Iterable[ControlFlowBlock]: + self.break_edges.clear() + visited = set() + start = self.loop_begin + graph = start.parent_graph + exit_state = self.exit_state + yield start + visited.add(start) + stack = [(start, iter(graph.successors(start)))] + while stack: + parent, children = stack[-1] + try: + child = next(children) + if child not in visited: + visited.add(child) + if child == exit_state: + # If the exit state is reachable from the loop body, that counts as a break edge. + for e in graph.edges_between(parent, child): + self.break_edges.add(e) + elif child != terminator: + try: + yield child + stack.append((child, iter(graph.successors(child)))) + except sdutil.StopTraversal: + pass + else: + # If we reached the terminator, we do not traverse further. All edges reaching the terminator + # are marked as continue edges. If there is only one continue edge int the end, it can be + # discarded (not actually a continue, simply the edge closing the loop). + for e in graph.edges_between(parent, child): + self.continue_edges.add(e) + except StopIteration: + stack.pop() + def loop_body(self) -> List[ControlFlowBlock]: """ Returns a list of all control flow blocks (or states) contained in the loop. """ - begin = self.loop_begin - graph = begin.parent_graph if self.expr_index in (0, 1): guard = self.loop_guard - return list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != guard)) + return list(self._loop_body_dfs(guard)) elif self.expr_index in (2, 3): latch = self.loop_latch - loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != latch)) + loop_nodes = list(self._loop_body_dfs(latch)) loop_nodes += [latch] return loop_nodes elif self.expr_index == 4: - return [begin] + return [self.loop_begin] elif self.expr_index in (5, 7): ltest = self.loop_break latch = self.loop_latch - loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes = list(self._loop_body_dfs(ltest)) loop_nodes += [ltest, latch] return loop_nodes elif self.expr_index == 6: - return [begin, self.loop_latch] + return [self.loop_begin, self.loop_latch] return [] @@ -496,11 +559,12 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: raise ValueError(f'Invalid expression index {self.expr_index}') -def rotated_loop_find_itvar(begin_inedges: List[gr.Edge[InterstateEdge]], - latch_inedges: List[gr.Edge[InterstateEdge]], - backedge: gr.Edge[InterstateEdge], latch: ControlFlowBlock, - accept_missing_itvar: bool = False) -> Tuple[Optional[str], - Optional[gr.Edge[InterstateEdge]]]: +def rotated_loop_find_itvar( + begin_inedges: List[gr.Edge[InterstateEdge]], + latch_inedges: List[gr.Edge[InterstateEdge]], + backedge: gr.Edge[InterstateEdge], + latch: ControlFlowBlock, + accept_missing_itvar: bool = False) -> Tuple[Optional[str], Optional[gr.Edge[InterstateEdge]]]: # The iteration variable must be assigned (initialized) on all edges leading into the beginning block, which # are not the backedge. Gather all variabes for which that holds - they are all candidates for the iteration # variable (Phase 1). Said iteration variable must then be incremented: @@ -582,7 +646,7 @@ def find_for_loop( List[sd.SDFGState], sd.SDFGState]]]: """ Finds loop range from state machine. - + :param guard: State from which the outgoing edges detect whether to exit the loop or not. :param entry: First state in the loop body. @@ -694,7 +758,7 @@ def find_rotated_for_loop( List[sd.SDFGState], sd.SDFGState]]]: """ Finds rotated loop range from state machine. - + :param latch: State from which the outgoing edges detect whether to reenter the loop or not. :param entry: First state in the loop body. :param itervar: An optional field that overrides the analyzed iteration variable. @@ -791,3 +855,47 @@ def find_rotated_for_loop( return None return itervar, (start, end, stride), (start_states, last_loop_state) + + +class LoopRangeAnnotator(DetectLoop, transformation.MultiStateTransformation): + + def can_be_applied(self, graph, expr_index, sdfg, permissive = False): + if super().can_be_applied(graph, expr_index, sdfg, permissive): + loop_info = self.loop_information() + if loop_info is None: + return False + return True + return False + + def loop_guard_state(self): + """ + Returns the loop guard state of this loop (i.e., latch state or begin state for inverted or self loops). + """ + if self.expr_index in (0, 1): + return self.loop_guard + elif self.expr_index in (2, 3, 5, 6, 7): + return self.loop_latch + else: + return self.loop_begin + + def apply(self, graph, sdfg): + itvar, rng, _ = self.loop_information() + + body = self.loop_body() + meta = self.loop_meta_states() + full_body = set(body) + full_body.update(meta) + + # Make sure the range is flipped such that the stride is positive (in order to match subsets.Range). + start, stop, stride = rng + # ===== + # NOTE: This inequality needs to be checked exactly like this due to sympy limitations, do not simplify! + if (stride < 0) == True: + rng = (stop, start, -stride) + # ===== + + for v in full_body: + v.ranges[itvar] = Range([rng]) + guard_state = self.loop_guard_state() + guard_state.is_loop_guard = True + guard_state.itvar = itvar diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index e677cca752..808a76f84b 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -320,8 +320,8 @@ def test_assumption_system_contradictions(assumptions): for test_name in work_depth_test_cases.keys(): test_work_depth(test_name) - for test, correct in tests_cases_avg_par: - test_avg_par(test, correct) + for test in tests_cases_avg_par.keys(): + test_avg_par(test) for expr, assums, res in assumptions_tests: test_assumption_system(expr, assums, res) From a180a7e6b26b161b1758f351fdc45dd57b29da9d Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 14 Jul 2025 16:34:45 +0200 Subject: [PATCH 087/137] Testing per scope access cache --- .../analysis/schedule_tree/tree_to_sdfg.py | 83 ++++++++++++------- dace/sdfg/analysis/schedule_tree/treenodes.py | 49 +++++++++++ tests/schedule_tree/to_sdfg_test.py | 4 +- 3 files changed, 107 insertions(+), 29 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 8dc55c8c16..a4dd220a37 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -25,20 +25,27 @@ class StateBoundaryBehavior(Enum): class StreeToSDFG(tn.ScheduleNodeVisitor): def __init__(self, start_state: Optional[SDFGState] = None) -> None: - # state management - self._state_stack: List[SDFGState] = [] + self._ctx: tn.Context + """Context information like tree root and current scope.""" + self._current_state = start_state + """Current SDFGState in the SDFG that we are building.""" - # inter-state symbol assignments self._interstate_symbols: List[tn.AssignNode] = [] + """Interstate symbol assignments. Will be assigned with the next state transition.""" + + # state management + self._state_stack: List[SDFGState] = [] # dataflow scopes # List[ (MapEntryNode, ToConnect) | (SDFG, {"inputs": set(), "outputs": set()}) ] self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]] | Tuple[SDFG, Dict[str, Set[str]]]] = [] + # -- to be torched -- # caches - self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} + # self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} + # end -- to be torched -- def _pop_state(self, label: Optional[str] = None) -> SDFGState: """Pops the last state from the state stack. @@ -56,32 +63,42 @@ def _pop_state(self, label: Optional[str] = None) -> SDFGState: return popped - def _ensure_access_cache(self, state: SDFGState) -> Dict[str, nodes.AccessNode]: - """Ensure an access_cache entry for the given state. - - Checks if there exists an access_cache for `state`. Creates an empty one if it doesn't exist yet. + # def _ensure_access_cache(self, state: SDFGState) -> Dict[str, nodes.AccessNode]: + # """Ensure an access_cache entry for the given state. - :param SDFGState state: The state to check. - :return: The state's access_cache. - """ - if state not in self._access_cache: - self._access_cache[state] = {} - - return self._access_cache[state] +# +# Checks if there exists an access_cache for `state`. Creates an empty one if it doesn't exist yet. +# +# :param SDFGState state: The state to check. +# +# :return: The state's access_cache. +# """ +# # -- to be torched +# raise RuntimeError("We shouldn't end up here anymore.") +# if state not in self._access_cache: +# self._access_cache[state] = {} +# +# return self._access_cache[state] def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: + # -- to be torched -- assert self._current_state is None, "Expected no 'current_state' at root." assert not self._state_stack, "Expected empty state stack at root." assert not self._dataflow_stack, "Expected empty dataflow stack at root." assert not self._interstate_symbols, "Expected empty list of symbols at root." + # end -- to be torched -- self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) - self.visit(node.children, sdfg=sdfg) + self._ctx = tn.Context(root=node, access_cache={}, current_scope=None) + with node.scope(self._ctx): + self.visit(node.children, sdfg=sdfg) + # -- to be torched -- assert not self._state_stack, "Expected empty state stack." assert not self._dataflow_stack, "Expected empty dataflow stack." assert not self._interstate_symbols, "Expected empty list of symbols to add." + # end -- to be torched -- def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... @@ -283,8 +300,10 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state = start_state # visit children - for child in node.children: - self.visit(child, sdfg=inner_sdfg) + with node.scope(self._ctx): + self.visit(node.children, sdfg=inner_sdfg) + # for child in node.children: + # self.visit(child, sdfg=inner_sdfg) # restore current state and stacks self._current_state = self._pop_state(old_state_label) @@ -333,8 +352,8 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: # Set a new access_cache before visiting children such that they have their # own access cache (per map scope). - access_cache = self._ensure_access_cache(self._current_state) - self._access_cache[self._current_state] = {} + # access_cache = self._ensure_access_cache(self._current_state) + # self._access_cache[self._current_state] = {} # visit children inside the map type_of_children = [type(child) for child in node.children] @@ -343,14 +362,18 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: if last_child_is_MapScope and all_others_are_Boundaries: # skip weirdly added StateBoundaryNode # tmp: use this - for now - to "backprop-insert" extra state boundaries for nested SDFGs - self.visit(node.children[-1], sdfg=sdfg) + with node.scope(self._ctx): + self.visit(node.children[-1], sdfg=sdfg) elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): self._insert_nestedSDFG(node, sdfg) else: - self.visit(node.children, sdfg=sdfg) + with node.scope(self._ctx): + self.visit(node.children, sdfg=sdfg) + + # # reset the access_cache + # self._access_cache[self._current_state] = access_cache - # reset the access_cache - self._access_cache[self._current_state] = access_cache + access_cache = self._ctx.access_cache[id(self._ctx.current_scope)] # dataflow stack management _, to_connect = self._dataflow_stack.pop() @@ -506,7 +529,8 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: tasklet = node.node self._current_state.add_node(tasklet) - cache = self._ensure_access_cache(self._current_state) + cache = self._ctx.access_cache[id(self._ctx.current_scope)] + assert cache is not None scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) # Connect input memlets @@ -603,6 +627,8 @@ def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: + raise NotImplementedError("Not yet ported to new stree bridge") + # apparently we need this for the first prototype self._ensure_access_cache(self._current_state) access_cache = self._access_cache[self._current_state] @@ -837,7 +863,7 @@ def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): # SDFG content creation functions -def create_state_boundary(bnode: tn.StateBoundaryNode, +def create_state_boundary(boundary_node: tn.StateBoundaryNode, sdfg_region: ControlFlowRegion, state: SDFGState, behavior: StateBoundaryBehavior, @@ -845,7 +871,7 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, """ Creates a boundary between two states - :param bnode: The state boundary node to generate. + :param boundary_node: The state boundary node to generate. :param sdfg_region: The control flow block in which to generate the boundary (e.g., SDFG). :param state: The last state prior to this boundary. :param behavior: The state boundary behavior with which to create the boundary. @@ -858,7 +884,8 @@ def create_state_boundary(bnode: tn.StateBoundaryNode, # TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every # behavior. Fall back to state transition in that case. - label = "cf_state_boundary" if bnode.due_to_control_flow else "state_boundary" + label = "cf_state_boundary" if boundary_node.due_to_control_flow else "state_boundary" + assignments = assignments if assignments is not None else {} return _insert_and_split_assignments(sdfg_region, state, label=label, assignments=assignments) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index e9faa9c141..14c887ec07 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -9,6 +9,7 @@ from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast from dace.sdfg.state import SDFGState from dace.memlet import Memlet +from types import TracebackType from typing import Any, Dict, Iterator, List, Optional, Set, Tuple, Union INDENTATION = ' ' @@ -18,6 +19,48 @@ class UnsupportedScopeException(Exception): pass +@dataclass +class Context: + root: 'ScheduleTreeRoot' + current_scope: Optional['ScheduleTreeScope'] + + access_cache: Dict[str, Dict[str, nodes.AccessNode]] + """Per scope (hashed by id(scope_node) access_cache.""" + + +class ContextPushPop: + """Append the given node to the scope, then push/pop the scope.""" + + def __init__(self, ctx: Context, node: 'ScheduleTreeScope') -> None: + if ctx.current_scope is None and not isinstance(node, ScheduleTreeRoot): + raise ValueError("ctx.current_scope is only allowed to be 'None' when node it tree root.") + + self._ctx = ctx + self._parent_scope = ctx.current_scope + self._node = node + + assert id(node) not in self._ctx.access_cache + self._ctx.access_cache[id(node)] = {} + + def __enter__(self) -> None: + assert not self._ctx.access_cache[id(self._node)], "Expecting an empty access_cache when entering the context." + # self._node.parent = self._parent_scope + # if self._parent_scope is not None: # Exception for ScheduleTreeRoot + # self._parent_scope.children.append(self._node) + self._ctx.current_scope = self._node + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + assert id(self._node) in self._ctx.access_cache + self._ctx.access_cache[id(self._node)].clear() + + self._ctx.current_scope = self._parent_scope + + @dataclass class ScheduleTreeNode: parent: Optional['ScheduleTreeScope'] = field(default=None, init=False) @@ -225,6 +268,9 @@ def as_sdfg(self, def get_root(self) -> 'ScheduleTreeRoot': return self + def scope(self, ctx: Context) -> ContextPushPop: + return ContextPushPop(ctx, self) + @dataclass class ControlFlowScope(ScheduleTreeScope): @@ -235,6 +281,9 @@ class ControlFlowScope(ScheduleTreeScope): class DataflowScope(ScheduleTreeScope): node: nodes.EntryNode + def scope(self, ctx: Context) -> ContextPushPop: + return ContextPushPop(ctx, self) + @dataclass class GBlock(ControlFlowScope): diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 099257e0a0..949bd29d80 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -174,7 +174,9 @@ def test_state_boundaries_state_transition(): ) stree = t2s.insert_state_boundaries_to_tree(stree) - assert [tn.AssignNode, tn.TaskletNode, tn.StateBoundaryNode, tn.AssignNode] == [type(n) for n in stree.children] + assert [ + tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode, tn.StateBoundaryNode, tn.AssignNode, tn.StateBoundaryNode + ] == [type(n) for n in stree.children] @pytest.mark.parametrize('boundary', (False, True)) From f014559c362e272d2aa45272282c258cbd95ca71 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 14 Jul 2025 19:07:24 +0200 Subject: [PATCH 088/137] Cache per state & scope. Special case for write after read after write --- .../analysis/schedule_tree/tree_to_sdfg.py | 63 +++++++++++-------- dace/sdfg/analysis/schedule_tree/treenodes.py | 26 ++++---- tests/schedule_tree/to_sdfg_test.py | 33 ++++++++++ 3 files changed, 86 insertions(+), 36 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index a4dd220a37..cdd3bf8200 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -91,7 +91,7 @@ def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) self._ctx = tn.Context(root=node, access_cache={}, current_scope=None) - with node.scope(self._ctx): + with node.scope(self._current_state, self._ctx): self.visit(node.children, sdfg=sdfg) # -- to be torched -- @@ -300,10 +300,8 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state = start_state # visit children - with node.scope(self._ctx): + with node.scope(self._current_state, self._ctx): self.visit(node.children, sdfg=inner_sdfg) - # for child in node.children: - # self.visit(child, sdfg=inner_sdfg) # restore current state and stacks self._current_state = self._pop_state(old_state_label) @@ -343,6 +341,7 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) + cache_state = self._current_state # map entry # --------- @@ -350,11 +349,6 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: self._current_state.add_node(map_entry) self._dataflow_stack.append((map_entry, dict())) - # Set a new access_cache before visiting children such that they have their - # own access cache (per map scope). - # access_cache = self._ensure_access_cache(self._current_state) - # self._access_cache[self._current_state] = {} - # visit children inside the map type_of_children = [type(child) for child in node.children] last_child_is_MapScope = type_of_children[-1] == tn.MapScope @@ -362,18 +356,21 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: if last_child_is_MapScope and all_others_are_Boundaries: # skip weirdly added StateBoundaryNode # tmp: use this - for now - to "backprop-insert" extra state boundaries for nested SDFGs - with node.scope(self._ctx): + with node.scope(self._current_state, self._ctx): self.visit(node.children[-1], sdfg=sdfg) elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): self._insert_nestedSDFG(node, sdfg) else: - with node.scope(self._ctx): + with node.scope(self._current_state, self._ctx): self.visit(node.children, sdfg=sdfg) - # # reset the access_cache - # self._access_cache[self._current_state] = access_cache + if cache_state != self._current_state: + breakpoint - access_cache = self._ctx.access_cache[id(self._ctx.current_scope)] + cache_key = (cache_state, id(self._ctx.current_scope)) + if cache_key not in self._ctx.access_cache: + self._ctx.access_cache[cache_key] = {} + access_cache = self._ctx.access_cache[cache_key] # dataflow stack management _, to_connect = self._dataflow_stack.pop() @@ -497,15 +494,17 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: outer_to_connect["outputs"].add(name) # connect "outside the map" - access_node = self._current_state.add_write(name) + if name not in access_cache: + # cache write access into access_cache + write_access_node = self._current_state.add_write(name) + access_cache[name] = write_access_node + + access_node = access_cache[name] self._current_state.add_memlet_path(map_exit, access_node, src_conn=out_connector_name, memlet=Memlet.from_array(name, sdfg.arrays[name])) - # cache write access into access_cache - access_cache[name] = access_node - if isinstance(outer_map_entry, nodes.EntryNode): outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) else: @@ -529,8 +528,10 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: tasklet = node.node self._current_state.add_node(tasklet) - cache = self._ctx.access_cache[id(self._ctx.current_scope)] - assert cache is not None + cache_key = (self._current_state, id(self._ctx.current_scope)) + if cache_key not in self._ctx.access_cache: + self._ctx.access_cache[cache_key] = {} + cache = self._ctx.access_cache[cache_key] scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) # Connect input memlets @@ -590,12 +591,24 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Connect output memlets for name, memlet in node.out_memlets.items(): - # we always write to a new access_node - access_node = self._current_state.add_write(memlet.data) - self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) + # don't use cached access node, if it was an input, e.g. + # A[1] = tasklet() + # A[1] = tasklet(A[1]) + # TODO / Question: Do I need port to "port this up" to the MapScope level? I guess so? + cached_node_is_input = False + if memlet.data in cache: + for _name, in_memlet in node.in_memlets.items(): + if memlet.data == in_memlet.data: + cached_node_is_input = True + break + + if memlet.data not in cache or cached_node_is_input: + # cache write access node + write_access_node = self._current_state.add_write(memlet.data) + cache[memlet.data] = write_access_node - # cache write access node (or update an existing one) for read after write cases - cache[memlet.data] = access_node + access_node = cache[memlet.data] + self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) if isinstance(scope_node, nodes.MapEntry): # copy the memlet since we already used it in the memlet path above diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 14c887ec07..aa6eb58951 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -24,26 +24,29 @@ class Context: root: 'ScheduleTreeRoot' current_scope: Optional['ScheduleTreeScope'] - access_cache: Dict[str, Dict[str, nodes.AccessNode]] + access_cache: Dict[Tuple[SDFGState, str], Dict[str, nodes.AccessNode]] """Per scope (hashed by id(scope_node) access_cache.""" class ContextPushPop: """Append the given node to the scope, then push/pop the scope.""" - def __init__(self, ctx: Context, node: 'ScheduleTreeScope') -> None: + def __init__(self, ctx: Context, state: SDFGState, node: 'ScheduleTreeScope') -> None: if ctx.current_scope is None and not isinstance(node, ScheduleTreeRoot): raise ValueError("ctx.current_scope is only allowed to be 'None' when node it tree root.") self._ctx = ctx self._parent_scope = ctx.current_scope self._node = node + self._state = state - assert id(node) not in self._ctx.access_cache - self._ctx.access_cache[id(node)] = {} + cache_key = (state, id(node)) + assert cache_key not in self._ctx.access_cache + self._ctx.access_cache[cache_key] = {} def __enter__(self) -> None: - assert not self._ctx.access_cache[id(self._node)], "Expecting an empty access_cache when entering the context." + assert not self._ctx.access_cache[(self._state, id( + self._node))], "Expecting an empty access_cache when entering the context." # self._node.parent = self._parent_scope # if self._parent_scope is not None: # Exception for ScheduleTreeRoot # self._parent_scope.children.append(self._node) @@ -55,8 +58,9 @@ def __exit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ) -> None: - assert id(self._node) in self._ctx.access_cache - self._ctx.access_cache[id(self._node)].clear() + cache_key = (self._state, id(self._node)) + assert cache_key in self._ctx.access_cache + # self._ctx.access_cache[cache_key].clear() self._ctx.current_scope = self._parent_scope @@ -268,8 +272,8 @@ def as_sdfg(self, def get_root(self) -> 'ScheduleTreeRoot': return self - def scope(self, ctx: Context) -> ContextPushPop: - return ContextPushPop(ctx, self) + def scope(self, state: SDFGState, ctx: Context) -> ContextPushPop: + return ContextPushPop(ctx, state, self) @dataclass @@ -281,8 +285,8 @@ class ControlFlowScope(ScheduleTreeScope): class DataflowScope(ScheduleTreeScope): node: nodes.EntryNode - def scope(self, ctx: Context) -> ContextPushPop: - return ContextPushPop(ctx, self) + def scope(self, state: SDFGState, ctx: Context) -> ContextPushPop: + return ContextPushPop(ctx, state, self) @dataclass diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 949bd29d80..9d607704d0 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -515,6 +515,39 @@ def test_map_with_state_boundary_inside(): sdfg.validate() +def test_map_calculate_temporary_in_two_loops(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + "A": dace.data.Array(dace.float64, [20]), + "tmp": dace.data.Array(dace.float64, [20], transient=True) + }, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("first_half", "i", sbs.Range.from_string("0:10"))), + children=[ + tn.TaskletNode(nodes.Tasklet("beginning", {}, {'out'}, 'out = i'), {}, + {'out': dace.Memlet("tmp[i]")}) + ]), + tn.MapScope(node=nodes.MapEntry(nodes.Map("second_half", "i", sbs.Range.from_string("10:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("end", {}, {'out'}, 'out = i'), {}, + {'out': dace.Memlet("tmp[i]")}) + ]), + tn.MapScope(node=nodes.MapEntry(nodes.Map("read_tmp", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("read_temp", {"tmp"}, {"out"}, "out = tmp + 1"), + {"tmp": dace.Memlet("tmp[i]")}, {"out": dace.Memlet("A[i]")}) + ]) + ]) + + sdfg = stree.as_sdfg(simplify=True) + sdfg.validate() + + assert [node.name for node, _ in sdfg.all_nodes_recursive() + if isinstance(node, nodes.Tasklet)] == ["beginning", "end", "read_temp"] + + def test_edge_assignment_read_after_write(): stree = tn.ScheduleTreeRoot(name="tester", containers={}, From 3af927f3ac39f4f5e1ab06e4a5091d2bfb90d6d4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:30:15 +0200 Subject: [PATCH 089/137] WIP: revert special case for cached node is input --- .../analysis/schedule_tree/tree_to_sdfg.py | 24 ++++++++++--------- tests/schedule_tree/to_sdfg_test.py | 4 +++- 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index cdd3bf8200..d37f1355ac 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -594,18 +594,20 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # don't use cached access node, if it was an input, e.g. # A[1] = tasklet() # A[1] = tasklet(A[1]) + # TODO: is this really necessary? + # TODO: this extends more generally to all "parents" # TODO / Question: Do I need port to "port this up" to the MapScope level? I guess so? - cached_node_is_input = False - if memlet.data in cache: - for _name, in_memlet in node.in_memlets.items(): - if memlet.data == in_memlet.data: - cached_node_is_input = True - break - - if memlet.data not in cache or cached_node_is_input: - # cache write access node - write_access_node = self._current_state.add_write(memlet.data) - cache[memlet.data] = write_access_node + # cached_node_is_input = False + # if memlet.data in cache: + # for _name, in_memlet in node.in_memlets.items(): + # if memlet.data == in_memlet.data: + # cached_node_is_input = True + # break + + # if memlet.data not in cache or cached_node_is_input: + # cache write access node + write_access_node = self._current_state.add_write(memlet.data) + cache[memlet.data] = write_access_node access_node = cache[memlet.data] self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 9d607704d0..ab7bad004a 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -223,6 +223,7 @@ def test_create_state_boundary_state_transition(control_flow): @pytest.mark.xfail(reason="Not yet implemented") +@pytest.mark.parametrize("control_flow", (True, False)) def test_create_state_boundary_empty_memlet(control_flow): sdfg = dace.SDFG("tester") state = sdfg.add_state("start", is_start_block=True) @@ -707,7 +708,8 @@ def test_Ray_Fast_tmp(): test_state_boundaries_propagation(boundary=True) test_create_state_boundary_state_transition(control_flow=True) test_create_state_boundary_state_transition(control_flow=False) - test_create_state_boundary_empty_memlet() + # test_create_state_boundary_empty_memlet(control_flow=True) + # test_create_state_boundary_empty_memlet(control_flow=False) test_create_tasklet_raw() test_create_tasklet_waw() test_create_for_loop() From 37c56968f9b78517cdc4c3df1305236e763f0435 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 15 Jul 2025 15:13:12 +0200 Subject: [PATCH 090/137] Fix write access caching (hopefully) --- .../analysis/schedule_tree/tree_to_sdfg.py | 26 ++++------- tests/schedule_tree/to_sdfg_test.py | 43 +++++++++++++++++++ 2 files changed, 52 insertions(+), 17 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index d37f1355ac..35cef27002 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -494,7 +494,10 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: outer_to_connect["outputs"].add(name) # connect "outside the map" - if name not in access_cache: + # only re-use cached write-only nodes, e.g. don't create a cycle for + # map i=0:20: + # A[i] = tasklet(A[i]) + if name not in access_cache or self._current_state.out_degree(access_cache[name]) > 0: # cache write access into access_cache write_access_node = self._current_state.add_write(name) access_cache[name] = write_access_node @@ -591,23 +594,12 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Connect output memlets for name, memlet in node.out_memlets.items(): - # don't use cached access node, if it was an input, e.g. - # A[1] = tasklet() + # only re-use cached write-only nodes, e.g. don't create a cycle for # A[1] = tasklet(A[1]) - # TODO: is this really necessary? - # TODO: this extends more generally to all "parents" - # TODO / Question: Do I need port to "port this up" to the MapScope level? I guess so? - # cached_node_is_input = False - # if memlet.data in cache: - # for _name, in_memlet in node.in_memlets.items(): - # if memlet.data == in_memlet.data: - # cached_node_is_input = True - # break - - # if memlet.data not in cache or cached_node_is_input: - # cache write access node - write_access_node = self._current_state.add_write(memlet.data) - cache[memlet.data] = write_access_node + if memlet.data not in cache or self._current_state.out_degree(cache[memlet.data]) > 0: + # cache write access node + write_access_node = self._current_state.add_write(memlet.data) + cache[memlet.data] = write_access_node access_node = cache[memlet.data] self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index ab7bad004a..5d929e04b8 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -287,6 +287,32 @@ def test_create_tasklet_waw(): assert [(s2_tasklet, s2_anode)] == [(edge.src, edge.dst) for edge in s2.edges()] +def test_create_tasklet_war(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name="tester", + containers={"A": dace.data.Array(dace.float64, [20])}, + children=[ + tn.TaskletNode( + nodes.Tasklet("read_write", {"read"}, {"write"}, "write = read + 1"), + {"read": dace.Memlet("A[1]")}, + {"write": dace.Memlet("A[1]")}, + ) + ], + ) + + sdfg = stree.as_sdfg() + + sdfg_states = list(sdfg.states()) + assert len(sdfg_states) == 1 + + state_nodes = list(sdfg_states[0].nodes()) + assert [node.name for node in state_nodes + if isinstance(node, nodes.Tasklet)] == ["read_write"], "Expect one Tasklet node." + assert [node.data for node in state_nodes + if isinstance(node, nodes.AccessNode)] == ["A", "A"], "Expect two AccessNodes for A." + + def test_create_for_loop(): # yapf: disable loop=tn.ForScope( @@ -408,6 +434,23 @@ def test_create_map_scope_read_after_write(): sdfg.validate() +def test_create_map_scope_write_after_read(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name="tester", + containers={"A": dace.data.Array(dace.float64, [20])}, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("read_write", {"read"}, {"write"}, "write = read+1"), + {"read": dace.Memlet("A[i]")}, {"write": dace.Memlet("A[i]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + def test_create_map_scope_copy(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", From 6ed3890287759e2164a395a339d0b06d50923102 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 16 Jul 2025 13:20:26 +0200 Subject: [PATCH 091/137] Patch: DeadDataflowElimination can't inline pointers into Tasklets --- dace/transformation/passes/dead_dataflow_elimination.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 4f5d718bdc..69953f236b 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -263,6 +263,13 @@ def _is_node_dead(self, node: nodes.Node, sdfg: SDFG, state: SDFGState, dead_nod and any(ie.data.data == node.data for ie in state.in_edges(l.src))): return False + # If data is connected to a Tasklet + if isinstance(l.src, nodes.Tasklet): + # We can't inline connected data that is a pointer + ctype = infer_types.infer_out_connector_type(sdfg, state, l.src, l.src_conn) + if l.src.language == dtypes.Language.Python and isinstance(ctype, dtypes.pointer): + return False + # If it is a stream and is read somewhere in the state, it may be popped after pushing if isinstance(desc, data.Stream) and node.data in access_set[0]: return False From 82541a9401dcadca43edc33cf1db61a0fe21d0e5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 22 Jul 2025 10:43:47 +0200 Subject: [PATCH 092/137] Add support for NView nodes (+ minimal cleanup) This is a first version. To be tested on larger code to see if this is really all it needs. Includes minimal cleanup (e.g. a utility for finding the parent SDFG to get the array description from). --- .../analysis/schedule_tree/sdfg_to_tree.py | 6 + .../analysis/schedule_tree/tree_to_sdfg.py | 259 +++++++++++------- dace/sdfg/analysis/schedule_tree/treenodes.py | 18 ++ dace/subsets.py | 2 +- tests/schedule_tree/to_sdfg_test.py | 74 ++--- 5 files changed, 218 insertions(+), 141 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 7c6b2f88ec..4ee9caed5c 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -580,6 +580,12 @@ def state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Insert the nested SDFG flattened nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) result.extend(nested_stree.children) + + if generated_nviews: + # Insert matching NViewEnd nodes to define the scope NView nodes. + for target in generated_nviews: + result.append(tn.NViewEnd(target=target)) + elif isinstance(node, dace.nodes.Tasklet): in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 35cef27002..13c71dedf2 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,7 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict -from dace import symbolic +from dace import symbolic, data from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge @@ -34,6 +34,12 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._interstate_symbols: List[tn.AssignNode] = [] """Interstate symbol assignments. Will be assigned with the next state transition.""" + self._nviews_free: List[tn.NView] = [] + """Keep track of NView (nested SDFG view) nodes that are "free" to be used.""" + + self._nviews_bound_per_scope: Dict[int, List[tn.NView]] = {} + """Mapping of id(SDFG) -> list of active NView nodes in that SDFG.""" + # state management self._state_stack: List[SDFGState] = [] @@ -42,10 +48,35 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]] | Tuple[SDFG, Dict[str, Set[str]]]] = [] - # -- to be torched -- - # caches - # self._access_cache: Dict[SDFGState, Dict[str, nodes.AccessNode]] = {} - # end -- to be torched -- + def _apply_nview_array_override(self, array_name: str, sdfg: SDFG) -> bool: + """Apply an NView override if applicable. Returns true if the NView was applied.""" + length = len(self._nviews_free) + for index, nview in enumerate(reversed(self._nviews_free), start=1): + if nview.target == array_name: + # Add the "override" data descriptor + sdfg.add_datadesc(nview.target, nview.view_desc.clone()) + if nview.src_desc.transient: + sdfg.arrays[nview.target].transient = False + + # Keep track of used NViews per scope (to "free" them again once the scope ends) + self._nviews_bound_per_scope[id(sdfg)].append(nview) + + # This NView is in use now, remove it from the free NViews. + del self._nviews_free[length - index] + return True + + return False + + def _parent_sdfg_with_array(self, name: str, sdfg: SDFG) -> SDFG: + """Find the closest parent SDFG containing an array with the given name.""" + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while name not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + assert sdfg_counter < MAX_NESTED_SDFGS, f"Array '{name}' not found in any parent of SDFG '{sdfg.name}'." + return parent_sdfg def _pop_state(self, label: Optional[str] = None) -> SDFGState: """Pops the last state from the state stack. @@ -63,24 +94,6 @@ def _pop_state(self, label: Optional[str] = None) -> SDFGState: return popped - # def _ensure_access_cache(self, state: SDFGState) -> Dict[str, nodes.AccessNode]: - # """Ensure an access_cache entry for the given state. - - -# -# Checks if there exists an access_cache for `state`. Creates an empty one if it doesn't exist yet. -# -# :param SDFGState state: The state to check. -# -# :return: The state's access_cache. -# """ -# # -- to be torched -# raise RuntimeError("We shouldn't end up here anymore.") -# if state not in self._access_cache: -# self._access_cache[state] = {} -# -# return self._access_cache[state] - def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: # -- to be torched -- assert self._current_state is None, "Expected no 'current_state' at root." @@ -130,23 +143,23 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: for memlet in input_memlets: # Copy data descriptor from parent SDFG and add input connector if memlet.data not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet.data].transient: - sdfg.arrays[memlet.data].transient = False - # TODO - # ... unless they are only ever used inside the nested SDFG, in which case - # we should delete them from the parent SDFG's array list. - # NOTE This can probably be done automatically by a cleanup pass in the end. - # Something like DDE should be able to do this. - + parent_sdfg = self._parent_sdfg_with_array(memlet.data, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + # Dev note: nview.target and memlet.data are identical assert memlet.data not in to_connect["inputs"] to_connect["inputs"].add(memlet.data) return @@ -297,6 +310,7 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: old_state_label = self._current_state.label self._state_stack.append(self._current_state) self._dataflow_stack.append((inner_sdfg, {"inputs": set(), "outputs": set()})) + self._nviews_bound_per_scope[id(inner_sdfg)] = [] self._current_state = start_state # visit children @@ -328,8 +342,18 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: assert new_in_connector == True assert new_in_connector == new_out_connector - self._current_state.add_edge(map_entry, out_connector, nsdfg, name, - Memlet.from_array(name, nsdfg.sdfg.arrays[name])) + # Add Memlet for NView node (if applicable) + edge_added = False + for nview in self._nviews_bound_per_scope[id(inner_sdfg)]: + if name == nview.target: + self._current_state.add_edge(map_entry, out_connector, nsdfg, name, + Memlet.from_memlet(nview.memlet)) + edge_added = True + break + + if not edge_added: + self._current_state.add_edge(map_entry, out_connector, nsdfg, name, + Memlet.from_array(name, nsdfg.sdfg.arrays[name])) # Add empty memlet if we didn't add any in the loop above if self._current_state.out_degree(map_entry) < 1: @@ -337,7 +361,22 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: # connect nsdfg output memlets (to be propagated) for name in nsdfg.out_connectors: - to_connect[name] = (nsdfg, Memlet.from_array(name, nsdfg.sdfg.arrays[name])) + # Add memlets for NView node (if applicable) + edge_added = False + for nview in self._nviews_bound_per_scope[id(inner_sdfg)]: + if name == nview.target: + to_connect[name] = (nsdfg, Memlet.from_memlet(nview.memlet)) + edge_added = True + break + + if not edge_added: + to_connect[name] = (nsdfg, Memlet.from_array(name, nsdfg.sdfg.arrays[name])) + + # Move NViews back to "free" NViews for usage in a sibling scope. + for nview in self._nviews_bound_per_scope[id(inner_sdfg)]: + self._nviews_free.append(nview) + + del self._nviews_bound_per_scope[id(inner_sdfg)] def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) @@ -364,9 +403,6 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: with node.scope(self._current_state, self._ctx): self.visit(node.children, sdfg=sdfg) - if cache_state != self._current_state: - breakpoint - cache_key = (cache_state, id(self._ctx.current_scope)) if cache_key not in self._ctx.access_cache: self._ctx.access_cache[cache_key] = {} @@ -406,23 +442,23 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: if isinstance(outer_map_entry, SDFG): # Copy data descriptor from parent SDFG and add input connector if memlet_data not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while memlet_data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(memlet_data, parent_sdfg.arrays[memlet_data].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet_data].transient: - sdfg.arrays[memlet_data].transient = False - # TODO - # ... unless they are only ever used inside the nested SDFG, in which case - # we should delete them from the parent SDFG's array list. - # NOTE This can probably be done automatically by a cleanup pass in the end. - # Something like DDE should be able to do this. - + parent_sdfg: SDFG = self._parent_sdfg_with_array(memlet_data, sdfg) + + # Add support for NView nodes + use_nview = self._apply_nview_array_override(memlet_data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet_data, parent_sdfg.arrays[memlet_data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet_data].transient: + sdfg.arrays[memlet_data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + # Dev note: nview.target and memlet_data are identical assert memlet_data not in outer_to_connect["inputs"] outer_to_connect["inputs"].add(memlet_data) else: @@ -478,19 +514,19 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: if isinstance(outer_map_entry, SDFG): if name not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while name not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(name, parent_sdfg.arrays[name].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[name].transient: - sdfg.arrays[name].transient = False + parent_sdfg = self._parent_sdfg_with_array(name, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(name, sdfg) + if not use_nview: + sdfg.add_datadesc(name, parent_sdfg.arrays[name].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[name].transient: + sdfg.arrays[name].transient = False # Add out_connector in any case if not yet present, e.g. write after read + # Dev not: name and nview.target are identical outer_to_connect["outputs"].add(name) # connect "outside the map" @@ -560,23 +596,23 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: if isinstance(scope_node, SDFG): # Copy data descriptor from parent SDFG and add input connector if memlet.data not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet.data].transient: - sdfg.arrays[memlet.data].transient = False - # TODO - # ... unless they are only ever used inside the nested SDFG, in which case - # we should delete them from the parent SDFG's array list. - # NOTE This can probably be done automatically by a cleanup pass in the end. - # Something like DDE should be able to do this. + parent_sdfg = self._parent_sdfg_with_array(memlet.data, sdfg) + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + # Dev note: memlet.data and nview.target are identical assert memlet.data not in to_connect["inputs"] to_connect["inputs"].add(memlet.data) else: @@ -611,19 +647,19 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: if isinstance(scope_node, SDFG): if memlet.data not in sdfg.arrays: - parent_sdfg = sdfg.parent.parent - sdfg_counter = 1 - while memlet.data not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: - parent_sdfg = parent_sdfg.parent.parent - assert isinstance(parent_sdfg, SDFG) - sdfg_counter += 1 - sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) - - # Transients passed into a nested SDFG become non-transient inside that nested SDFG - if parent_sdfg.arrays[memlet.data].transient: - sdfg.arrays[memlet.data].transient = False + parent_sdfg: SDFG = self._parent_sdfg_with_array(memlet.data, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False # Add out_connector in any case if not yet present, e.g. write after read + # Dev note: memlet.data and nview.target are identical to_connect["outputs"].add(memlet.data) else: @@ -659,8 +695,25 @@ def visit_ViewNode(self, node: tn.ViewNode, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: - # TODO: Fillz and Ray_Fast will need these ... - raise NotImplementedError(f"{type(node)} not implemented") + # Basic working principle: + # + # - NView and (artificial) NViewEnd nodes are added in parallel to mark the region where the view applies. + # - Keep a stack of NView nodes (per name) that is pushed/popped when NView and NViewEnd nodes are visited. + # - In between, when going "down into" a NestedSDFG, use the current NView (if it applies) + # - In between, when "coming back up" from a NestedSDFG, pop the NView from the stack. + # - AccessNodes will automatically pick up the right name (from the NestedSDFG's array list) + self._nviews_free.append(node) + + def visit_NViewEnd(self, node: tn.NViewEnd, sdfg: SDFG) -> None: + length = len(self._nviews_free) + + for index, nview in enumerate(reversed(self._nviews_free), start=1): + if node.target == nview.target: + # Stack semantics: remove from the back of the list + del self._nviews_free[length - index] + return + + raise RuntimeError(f"No matching NView found for target {node.target} in {self._nviews_free}.") def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: # Let's see if we need this for the first prototype ... diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index aa6eb58951..f205bbed2f 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -743,6 +743,24 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = nview {self.memlet} as {self.view_desc.shape}' +@dataclass +class NViewEnd(ScheduleTreeNode): + """ + Artificial node to denote the scope end of the associated Nested SDFG view node. + """ + + target: str #: target name of the associated NView container + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f"end nview {self.target}" + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + @dataclass class RefSetNode(ScheduleTreeNode): """ diff --git a/dace/subsets.py b/dace/subsets.py index c5e0debc0b..82353bb5c4 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -80,7 +80,7 @@ def covers(self, other): # Subsets of different dimensionality can never cover each other. if self.dims() != other.dims(): return ValueError( - f"A subset of dimensionality {self.dim()} cannot test covering a subset of dimensionality {other.dims()}" + f"A subset of dimensionality {self.dims()} cannot test covering a subset of dimensionality {other.dims()}" ) if not Config.get('optimizer', 'symbolic_positive'): diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 5d929e04b8..1f0590c912 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -3,7 +3,7 @@ Tests components in conversion of schedule trees to SDFGs. """ import dace -from dace import subsets as sbs +from dace import data, subsets as sbs from dace.codegen import control_flow as cf from dace.properties import CodeBlock from dace.sdfg import nodes @@ -16,7 +16,7 @@ def test_state_boundaries_none(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), @@ -34,7 +34,7 @@ def test_state_boundaries_waw(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), @@ -53,7 +53,7 @@ def test_state_boundaries_waw_ranges(overlap): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, symbols={'N': N}, children=[ @@ -75,8 +75,8 @@ def test_state_boundaries_war(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, @@ -94,8 +94,8 @@ def test_state_boundaries_read_write_chain(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, @@ -116,8 +116,8 @@ def test_state_boundaries_data_race(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, @@ -141,7 +141,7 @@ def test_state_boundaries_cfg(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla1', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), @@ -160,7 +160,7 @@ def test_state_boundaries_state_transition(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, symbols={ 'N': dace.symbol('N'), @@ -186,7 +186,7 @@ def test_state_boundaries_propagation(boundary): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, symbols={ 'N': N, @@ -237,7 +237,7 @@ def test_create_tasklet_raw(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), @@ -268,7 +268,7 @@ def test_create_tasklet_waw(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), @@ -291,7 +291,7 @@ def test_create_tasklet_war(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", - containers={"A": dace.data.Array(dace.float64, [20])}, + containers={"A": data.Array(dace.float64, [20])}, children=[ tn.TaskletNode( nodes.Tasklet("read_write", {"read"}, {"write"}, "write = read + 1"), @@ -328,7 +328,7 @@ def test_create_for_loop(): # yapf: enable # Manually create a schedule tree - stree = tn.ScheduleTreeRoot(name='tester', containers={'A': dace.data.Array(dace.float64, [20])}, children=[loop]) + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[loop]) sdfg = stree.as_sdfg() sdfg.validate() @@ -353,7 +353,7 @@ def test_create_while_loop(): # yapf: enable # Manually create a schedule tree - stree = tn.ScheduleTreeRoot(name='tester', containers={'A': dace.data.Array(dace.float64, [20])}, children=[loop]) + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[loop]) sdfg = stree.as_sdfg() sdfg.validate() @@ -362,7 +362,7 @@ def test_create_while_loop(): def test_create_if_else(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': dace.data.Array(dace.float64, [20])}, + containers={'A': data.Array(dace.float64, [20])}, children=[ tn.IfScope(condition=CodeBlock("A[0] > 0"), children=[ @@ -382,7 +382,7 @@ def test_create_if_else(): def test_create_if_without_else(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': dace.data.Array(dace.float64, [20])}, + containers={'A': data.Array(dace.float64, [20])}, children=[ tn.IfScope(condition=CodeBlock("A[0] > 0"), children=[ @@ -398,7 +398,7 @@ def test_create_if_without_else(): def test_create_map_scope_write(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': dace.data.Array(dace.float64, [20])}, + containers={'A': data.Array(dace.float64, [20])}, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), @@ -417,8 +417,8 @@ def test_create_map_scope_read_after_write(): stree = tn.ScheduleTreeRoot( name="tester", containers={ - 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [20], transient=True), + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20], transient=True), }, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), @@ -438,7 +438,7 @@ def test_create_map_scope_write_after_read(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", - containers={"A": dace.data.Array(dace.float64, [20])}, + containers={"A": data.Array(dace.float64, [20])}, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), children=[ @@ -455,8 +455,8 @@ def test_create_map_scope_copy(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", containers={ - 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), }, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", @@ -477,8 +477,8 @@ def test_create_map_scope_double_memlet(): stree = tn.ScheduleTreeRoot( name="tester", containers={ - 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), }, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:10"))), @@ -498,7 +498,7 @@ def test_create_nested_map_scope(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", - containers={'A': dace.data.Array(dace.float64, [20])}, + containers={'A': data.Array(dace.float64, [20])}, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), children=[ @@ -519,8 +519,8 @@ def test_create_nested_map_scope_multi_read(): stree = tn.ScheduleTreeRoot( name="tester", containers={ - 'A': dace.data.Array(dace.float64, [20]), - 'B': dace.data.Array(dace.float64, [10]) + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [10]) }, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), @@ -543,7 +543,7 @@ def test_create_nested_map_scope_multi_read(): def test_map_with_state_boundary_inside(): # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': dace.data.Array(dace.float64, [20])}, + containers={'A': data.Array(dace.float64, [20])}, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), @@ -564,8 +564,8 @@ def test_map_calculate_temporary_in_two_loops(): stree = tn.ScheduleTreeRoot( name="tester", containers={ - "A": dace.data.Array(dace.float64, [20]), - "tmp": dace.data.Array(dace.float64, [20], transient=True) + "A": data.Array(dace.float64, [20]), + "tmp": data.Array(dace.float64, [20], transient=True) }, children=[ tn.MapScope(node=nodes.MapEntry(nodes.Map("first_half", "i", sbs.Range.from_string("0:10"))), @@ -612,7 +612,7 @@ def test_assign_nodes_force_state_transition(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), @@ -629,7 +629,7 @@ def test_assign_nodes_multiple_force_one_transition(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), @@ -649,7 +649,7 @@ def test_assign_nodes_avoid_duplicate_boundaries(): stree = tn.ScheduleTreeRoot( name='tester', containers={ - 'A': dace.data.Array(dace.float64, [20]), + 'A': data.Array(dace.float64, [20]), }, children=[ tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), From bd9a0472275000b6275a2a63410a68d40ab39c93 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 25 Jul 2025 18:27:46 +0200 Subject: [PATCH 093/137] NView support: second half --- .../analysis/schedule_tree/tree_to_sdfg.py | 30 +++++++++++++++++-- .../passes/constant_propagation.py | 2 +- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 13c71dedf2..d5f7fe8969 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -31,6 +31,9 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._current_state = start_state """Current SDFGState in the SDFG that we are building.""" + self._current_nestedSDFG: int | None = None + """Id of the current nested SDFG if we are inside one.""" + self._interstate_symbols: List[tn.AssignNode] = [] """Interstate symbol assignments. Will be assigned with the next state transition.""" @@ -40,6 +43,9 @@ def __init__(self, start_state: Optional[SDFGState] = None) -> None: self._nviews_bound_per_scope: Dict[int, List[tn.NView]] = {} """Mapping of id(SDFG) -> list of active NView nodes in that SDFG.""" + self._nviews_deferred_removal: Dict[int, List[tn.NView]] = {} + """"Mapping of id(SDFG) -> list of NView nodes to be removed once we exit this nested SDFG.""" + # state management self._state_stack: List[SDFGState] = [] @@ -52,7 +58,7 @@ def _apply_nview_array_override(self, array_name: str, sdfg: SDFG) -> bool: """Apply an NView override if applicable. Returns true if the NView was applied.""" length = len(self._nviews_free) for index, nview in enumerate(reversed(self._nviews_free), start=1): - if nview.target == array_name: + if nview.target == array_name and nview not in self._nviews_deferred_removal[id(sdfg)]: # Add the "override" data descriptor sdfg.add_datadesc(nview.target, nview.view_desc.clone()) if nview.src_desc.transient: @@ -301,6 +307,7 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) state_stack_size = len(self._state_stack) + outer_nestedSDFG = self._current_nestedSDFG # prepare inner SDFG inner_sdfg = SDFG("nested_sdfg", parent=self._current_state) @@ -311,6 +318,8 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: self._state_stack.append(self._current_state) self._dataflow_stack.append((inner_sdfg, {"inputs": set(), "outputs": set()})) self._nviews_bound_per_scope[id(inner_sdfg)] = [] + self._nviews_deferred_removal[id(inner_sdfg)] = [] + self._current_nestedSDFG = id(inner_sdfg) self._current_state = start_state # visit children @@ -374,9 +383,18 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: # Move NViews back to "free" NViews for usage in a sibling scope. for nview in self._nviews_bound_per_scope[id(inner_sdfg)]: + # If this NView ended in the current nested SDFG, don't add it back to the + # "free NView" nodes. We need to keep it alive until here to make sure that + # we can add the memlets above. + if nview in self._nviews_deferred_removal[id(inner_sdfg)]: + continue self._nviews_free.append(nview) del self._nviews_bound_per_scope[id(inner_sdfg)] + del self._nviews_deferred_removal[id(inner_sdfg)] + + # Restore current nested SDFG + self._current_nestedSDFG = outer_nestedSDFG def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) @@ -705,8 +723,16 @@ def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: self._nviews_free.append(node) def visit_NViewEnd(self, node: tn.NViewEnd, sdfg: SDFG) -> None: - length = len(self._nviews_free) + # If bound to the current nested SDFG, defer cleanup + if self._current_nestedSDFG is not None: + currently_bound = self._nviews_bound_per_scope[self._current_nestedSDFG] + for index, nview in enumerate(reversed(currently_bound)): + if node.target == nview.target: + # Bound to current nested SDFG. Slate for deferred removal once we exit that nested SDFG. + self._nviews_deferred_removal[self._current_nestedSDFG].append(nview) + return + length = len(self._nviews_free) for index, nview in enumerate(reversed(self._nviews_free), start=1): if node.target == nview.target: # Stack semantics: remove from the back of the list diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index bfa0928415..0ac66e0ac9 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -28,7 +28,7 @@ class ConstantPropagation(ppl.Pass): CATEGORY: str = 'Simplification' - recursive = properties.Property(dtype=bool, default=True, desc='Propagagte recursively through nested SDFGs') + recursive = properties.Property(dtype=bool, default=True, desc='Propagate recursively through nested SDFGs') progress = properties.Property(dtype=bool, default=None, allow_none=True, desc='Show progress') def modifies(self) -> ppl.Modifies: From 7882153446a58d141fb013eeb88b2c198145e93d Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Sat, 26 Jul 2025 16:36:31 +0200 Subject: [PATCH 094/137] Backport of #2098 for v1/maintenance (#2099) --- .../passes/dead_dataflow_elimination.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 4f5d718bdc..9bb90c3ea4 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -258,6 +258,19 @@ def _is_node_dead(self, node: nodes.Node, sdfg: SDFG, state: SDFGState, dead_nod if _has_side_effects(l.src, sdfg): return False + # If data is connected to a tasklet through a pointer and more than 1 element is accessed, + # we cannot eliminate the connector, as it may require dataflow analysis inside the tasklet. + if isinstance(l.src, nodes.Tasklet): + ctype = infer_types.infer_out_connector_type(sdfg, state, l.src, l.src_conn) + if isinstance(ctype, dtypes.pointer): + is_larger = False + try: + is_larger = l.data.volume > 1 + except ValueError: + is_larger = True + if is_larger: + return False + # If data is connected to a nested SDFG or library node as an input/output, do not remove if (isinstance(l.src, (nodes.NestedSDFG, nodes.LibraryNode)) and any(ie.data.data == node.data for ie in state.in_edges(l.src))): From 74333ac0fbd3c0802948036da36f0673751310f1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 6 Aug 2025 18:14:19 +0200 Subject: [PATCH 095/137] Re-add support for CopyNodes --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index d5f7fe8969..2592f7a899 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -688,11 +688,11 @@ def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: - raise NotImplementedError("Not yet ported to new stree bridge") - - # apparently we need this for the first prototype - self._ensure_access_cache(self._current_state) - access_cache = self._access_cache[self._current_state] + # ensure we have an access_cache and fetch it + cache_key = (self._current_state, id(self._ctx.current_scope)) + if cache_key not in self._ctx.access_cache: + self._ctx.access_cache[cache_key] = {} + access_cache = self._ctx.access_cache[cache_key] # assumption source access may or may not yet exist (in this state) src_name = node.memlet.data From eb160ced537eae79dbebb07dcbe7b53094876d38 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 19 Aug 2025 09:21:47 +0200 Subject: [PATCH 096/137] Mitigation: don't inline nested SDFG without unique connector names --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 4ee9caed5c..1c438e6fb7 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -67,6 +67,17 @@ def dealias_sdfg(sdfg: SDFG): inv_replacements[parent_name] = [name] break + # This function assumes input/outputs of the nested SDFG to be uniquely named. + # Let's assert this because failure to comply will result a potentially working, + # but potentially wrong schedule tree, which is non-trivial to debug. + for replacement in replacements.keys(): + connectors = list(parent_state.edges_by_connector(parent_node, replacement)) + if len(connectors) > 1: + raise ValueError( + f"Expected in/out connectors of nested SDFG '{parent_node.label}' to be uniquely named. " + f"Found duplicate '{replacement}' in inputs '{parent_node.in_connectors}' and " + f"outputs '{parent_node.out_connectors}'.") + if to_unsqueeze: for parent_name in to_unsqueeze: parent_arr = parent_sdfg.arrays[parent_name] From 77dc937350e4fb5e0166c481993d1fa0f51d4ea2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 19 Aug 2025 16:41:34 +0200 Subject: [PATCH 097/137] fix: support de-aliasing SDFGs with in/out conns of the same name This commit adds support for de-aliasing nested SFDGs with input/output connectors that have the same name. When unsqueezing memlets (in the middle of the function) it's important to pick the right "outer memlet" and if we don't separate inputs and outputs, we (randomly) choose either one for both, which can lead to subtle bugs. --- .../analysis/schedule_tree/sdfg_to_tree.py | 71 ++++++++++++------- 1 file changed, 46 insertions(+), 25 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 1c438e6fb7..1c01bc95eb 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -44,7 +44,8 @@ def dealias_sdfg(sdfg: SDFG): replacements: Dict[str, str] = {} inv_replacements: Dict[str, List[str]] = {} - parent_edges: Dict[str, Memlet] = {} + parent_edges_inputs: Dict[str, Memlet] = {} + parent_edges_outputs: Dict[str, Memlet] = {} to_unsqueeze: Set[str] = set() parent_sdfg = nsdfg.parent_sdfg @@ -54,29 +55,42 @@ def dealias_sdfg(sdfg: SDFG): for name, desc in nsdfg.arrays.items(): if desc.transient: continue - for edge in parent_state.edges_by_connector(parent_node, name): + for edge in parent_state.in_edges_by_connector(parent_node, name): parent_name = edge.data.data assert parent_name in parent_sdfg.arrays if name != parent_name: + parent_edges_inputs[name] = edge + replacements[name] = parent_name - parent_edges[name] = edge if parent_name in inv_replacements: inv_replacements[parent_name].append(name) to_unsqueeze.add(parent_name) else: inv_replacements[parent_name] = [name] + # We found an incoming edge for name and we don't expect a second one. break - # This function assumes input/outputs of the nested SDFG to be uniquely named. - # Let's assert this because failure to comply will result a potentially working, - # but potentially wrong schedule tree, which is non-trivial to debug. - for replacement in replacements.keys(): - connectors = list(parent_state.edges_by_connector(parent_node, replacement)) - if len(connectors) > 1: - raise ValueError( - f"Expected in/out connectors of nested SDFG '{parent_node.label}' to be uniquely named. " - f"Found duplicate '{replacement}' in inputs '{parent_node.in_connectors}' and " - f"outputs '{parent_node.out_connectors}'.") + for edge in parent_state.out_edges_by_connector(parent_node, name): + parent_name = edge.data.data + assert parent_name in parent_sdfg.arrays + if name != parent_name: + parent_edges_outputs[name] = edge + + if replacements.get(name, None) is not None: + # There's an incoming and an outgoing connector with the same name. + # Make sure both map to the same memory in the parent sdfg + assert replacements[name] == parent_name + assert name in inv_replacements[parent_name] + break + else: + replacements[name] = parent_name + if parent_name in inv_replacements: + inv_replacements[parent_name].append(name) + to_unsqueeze.add(parent_name) + else: + inv_replacements[parent_name] = [name] + # We found an outgoing edge for name and we don't expect a second one. + break if to_unsqueeze: for parent_name in to_unsqueeze: @@ -106,14 +120,18 @@ def dealias_sdfg(sdfg: SDFG): # destination subset if isinstance(src, nd.AccessNode) and src.data in child_names: src_data = src.data - new_src_memlet = unsqueeze_memlet(e.data, parent_edges[src.data].data, use_src_subset=True) + new_src_memlet = unsqueeze_memlet(e.data, + parent_edges_inputs[src.data].data, + use_src_subset=True) else: src_data = None new_src_memlet = None # We need to take directionality of the memlet into account if isinstance(dst, nd.AccessNode) and dst.data in child_names: dst_data = dst.data - new_dst_memlet = unsqueeze_memlet(e.data, parent_edges[dst.data].data, use_dst_subset=True) + new_dst_memlet = unsqueeze_memlet(e.data, + parent_edges_outputs[dst.data].data, + use_dst_subset=True) else: dst_data = None new_dst_memlet = None @@ -132,23 +150,26 @@ def dealias_sdfg(sdfg: SDFG): syms = e.data.read_symbols() for memlet in e.data.get_read_memlets(nsdfg.arrays): if memlet.data in child_names: - repl_dict[str(memlet)] = unsqueeze_memlet(memlet, parent_edges[memlet.data].data) + repl_dict[str(memlet)] = unsqueeze_memlet(memlet, parent_edges_inputs[memlet.data].data) if memlet.data in syms: syms.remove(memlet.data) for s in syms: - if s in parent_edges: + if s in parent_edges_inputs: if s in nsdfg.arrays: - repl_dict[s] = parent_edges[s].data.data + repl_dict[s] = parent_edges_inputs[s].data.data else: - repl_dict[s] = str(parent_edges[s].data) + repl_dict[s] = str(parent_edges_inputs[s].data) e.data.replace_dict(repl_dict) for name in child_names: - edge = parent_edges[name] - for e in parent_state.memlet_tree(edge): - if e.data.data == parent_name: - e.data.subset = subsets.Range.from_array(parent_arr) - else: - e.data.other_subset = subsets.Range.from_array(parent_arr) + for edge in [parent_edges_inputs.get(name, None), parent_edges_outputs.get(name, None)]: + if edge is None: + continue + + for e in parent_state.memlet_tree(edge): + if e.data.data == parent_name: + e.data.subset = subsets.Range.from_array(parent_arr) + else: + e.data.other_subset = subsets.Range.from_array(parent_arr) if replacements: symbolic.safe_replace(replacements, lambda d: replace_datadesc_names(nsdfg, d), value_as_string=True) From 54c669ed0f0bc97eb077e84ace58f7596744c8b5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 20 Aug 2025 17:05:28 +0200 Subject: [PATCH 098/137] fix: de-allocate arrays that are move from stack to heap --- dace/codegen/targets/cpu.py | 12 +++++++++--- tests/codegen/cpp_test.py | 36 +++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 9ba202757e..534c14a30c 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -485,10 +485,10 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV if symbolic.issymbolic(arrsize, sdfg.constants): warnings.warn('Variable-length array %s with size %s ' - 'detected and was allocated on heap instead of ' + 'detected and was allocated on the heap instead of ' '%s' % (name, cpp.sym2cpp(arrsize), nodedesc.storage)) elif (arrsize_bytes > Config.get("compiler", "max_stack_array_size")) == True: - warnings.warn("Array {} with size {} detected and was allocated on heap instead of " + warnings.warn("Array {} with size {} detected and was allocated on the heap instead of " "{} since its size is greater than max_stack_array_size ({})".format( name, cpp.sym2cpp(arrsize_bytes), nodedesc.storage, Config.get("compiler", "max_stack_array_size"))) @@ -572,6 +572,10 @@ def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgrap node: nodes.AccessNode, nodedesc: data.Data, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: arrsize = nodedesc.total_size + arrsize_bytes = None + if not isinstance(nodedesc.dtype, dtypes.opaque): + arrsize_bytes = arrsize * nodedesc.dtype.bytes + alloc_name = cpp.ptr(node.data, nodedesc, sdfg, self._frame) if isinstance(nodedesc, data.Array) and nodedesc.start_offset != 0: alloc_name = f'({alloc_name} - {cpp.sym2cpp(nodedesc.start_offset)})' @@ -584,7 +588,9 @@ def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgrap if isinstance(nodedesc, (data.Scalar, data.View, data.Stream, data.Reference)): return elif (nodedesc.storage == dtypes.StorageType.CPU_Heap - or (nodedesc.storage == dtypes.StorageType.Register and symbolic.issymbolic(arrsize, sdfg.constants))): + or (nodedesc.storage == dtypes.StorageType.Register and + (symbolic.issymbolic(arrsize, sdfg.constants) or + (arrsize_bytes and ((arrsize_bytes > Config.get("compiler", "max_stack_array_size")) == True))))): callsite_stream.write("delete[] %s;\n" % alloc_name, cfg, state_id, node) elif nodedesc.storage is dtypes.StorageType.CPU_ThreadLocal: # Deallocate in each OpenMP thread diff --git a/tests/codegen/cpp_test.py b/tests/codegen/cpp_test.py index 667997216b..271539455b 100644 --- a/tests/codegen/cpp_test.py +++ b/tests/codegen/cpp_test.py @@ -3,9 +3,10 @@ from functools import reduce from operator import mul from typing import Dict, Collection +import warnings -import dace -from dace import SDFG, Memlet +from dace import SDFG, Memlet, dtypes +from dace.codegen import codegen from dace.codegen.targets import cpp from dace.sdfg.state import SDFGState from dace.subsets import Range @@ -165,9 +166,9 @@ def test_reshape_strides_from_strided_and_offset_range(): def redundant_array_crashes_codegen_test_original_graph(): g = SDFG('prog') - g.add_array('A', (5, 5), dace.float32) - g.add_array('b', (1,), dace.float32, transient=True) - g.add_array('c', (5, 5), dace.float32, transient=True) + g.add_array('A', (5, 5), dtypes.float32) + g.add_array('b', (1,), dtypes.float32, transient=True) + g.add_array('c', (5, 5), dtypes.float32, transient=True) st0 = g.add_state('st0', is_start_block=True) st = st0 @@ -211,6 +212,30 @@ def test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now(): assert e.data.free_symbols == {'i', 'j'} +def test_arrays_bigger_than_max_stack_size_get_deallocated(): + # Setup SDFG with array A that is too big to be allocated on the stack. + sdfg = SDFG("test") + sdfg.add_array(name="A", shape=(10000,), dtype=dtypes.float64, storage=dtypes.StorageType.Register, transient=True) + state = sdfg.add_state("state", is_start_block=True) + read = state.add_access("A") + tasklet = state.add_tasklet("dummy", {"a"}, {}, "a = 1") + state.add_memlet_path(read, tasklet, dst_conn="a", memlet=Memlet("A[0]")) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Generate code for the program by traversing the SDFG state by state + program_objects = codegen.generate_code(sdfg) + + # Assert that we get the expected warning message + assert w + assert any("was allocated on the heap instead of" in str(warn.message) for warn in w) + + # In code, assert that we allocate _and_ deallocate on the heap + code = program_objects[0].clean_code + assert code.find("A = new double") > 0, "A is allocated on the heap." + assert code.find("delete[] A") > 0, "A is deallocated from the heap." + + if __name__ == '__main__': test_reshape_strides_multidim_array_all_dims_unit() test_reshape_strides_multidim_array_some_dims_unit() @@ -219,3 +244,4 @@ def test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now(): test_reshape_strides_from_strided_and_offset_range() test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now() + test_arrays_bigger_than_max_stack_size_get_deallocated() From 830ffa58a5d4f3f1ea023de4819d4e49ada395f8 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 3 Sep 2025 15:38:16 +0200 Subject: [PATCH 099/137] Fix: delete array moved to heap (backport to `v1/maintenance`) (#2135) # Description If arrays have storage type Register and are bigger than `max_stack_array_size`, they are moved to the heap. In that case, since allocation is now dynamic, the array also has to be de-allocated again. This issue is also present in mainline DaCe. This PR is a backport of [insert other PR here]. Co-authored-by: Roman Cattaneo <1116746+romanc@users.noreply.github.com> --- dace/codegen/targets/cpu.py | 12 +++++++++--- tests/codegen/cpp_test.py | 36 +++++++++++++++++++++++++++++++----- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 9ba202757e..534c14a30c 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -485,10 +485,10 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV if symbolic.issymbolic(arrsize, sdfg.constants): warnings.warn('Variable-length array %s with size %s ' - 'detected and was allocated on heap instead of ' + 'detected and was allocated on the heap instead of ' '%s' % (name, cpp.sym2cpp(arrsize), nodedesc.storage)) elif (arrsize_bytes > Config.get("compiler", "max_stack_array_size")) == True: - warnings.warn("Array {} with size {} detected and was allocated on heap instead of " + warnings.warn("Array {} with size {} detected and was allocated on the heap instead of " "{} since its size is greater than max_stack_array_size ({})".format( name, cpp.sym2cpp(arrsize_bytes), nodedesc.storage, Config.get("compiler", "max_stack_array_size"))) @@ -572,6 +572,10 @@ def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgrap node: nodes.AccessNode, nodedesc: data.Data, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: arrsize = nodedesc.total_size + arrsize_bytes = None + if not isinstance(nodedesc.dtype, dtypes.opaque): + arrsize_bytes = arrsize * nodedesc.dtype.bytes + alloc_name = cpp.ptr(node.data, nodedesc, sdfg, self._frame) if isinstance(nodedesc, data.Array) and nodedesc.start_offset != 0: alloc_name = f'({alloc_name} - {cpp.sym2cpp(nodedesc.start_offset)})' @@ -584,7 +588,9 @@ def deallocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgrap if isinstance(nodedesc, (data.Scalar, data.View, data.Stream, data.Reference)): return elif (nodedesc.storage == dtypes.StorageType.CPU_Heap - or (nodedesc.storage == dtypes.StorageType.Register and symbolic.issymbolic(arrsize, sdfg.constants))): + or (nodedesc.storage == dtypes.StorageType.Register and + (symbolic.issymbolic(arrsize, sdfg.constants) or + (arrsize_bytes and ((arrsize_bytes > Config.get("compiler", "max_stack_array_size")) == True))))): callsite_stream.write("delete[] %s;\n" % alloc_name, cfg, state_id, node) elif nodedesc.storage is dtypes.StorageType.CPU_ThreadLocal: # Deallocate in each OpenMP thread diff --git a/tests/codegen/cpp_test.py b/tests/codegen/cpp_test.py index 667997216b..271539455b 100644 --- a/tests/codegen/cpp_test.py +++ b/tests/codegen/cpp_test.py @@ -3,9 +3,10 @@ from functools import reduce from operator import mul from typing import Dict, Collection +import warnings -import dace -from dace import SDFG, Memlet +from dace import SDFG, Memlet, dtypes +from dace.codegen import codegen from dace.codegen.targets import cpp from dace.sdfg.state import SDFGState from dace.subsets import Range @@ -165,9 +166,9 @@ def test_reshape_strides_from_strided_and_offset_range(): def redundant_array_crashes_codegen_test_original_graph(): g = SDFG('prog') - g.add_array('A', (5, 5), dace.float32) - g.add_array('b', (1,), dace.float32, transient=True) - g.add_array('c', (5, 5), dace.float32, transient=True) + g.add_array('A', (5, 5), dtypes.float32) + g.add_array('b', (1,), dtypes.float32, transient=True) + g.add_array('c', (5, 5), dtypes.float32, transient=True) st0 = g.add_state('st0', is_start_block=True) st = st0 @@ -211,6 +212,30 @@ def test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now(): assert e.data.free_symbols == {'i', 'j'} +def test_arrays_bigger_than_max_stack_size_get_deallocated(): + # Setup SDFG with array A that is too big to be allocated on the stack. + sdfg = SDFG("test") + sdfg.add_array(name="A", shape=(10000,), dtype=dtypes.float64, storage=dtypes.StorageType.Register, transient=True) + state = sdfg.add_state("state", is_start_block=True) + read = state.add_access("A") + tasklet = state.add_tasklet("dummy", {"a"}, {}, "a = 1") + state.add_memlet_path(read, tasklet, dst_conn="a", memlet=Memlet("A[0]")) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + # Generate code for the program by traversing the SDFG state by state + program_objects = codegen.generate_code(sdfg) + + # Assert that we get the expected warning message + assert w + assert any("was allocated on the heap instead of" in str(warn.message) for warn in w) + + # In code, assert that we allocate _and_ deallocate on the heap + code = program_objects[0].clean_code + assert code.find("A = new double") > 0, "A is allocated on the heap." + assert code.find("delete[] A") > 0, "A is deallocated from the heap." + + if __name__ == '__main__': test_reshape_strides_multidim_array_all_dims_unit() test_reshape_strides_multidim_array_some_dims_unit() @@ -219,3 +244,4 @@ def test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now(): test_reshape_strides_from_strided_and_offset_range() test_redundant_array_does_not_crash_codegen_but_produces_bad_graph_now() + test_arrays_bigger_than_max_stack_size_get_deallocated() From 1810f49c9b6a8bf43ae764bf707ba7cac0264632 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 9 Oct 2025 21:27:43 +0200 Subject: [PATCH 100/137] Backport Disabling Failing FPGA tests (#2171) Backport of #2169 --- dace/viewer/webclient | 2 +- setup.py | 2 +- tests/blas/nodes/axpy_test.py | 5 +++++ tests/fpga/streaming_memory_test.py | 8 ++++++++ tests/fpga/vec_sum_test.py | 4 +++- tests/fpga/veclen_conversion_connector_test.py | 5 ++++- tests/npbench/misc/arc_distance_test.py | 2 ++ 7 files changed, 24 insertions(+), 4 deletions(-) diff --git a/dace/viewer/webclient b/dace/viewer/webclient index 64861bbc05..f498ea57e6 160000 --- a/dace/viewer/webclient +++ b/dace/viewer/webclient @@ -1 +1 @@ -Subproject commit 64861bbc054c62bc6cb3f8525bfc4703d6c5e364 +Subproject commit f498ea57e69fe9963ceff3602c0a131e563743e0 diff --git a/setup.py b/setup.py index c228ae4558..b8c4b79364 100644 --- a/setup.py +++ b/setup.py @@ -74,7 +74,7 @@ include_package_data=True, install_requires=[ 'numpy < 2.0', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', - 'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', + 'fparser == 0.2.0', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"', 'packaging' ] + cmake_requires, extras_require={ diff --git a/tests/blas/nodes/axpy_test.py b/tests/blas/nodes/axpy_test.py index ed9ca4bad4..297961b78c 100755 --- a/tests/blas/nodes/axpy_test.py +++ b/tests/blas/nodes/axpy_test.py @@ -2,6 +2,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np +import pytest import argparse import scipy @@ -112,12 +113,16 @@ def stream_fpga_graph(veclen, precision, test_case, expansion): return sdfg +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @fpga_test() def test_axpy_fpga_array(): configs = [(0.5, 1, dace.float32), (1.0, 4, dace.float64)] return run_test(configs, "fpga_array") +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @fpga_test() def test_axpy_fpga_stream(): configs = [(0.5, 1, dace.float32), (1.0, 4, dace.float64)] diff --git a/tests/fpga/streaming_memory_test.py b/tests/fpga/streaming_memory_test.py index 11a56c42f4..1a48a8d128 100644 --- a/tests/fpga/streaming_memory_test.py +++ b/tests/fpga/streaming_memory_test.py @@ -548,6 +548,8 @@ def test_mem_buffer_vec_add_mixed_int(): return mem_buffer_vec_add_types(dace.int16, dace.int32, dace.int64, np.int16, np.int32, np.int64) +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @xilinx_test() def test_mem_buffer_mat_add(): # Make SDFG @@ -629,6 +631,8 @@ def test_mem_buffer_tensor_add(): return sdfg +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @xilinx_test() def test_mem_buffer_multistream(): # Make SDFG @@ -660,6 +664,8 @@ def test_mem_buffer_multistream(): return sdfg +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @xilinx_test() def test_mem_buffer_multistream_with_deps(): # Make SDFG @@ -714,6 +720,8 @@ def test_mem_buffer_mat_mul(): return sdfg +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @xilinx_test() def test_mem_buffer_map_order(): # Make SDFG diff --git a/tests/fpga/vec_sum_test.py b/tests/fpga/vec_sum_test.py index 791ba80e5d..f0bf96f94f 100644 --- a/tests/fpga/vec_sum_test.py +++ b/tests/fpga/vec_sum_test.py @@ -71,7 +71,9 @@ def test_vec_sum_vectorize_first(): return run_vec_sum(True) -@fpga_test(assert_ii_1=False) +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") +@fpga_test(assert_ii_1=False, intel=False) def test_vec_sum_fpga_transform_first(): return run_vec_sum(False) diff --git a/tests/fpga/veclen_conversion_connector_test.py b/tests/fpga/veclen_conversion_connector_test.py index 1d271512ec..24b2d95fe3 100644 --- a/tests/fpga/veclen_conversion_connector_test.py +++ b/tests/fpga/veclen_conversion_connector_test.py @@ -1,9 +1,12 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np +import pytest from veclen_conversion_test import make_sdfg from dace.fpga_testing import fpga_test +#TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @fpga_test() def test_veclen_conversion_connector(): @@ -33,4 +36,4 @@ def test_veclen_conversion_connector(): if __name__ == "__main__": - test_veclen_conversion_connector(None) + test_veclen_conversion_connector() diff --git a/tests/npbench/misc/arc_distance_test.py b/tests/npbench/misc/arc_distance_test.py index 81d6a819a7..571c0dd772 100644 --- a/tests/npbench/misc/arc_distance_test.py +++ b/tests/npbench/misc/arc_distance_test.py @@ -80,6 +80,8 @@ def test_gpu(): run_arc_distance(dace.dtypes.DeviceType.GPU) +# TODO: Investigate and re-enable if possible. +@pytest.mark.skip(reason="Unexplained CI Regression") @fpga_test(assert_ii_1=False) def test_fpga(): return run_arc_distance(dace.dtypes.DeviceType.FPGA) From c2e952fcc34eaacb51e4025a8e51b2d687725c54 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Fri, 10 Oct 2025 13:38:01 +0200 Subject: [PATCH 101/137] Backport of #2165 (#2166) Co-authored-by: Tal Ben-Nun --- dace/frontend/python/newast.py | 20 ++++++++--- tests/sdfg/reference_test.py | 66 ++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 4 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index d2813371c9..eed91f05ee 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -26,7 +26,7 @@ from dace.frontend.python import nested_call, replacements, preprocessing from dace.frontend.python.memlet_parser import DaceSyntaxError, parse_memlet, ParseMemlet, inner_eval_ast, MemletExpr from dace.sdfg import nodes -from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states +from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states, align_memlet from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock from dace.sdfg import SDFG, SDFGState @@ -2774,7 +2774,12 @@ def _add_assignment(self, memlet.other_subset = op_subset if op: memlet.wcr = LambdaProperty.from_string('lambda x, y: x {} y'.format(op)) - state.add_nedge(op1, op2, memlet) + if isinstance(self.sdfg.arrays[target_name], data.Reference): + e = state.add_edge(op1, None, op2, 'set', memlet) + # Align memlet to referenced array + e.data = align_memlet(state, e, dst=False) + else: + state.add_nedge(op1, op2, memlet) else: memlet = Memlet("{a}[{s}]".format(a=target_name, s=','.join(['__i%d' % i for i in range(len(target_subset))]))) @@ -3272,9 +3277,10 @@ def visit_AnnAssign(self, node: ast.AnnAssign): storage = dtypes.StorageType.Default type_name = rname(node.annotation) warnings.warn('typeclass {} is not supported'.format(type_name)) - if node.value is None and dtype is not None: # Annotating type without assignment + if dtype is not None: self.annotated_types[rname(node.target)] = dtype - return + if node.value is None: # Annotating type without assignment + return results = self._visit_assign(node, node.target, None, dtype=dtype) if storage != dtypes.StorageType.Default: self.sdfg.arrays[results[0][0]].storage = storage @@ -3403,6 +3409,12 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): true_name, new_data = self.sdfg.add_temp_transient([1], result_data.dtype) self.variables[name] = true_name defined_vars[name] = true_name + elif name in self.annotated_types and isinstance(self.annotated_types[name], data.Reference): + desc = copy.deepcopy(self.annotated_types[name]) + desc.transient = True + true_name = self.sdfg.add_datadesc(name, desc, find_new_name=True) + self.variables[name] = true_name + defined_vars[name] = true_name elif (not name.startswith('__return') and (isinstance(result_data, data.View) or (not result_data.transient and isinstance(result_data, data.Array)))): diff --git a/tests/sdfg/reference_test.py b/tests/sdfg/reference_test.py index da5c4a0111..79b4ceff00 100644 --- a/tests/sdfg/reference_test.py +++ b/tests/sdfg/reference_test.py @@ -10,6 +10,70 @@ import networkx as nx +def test_frontend_reference(): + N = dace.symbol('N') + M = dace.symbol('M') + mystruct = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (N, M), strides=(1, N)), + "arrA": dace.data.ArrayReference(dace.float32, (N, )), + "arrB": dace.data.ArrayReference(dace.float32, (N, )), + }, + name="MyStruct") + + @dace.program + def init_prog(mydat: mystruct, fill_value: int) -> None: + mydat.arrA = mydat.data[:, 2] + mydat.arrB = mydat.data[:, 0] + + # loop over all arrays and initialize them with `fill_value` + for index in range(M): + mydat.data[:, index] = fill_value + + # Initialize the two named ones by name + mydat.arrA[:] = fill_value + 1 + mydat.arrB[:] = fill_value + 2 + + dat = np.zeros((10, 5), dtype=np.float32) + inp_struct = mystruct.dtype._typeclass.as_ctypes()(data=dat.__array_interface__['data'][0]) + + func = init_prog.compile() + func(mydat=inp_struct, fill_value=3, N=10, M=5) + + assert np.allclose(dat[0, :], 5) and np.allclose(dat[1, :], 5) + assert np.allclose(dat[2, :], 3) and np.allclose(dat[3, :], 3) + assert np.allclose(dat[4, :], 4) and np.allclose(dat[5, :], 4) + assert np.allclose(dat[6, :], 3) and np.allclose(dat[7, :], 3) + assert np.allclose(dat[8, :], 3) and np.allclose(dat[9, :], 3) + + +def test_type_annotation_reference(): + N = dace.symbol('N') + + @dace.program + def ref(A: dace.float64[N], B: dace.float64[N], T: dace.int32, out: dace.float64[N]): + ref1: dace.data.ArrayReference(A.dtype, A.shape) = A + ref2: dace.data.ArrayReference(A.dtype, A.shape) = B + if T <= 0: + out[:] = ref1[:] + 1 + else: + out[:] = ref2[:] + 1 + + a = np.random.rand(20) + a_verif = a.copy() + b = np.random.rand(20) + b_verif = b.copy() + out = np.random.rand(20) + out_verif = out.copy() + + ref(a, b, 1, out, N=20) + ref.f(a_verif, b_verif, 1, out_verif) + assert np.allclose(out, out_verif) + + ref(a, b, -1, out, N=20) + ref.f(a_verif, b_verif, -1, out_verif) + assert np.allclose(out, out_verif) + + def test_unset_reference(): sdfg = dace.SDFG('tester') sdfg.add_reference('ref', [20], dace.float64) @@ -683,6 +747,8 @@ def test_ref2view_reconnection(): if __name__ == '__main__': + test_frontend_reference() + test_type_annotation_reference() test_unset_reference() test_reference_branch() test_reference_sources_pass() From 1033dfcf9d118856d82c6ee8d6f6cfacec662335 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 14 Oct 2025 10:46:14 +0200 Subject: [PATCH 102/137] Unrelated: fix a bunch of typos --- dace/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dace/data.py b/dace/data.py index 9749411fe6..19a45e7d92 100644 --- a/dace/data.py +++ b/dace/data.py @@ -1132,7 +1132,7 @@ def __init__(self, :param tensor_shape: logical shape of tensor (#rows, #cols, etc...) :param indices: a list of tuples, each tuple represents a level in the tensor - storage hirachy, specifying the levels tensor index type, and the + storage hierarchy, specifying the levels tensor index type, and the corresponding dimension this level encodes (as index of the tensor_shape tuple above). The order of the dimensions may differ from the logical shape of the tensor, e.g. as seen in the CSC @@ -1154,9 +1154,9 @@ def __init__(self, num_dims = len(tensor_shape) dimension_order = [idx for idx in self.index_ordering if isinstance(idx, int)] - # all tensor dimensions must occure exactly once in indices + # all tensor dimensions must occur exactly once in indices if not sorted(dimension_order) == list(range(num_dims)): - raise TypeError((f"All tensor dimensions must be refferenced exactly once in " + raise TypeError((f"All tensor dimensions must be referenced exactly once in " f"tensor indices. (referenced dimensions: {dimension_order}; " f"tensor dimensions: {list(range(num_dims))})")) From 4c042ee167e749305ac5b53406c6bbae6b2d306e Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 28 Oct 2025 20:23:53 -0700 Subject: [PATCH 103/137] Backport fix from #1853 and support `dace.map` syntax for struct fields (#2186) --- dace/data.py | 5 + dace/frontend/python/newast.py | 18 ++- dace/symbolic.py | 8 + tests/codegen/allocation_lifetime_test.py | 3 +- tests/python_frontend/loops_test.py | 18 +++ .../structures/structure_python_test.py | 139 ++++++++++++++---- 6 files changed, 159 insertions(+), 32 deletions(-) diff --git a/dace/data.py b/dace/data.py index 9749411fe6..93c40de450 100644 --- a/dace/data.py +++ b/dace/data.py @@ -387,6 +387,9 @@ def __init__(self, self.members = OrderedDict(members) for k, v in self.members.items(): + if isinstance(v, dtypes.typeclass): + v = Scalar(v) + self.members[k] = v v.transient = transient self.name = name @@ -402,6 +405,8 @@ def __init__(self, elif isinstance(v, Scalar): symbols |= v.free_symbols fields_and_types[k] = v.dtype + elif isinstance(v, dtypes.typeclass): + fields_and_types[k] = v elif isinstance(v, (sp.Basic, symbolic.SymExpr)): symbols |= v.free_symbols fields_and_types[k] = symbolic.symtype(v) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index eed91f05ee..71de818877 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1852,7 +1852,7 @@ def _parse_map_inputs(self, name: str, params: List[Tuple[str, str]], if symbolic.issymbolic(atom, self.sdfg.constants): # Check for undefined variables atomstr = str(atom) - if atomstr not in self.defined: + if atomstr not in self.defined and atomstr not in self.sdfg.arrays: raise DaceSyntaxError(self, node, 'Undefined variable "%s"' % atom) # Add to global SDFG symbols @@ -2350,7 +2350,7 @@ def visit_For(self, node: ast.For): if symbolic.issymbolic(atom, self.sdfg.constants): astr = str(atom) # Check for undefined variables - if astr not in self.defined: + if astr not in self.defined and not ('.' in astr and astr in self.sdfg.arrays): raise DaceSyntaxError(self, node, 'Undefined variable "%s"' % atom) # Add to global SDFG symbols if not a scalar if (astr not in self.sdfg.symbols and not (astr in self.variables or astr in self.sdfg.arrays)): @@ -3079,8 +3079,14 @@ def _add_access( else: var_name = self.sdfg.temp_data_name() - parent_name = self.scope_vars[name] - parent_array = self.scope_arrays[parent_name] + parent_name = self.scope_vars[until(name, '.')] + if '.' in name: + struct_field = name[name.index('.'):] + parent_name += struct_field + scope_ndict = dace.sdfg.NestedDict(self.scope_arrays) + parent_array = scope_ndict[parent_name] + else: + parent_array = self.scope_arrays[parent_name] has_indirection = (_subset_has_indirection(rng, self) or _subset_is_local_symbol_dependent(rng, self)) if has_indirection: @@ -3244,7 +3250,7 @@ def _add_write_access(self, return self.accesses[(name, rng, 'w')] elif name in self.variables: return (self.variables[name], rng) - elif (name, rng, 'r') in self.accesses or name in self.scope_vars: + elif (name, rng, 'r') in self.accesses or until(name, '.') in self.scope_vars: return self._add_access(name, rng, 'w', target, new_name, arr_type) else: raise NotImplementedError @@ -3498,7 +3504,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): raise IndexError('Boolean array indexing cannot be combined with indirect access') if self.nested and not new_data: - new_name, new_rng = self._add_write_access(name, rng, target) + new_name, new_rng = self._add_write_access(true_name, rng, target) # Local symbol or local data dependent if _subset_is_local_symbol_dependent(rng, self): new_rng = rng diff --git a/dace/symbolic.py b/dace/symbolic.py index 98ffa008d3..56308751d8 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -307,6 +307,8 @@ def symlist(values): except TypeError: values = [values] + skip = set() + for expr in values: if isinstance(expr, SymExpr): true_expr = expr.expr @@ -315,6 +317,12 @@ def symlist(values): else: continue for atom in sympy.preorder_traversal(true_expr): + if atom in skip: + continue + if isinstance(atom, Attr): + # Skip attributes + skip.add(atom.args[1]) + continue if isinstance(atom, symbol): result[atom.name] = atom return result diff --git a/tests/codegen/allocation_lifetime_test.py b/tests/codegen/allocation_lifetime_test.py index 2b53e87644..c366488c80 100644 --- a/tests/codegen/allocation_lifetime_test.py +++ b/tests/codegen/allocation_lifetime_test.py @@ -206,6 +206,7 @@ def persistentmem(output: dace.int32[1]): del csdfg +@pytest.mark.skip(reason="In v1, produces two tasklets side-by-side, leading to nondeterministic code order") def test_alloc_persistent_threadlocal(): @dace.program @@ -599,7 +600,7 @@ def test_multisize(): test_persistent_gpu_transpose_regression() test_alloc_persistent_register() test_alloc_persistent() - test_alloc_persistent_threadlocal() + # test_alloc_persistent_threadlocal() test_alloc_persistent_threadlocal_naming() test_alloc_multistate() test_nested_view_samename() diff --git a/tests/python_frontend/loops_test.py b/tests/python_frontend/loops_test.py index e0c869f20c..1a59ba13c4 100644 --- a/tests/python_frontend/loops_test.py +++ b/tests/python_frontend/loops_test.py @@ -5,6 +5,7 @@ from dace.frontend.python.common import DaceSyntaxError + @dace.program def for_loop(): A = dace.ndarray([10], dtype=dace.int32) @@ -499,6 +500,22 @@ def test_branch_in_while(): assert len(sdfg.source_nodes()) == 1 +def test_for_with_field(): + struct = dace.data.Structure({'data': dace.float64[20], 'length': dace.int32}, name='MyStruct') + + @dace.program + def for_with_field(S: struct): + for i in range(S.length): + S.data[i] = S.data[i] + 1.0 + + A = np.random.rand(20) + inp_struct = struct.dtype.base_type.as_ctypes()(data=A.__array_interface__['data'][0], length=10) + expected = np.copy(A) + expected[:10] += 1.0 + for_with_field.compile()(S=inp_struct) + assert np.allclose(A, expected) + + if __name__ == "__main__": test_for_loop() test_for_loop_with_break_continue() @@ -522,3 +539,4 @@ def test_branch_in_while(): test_while_else() test_branch_in_for() test_branch_in_while() + test_for_with_field() diff --git a/tests/python_frontend/structures/structure_python_test.py b/tests/python_frontend/structures/structure_python_test.py index bc4ab58d7c..5908a0b9e3 100644 --- a/tests/python_frontend/structures/structure_python_test.py +++ b/tests/python_frontend/structures/structure_python_test.py @@ -1,4 +1,5 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import ctypes import dace import numpy as np import pytest @@ -18,7 +19,7 @@ def csr_to_dense_python(A: CSR, B: dace.float32[M, N]): for i in dace.map[0:M]: for idx in dace.map[A.indptr[i]:A.indptr[i + 1]]: B[i, A.indices[idx]] = A.data[idx] - + rng = np.random.default_rng(42) A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) B = np.zeros((20, 20), dtype=np.float32) @@ -41,7 +42,7 @@ def test_write_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) CSR = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), name='CSRMatrix') - + @dace.program def dense_to_csr_python(A: dace.float32[M, N], B: CSR): idx = 0 @@ -53,7 +54,7 @@ def dense_to_csr_python(A: dace.float32[M, N], B: CSR): B.indices[idx] = j idx += 1 B.indptr[M] = idx - + rng = np.random.default_rng(42) tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) A = tmp.toarray() @@ -75,7 +76,7 @@ def test_local_structure(): M, N, nnz = (dace.symbol(s) for s in ('M', 'N', 'nnz')) CSR = dace.data.Structure(dict(indptr=dace.int32[M + 1], indices=dace.int32[nnz], data=dace.float32[nnz]), name='CSRMatrix') - + @dace.program def dense_to_csr_local_python(A: dace.float32[M, N], B: CSR): tmp = dace.define_local_structure(CSR) @@ -91,7 +92,7 @@ def dense_to_csr_local_python(A: dace.float32[M, N], B: CSR): B.indptr[:] = tmp.indptr[:] B.indices[:] = tmp.indices[:] B.data[:] = tmp.data[:] - + rng = np.random.default_rng(42) tmp = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) A = tmp.toarray() @@ -118,12 +119,11 @@ def __init__(self, diag, upper, lower): self.lower = lower n, nblocks = dace.symbol('n'), dace.symbol('nblocks') - BlockTriDiagonal = dace.data.Structure( - dict(diagonal=dace.complex128[nblocks, n, n], - upper=dace.complex128[nblocks, n, n], - lower=dace.complex128[nblocks, n, n]), - name='BlockTriDiagonalMatrix') - + BlockTriDiagonal = dace.data.Structure(dict(diagonal=dace.complex128[nblocks, n, n], + upper=dace.complex128[nblocks, n, n], + lower=dace.complex128[nblocks, n, n]), + name='BlockTriDiagonalMatrix') + @dace.program def rgf_leftToRight(A: BlockTriDiagonal, B: BlockTriDiagonal, n_: dace.int32, nblocks_: dace.int32): @@ -139,42 +139,41 @@ def rgf_leftToRight(A: BlockTriDiagonal, B: BlockTriDiagonal, n_: dace.int32, nb # 2. Forward substitution # From left to right for i in range(1, nblocks_): - tmp[i] = np.linalg.inv(A.diagonal[i] - A.lower[i-1] @ tmp[i-1] @ A.upper[i-1]) + tmp[i] = np.linalg.inv(A.diagonal[i] - A.lower[i - 1] @ tmp[i - 1] @ A.upper[i - 1]) # 3. Initialisation of last element of B B.diagonal[-1] = tmp[-1] # 4. Backward substitution # From right to left - for i in range(nblocks_-2, -1, -1): - B.diagonal[i] = tmp[i] @ (identity + A.upper[i] @ B.diagonal[i+1] @ A.lower[i] @ tmp[i]) - B.upper[i] = -tmp[i] @ A.upper[i] @ B.diagonal[i+1] - B.lower[i] = np.transpose(B.upper[i]) - + for i in range(nblocks_ - 2, -1, -1): + B.diagonal[i] = tmp[i] @ (identity + A.upper[i] @ B.diagonal[i + 1] @ A.lower[i] @ tmp[i]) + B.upper[i] = -tmp[i] @ A.upper[i] @ B.diagonal[i + 1] + B.lower[i] = np.transpose(B.upper[i]) + rng = np.random.default_rng(42) A_diag = rng.random((10, 20, 20)) + 1j * rng.random((10, 20, 20)) A_upper = rng.random((10, 20, 20)) + 1j * rng.random((10, 20, 20)) - A_lower = rng.random((10, 20, 20)) + 1j * rng.random((10, 20, 20)) + A_lower = rng.random((10, 20, 20)) + 1j * rng.random((10, 20, 20)) inpBTD = BlockTriDiagonal.dtype._typeclass.as_ctypes()(diagonal=A_diag.__array_interface__['data'][0], upper=A_upper.__array_interface__['data'][0], lower=A_lower.__array_interface__['data'][0]) - + B_diag = np.zeros((10, 20, 20), dtype=np.complex128) B_upper = np.zeros((10, 20, 20), dtype=np.complex128) B_lower = np.zeros((10, 20, 20), dtype=np.complex128) outBTD = BlockTriDiagonal.dtype._typeclass.as_ctypes()(diagonal=B_diag.__array_interface__['data'][0], upper=B_upper.__array_interface__['data'][0], lower=B_lower.__array_interface__['data'][0]) - + func = rgf_leftToRight.compile() func(A=inpBTD, B=outBTD, n_=A_diag.shape[1], nblocks_=A_diag.shape[0], n=A_diag.shape[1], nblocks=A_diag.shape[0]) A = BTD(A_diag, A_upper, A_lower) - B = BTD(np.zeros((10, 20, 20), dtype=np.complex128), - np.zeros((10, 20, 20), dtype=np.complex128), + B = BTD(np.zeros((10, 20, 20), dtype=np.complex128), np.zeros((10, 20, 20), dtype=np.complex128), np.zeros((10, 20, 20), dtype=np.complex128)) - + rgf_leftToRight.f(A, B, A_diag.shape[1], A_diag.shape[0]) assert np.allclose(B.diagonal, B_diag) @@ -195,7 +194,7 @@ def csr_to_dense_python(A: CSR, B: dace.float32[M, N]): for i in dace.map[0:M]: for idx in dace.map[A.indptr[i]:A.indptr[i + 1]]: B[i, A.indices[idx]] = A.data[idx] - + rng = np.random.default_rng(42) A = sparse.random(20, 20, density=0.1, format='csr', dtype=np.float32, random_state=rng) ref = A.toarray() @@ -203,7 +202,7 @@ def csr_to_dense_python(A: CSR, B: dace.float32[M, N]): inpA = CSR.dtype._typeclass.as_ctypes()(indptr=A.indptr.__array_interface__['data'][0], indices=A.indices.__array_interface__['data'][0], data=A.data.__array_interface__['data'][0]) - + # TODO: The following doesn't work because we need to create a Structure data descriptor from the ctypes class. # csr_to_dense_python(inpA, B) naive = csr_to_dense_python.to_sdfg(simplify=False) @@ -224,9 +223,99 @@ def csr_to_dense_python(A: CSR, B: dace.float32[M, N]): assert np.allclose(B, ref) +def test_write_structure_in_map(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def init_prog(bundle: Bundle, fill_value: int) -> None: + for index in dace.map[0:bundle.size]: + bundle.data[index, :] = fill_value + + data = np.zeros((10, 5), dtype=np.float32) + fill_value = 42 + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + size=9, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:9, :] = fill_value + + init_prog.compile()(inp_struct, fill_value, M=10, N=5) + + assert np.allclose(data, ref) + + +def test_readwrite_structure_in_map(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "data2": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleTypeTwoArrays") + + @dace.program + def copy_prog(bundle: Bundle) -> None: + for index in dace.map[0:bundle.size]: + bundle.data[index, :] = bundle.data2[index, :] + 5 + + data = np.zeros((10, 5), dtype=np.float32) + data2 = np.ones((10, 5), dtype=np.float32) + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + data2=data2.__array_interface__['data'][0], + size=ctypes.c_int64(6), + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:6, :] = 6.0 + + csdfg = copy_prog.compile() + csdfg.fast_call((ctypes.byref(inp_struct), ctypes.c_int(5)), (ctypes.c_int(5),)) + + assert np.allclose(data, ref) + + +def test_write_structure_in_loop(): + M = dace.symbol('M') + N = dace.symbol('N') + Bundle = dace.data.Structure(members={ + "data": dace.data.Array(dace.float32, (M, N)), + "size": dace.data.Scalar(dace.int64) + }, + name="BundleType") + + @dace.program + def init_prog(bundle: Bundle, fill_value: int) -> None: + for index in range(bundle.size): + bundle.data[index, :] = fill_value + + data = np.zeros((10, 5), dtype=np.float32) + fill_value = 42 + inp_struct = Bundle.dtype.base_type.as_ctypes()( + data=data.__array_interface__['data'][0], + size=6, + ) + ref = np.zeros((10, 5), dtype=np.float32) + ref[:6, :] = fill_value + + init_prog.compile()(inp_struct, fill_value, M=10, N=5) + + assert np.allclose(data, ref) + + if __name__ == '__main__': test_read_structure() test_write_structure() test_local_structure() test_rgf() # test_read_structure_gpu() + test_write_structure_in_map() + test_readwrite_structure_in_map() + test_write_structure_in_loop() From 22f9a30768953468893d3f3707b93ff69792bbb4 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Tue, 25 Nov 2025 12:22:54 +0100 Subject: [PATCH 104/137] limit networkx version <= 3.5 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 4687adce2c..f562b06389 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ }, include_package_data=True, install_requires=[ - 'numpy < 2.0', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', + 'numpy < 2.0', 'networkx >= 2.5, <= 3.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'fparser == 0.2.0', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"', 'packaging' ] + cmake_requires, From aa5f8b5b22a064e4586a822b1cf168e4fc95c480 Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Tue, 23 Dec 2025 17:27:52 -0500 Subject: [PATCH 105/137] [v1 Maintenance] `networkx` under 3.6 (#2258) Backporting the findings of an exploding `networkx` that was dealt with on mainline with #2235 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index b8c4b79364..4737651fff 100644 --- a/setup.py +++ b/setup.py @@ -73,7 +73,7 @@ }, include_package_data=True, install_requires=[ - 'numpy < 2.0', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', + 'numpy < 2.0', 'networkx >= 2.5, <= 3.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'fparser == 0.2.0', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill', 'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"', 'packaging' ] + cmake_requires, From c7b5a1014fb4a56f1cf821da0aeab81e753842ae Mon Sep 17 00:00:00 2001 From: Florian Deconinck Date: Thu, 25 Dec 2025 01:51:23 -0500 Subject: [PATCH 106/137] [v1 Maintenance] Back port `MapExpansion` fixes from v2 (#2257) This PR brings the fixes to `MapExpansion` [transform done in v2](https://github.com/spcl/dace/blob/156567b1eea3b54cd3dda0b6d3f259995127be68/dace/transformation/dataflow/map_expansion.py#L135) back into v1. --- dace/transformation/dataflow/map_expansion.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 8bc14213b0..0fbded6ff7 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -136,10 +136,14 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data)) graph.remove_edge(edge) - if graph.in_degree(map_entry) == 0: + if graph.in_degree(map_entry) == 0 or all( + e.dst_conn is None or not e.dst_conn.startswith("IN_") + for e in graph.in_edges(map_entry)): graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet()) else: for edge in graph.in_edges(map_entry): + if edge.dst_conn is None: + continue if not edge.dst_conn.startswith("IN_"): continue From e85315c0bed85de298056a890f2d429ca4b16a1b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 8 Jan 2026 15:55:39 +0100 Subject: [PATCH 107/137] cleanup after merging main Looks like me and/or git got a bit confused while merging main. Let's undo changes that don't need to be in this branch/PR (since it's anyway rather big). --- .github/workflows/fpga-ci.yml | 6 +- .github/workflows/general-ci.yml | 8 +- .github/workflows/heterogeneous-ci.yml | 6 +- dace/FillZ-schedule-tree.txt.jl | 821 ------------------ .../analysis/schedule_tree/sdfg_to_tree.py | 29 - .../interstate/loop_detection.py | 45 - tests/codegen/allocation_lifetime_test.py | 3 +- tests/codegen/cpp_test.py | 3 +- .../structures/structure_python_test.py | 2 - 9 files changed, 12 insertions(+), 911 deletions(-) delete mode 100644 dace/FillZ-schedule-tree.txt.jl diff --git a/.github/workflows/fpga-ci.yml b/.github/workflows/fpga-ci.yml index 21bef74b48..21ad7c1ac2 100644 --- a/.github/workflows/fpga-ci.yml +++ b/.github/workflows/fpga-ci.yml @@ -2,11 +2,11 @@ name: FPGA Tests on: push: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] pull_request: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] merge_group: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] env: CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/general-ci.yml b/.github/workflows/general-ci.yml index b3f8ffc049..59e0aae179 100644 --- a/.github/workflows/general-ci.yml +++ b/.github/workflows/general-ci.yml @@ -2,11 +2,11 @@ name: General Tests on: push: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] pull_request: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] merge_group: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] concurrency: group: ${{github.workflow}}-${{github.ref}} @@ -15,7 +15,7 @@ concurrency: jobs: test: if: "!contains(github.event.pull_request.labels.*.name, 'no-ci')" - runs-on: ubuntu-22.04 + runs-on: ubuntu-latest strategy: matrix: python-version: ['3.9','3.13'] diff --git a/.github/workflows/heterogeneous-ci.yml b/.github/workflows/heterogeneous-ci.yml index 475768b011..53a8788dce 100644 --- a/.github/workflows/heterogeneous-ci.yml +++ b/.github/workflows/heterogeneous-ci.yml @@ -2,11 +2,11 @@ name: Heterogeneous Tests on: push: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] pull_request: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] merge_group: - branches: [ main, v1/maintenance, ci-fix ] + branches: [ main, ci-fix ] env: CUDA_HOME: /usr/local/cuda diff --git a/dace/FillZ-schedule-tree.txt.jl b/dace/FillZ-schedule-tree.txt.jl deleted file mode 100644 index 083844bcf6..0000000000 --- a/dace/FillZ-schedule-tree.txt.jl +++ /dev/null @@ -1,821 +0,0 @@ -) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_15 = (2 - 1); (__k_15 > (1 - 1)); __k_15 = (__k_15 + (- 1)): - assign mask_140581926202832_gen_0_2 = (__g_tracers__qice__[0, 0, (__k_15 - 1)] < 0.0) - assign if_expression_140581858477968 = mask_140581926202832_gen_0_2 - assign if_condition_17 = if_expression_140581858477968 - state boundary - if if_condition_17: - state boundary - __g_tracers__qice__[__i, __j, __k_15] = tasklet(dp2[__i, __j, __k_15], dp2[__i, __j, __k_15 - 1], __g_tracers__qice__[__i, __j, __k_15], __g_tracers__qice__[__i, __j, __k_15 - 1]) - state boundary - for __k_15 = (1 - 1); (__k_15 > (0 - 1)); __k_15 = (__k_15 + (- 1)): - mask_140581926659024_gen_0_2[0] = tasklet(__g_tracers__qice__[__i, __j, __k_15]) - assign if_expression_140581853768016 = mask_140581926659024_gen_0_2 - assign if_condition_17 = if_expression_140581853768016 - state boundary - if (not if_condition_17): - pass - state boundary - else: - __g_tracers__qice__[__i, __j, __k_15] = tasklet() - dm_2[__i, __j, __k_15] = tasklet(dp2[__i, __j, __k_15], __g_tracers__qice__[__i, __j, __k_15]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_14 = 0; (__k_14 < (79 + 0)); __k_14 = (__k_14 + 1): - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j] = tasklet() - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - map __k in [0:79]: - lower_fix_2[__i, __j, __k], upper_fix_2[__i, __j, __k] = tasklet() - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_16 = 1; (__k_16 < (79 - 1)); __k_16 = (__k_16 + 1): - assign mask_140581914760464_gen_0_2 = (lower_fix[0, 0, (__k_16 - 1)] != 0.0) - assign if_expression_140581849443152 = mask_140581914760464_gen_0_2 - assign if_condition_19 = if_expression_140581849443152 - state boundary - if if_condition_19: - state boundary - __g_tracers__qice__[__i, __j, __k_16] = tasklet(dp2[__i, __j, __k_16], lower_fix_2[__i, __j, __k_16 - 1], __g_tracers__qice__[__i, __j, __k_16]) - assign mask_140581925917968_gen_0_2 = (__g_tracers__qice__[__i, __j, __k_16] < 0.0) - state boundary - state boundary - else: - assign mask_140581925917968_gen_0_2 = (__g_tracers__qice__[__i, __j, __k_16] < 0.0) - state boundary - assign if_expression_140581849490064 = mask_140581925917968_gen_0_2 - assign if_condition_19 = if_expression_140581849490064 - state boundary - if if_condition_19: - state boundary - mask_140581914769424_gen_0_2[0], __g_self__zfix[__i, __j] = tasklet(__g_tracers__qice__[__i, __j, __k_16 - 1], __g_self__zfix[__i, __j]) - assign if_expression_140581849501648 = mask_140581914769424_gen_0_2 - assign if_condition_19 = if_expression_140581849501648 - state boundary - if if_condition_19: - state boundary - __g_tracers__qice__[__i, __j, __k_16], upper_fix_2[__i, __j, __k_16] = tasklet(dp2[__i, __j, __k_16], dp2[__i, __j, __k_16 - 1], __g_tracers__qice__[__i, __j, __k_16], __g_tracers__qice__[__i, __j, __k_16 - 1]) - assign mask_140581910137744_gen_0_2 = ((__g_tracers__qice__[__i, __j, __k_16] < 0.0) and (__g_tracers__qice__[0, 0, (__k_16 + 1)] > 0.0)) - state boundary - state boundary - else: - assign mask_140581910137744_gen_0_2 = ((__g_tracers__qice__[__i, __j, __k_16] < 0.0) and (__g_tracers__qice__[0, 0, (__k_16 + 1)] > 0.0)) - state boundary - assign if_expression_140581850085456 = mask_140581910137744_gen_0_2 - assign if_condition_19 = if_expression_140581850085456 - state boundary - if if_condition_19: - state boundary - lower_fix_2[__i, __j, __k_16], __g_tracers__qice__[__i, __j, __k_16] = tasklet(dp2[__i, __j, __k_16], dp2[__i, __j, __k_16 + 1], __g_tracers__qice__[__i, __j, __k_16], __g_tracers__qice__[__i, __j, __k_16 + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [0:78]: - assign mask_140581914652944_gen_0_2 = (upper_fix_2[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581850264400 = mask_140581914652944_gen_0_2 - assign if_condition_18 = if_expression_140581850264400 - state boundary - if (not if_condition_18): - pass - state boundary - else: - state boundary - __g_tracers__qice__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qice__[__i, __j, __k], upper_fix_2[__i, __j, __k + 1]) - dm_2[__i, __j, __k], dm_pos_2[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qice__[__i, __j, __k]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_18 = (79 - 1); (__k_18 < (79 + 0)); __k_18 = (__k_18 + 1): - assign mask_140581921273680_gen_0_2 = (lower_fix[0, 0, (((- 79) + __k_18) + 1)] != 0.0) - assign if_expression_140581845110288 = mask_140581921273680_gen_0_2 - assign if_condition_22 = if_expression_140581845110288 - state boundary - if (not if_condition_22): - pass - state boundary - else: - state boundary - __g_tracers__qice__[__i, __j, __k_18] = tasklet(dp2[__i, __j, __k_18], lower_fix_2[__i, __j, __k_18 - 1], __g_tracers__qice__[__i, __j, __k_18]) - dup_gen_0_2[0], mask_140581899838608_gen_0_2[0] = tasklet(dp2[__i, __j, __k_18], dp2[__i, __j, __k_18 - 1], __g_tracers__qice__[__i, __j, __k_18], __g_tracers__qice__[__i, __j, __k_18 - 1]) - assign if_expression_140581845146256 = mask_140581899838608_gen_0_2 - assign if_condition_22 = if_expression_140581845146256 - state boundary - if (not if_condition_22): - pass - state boundary - else: - state boundary - __g_tracers__qice__[__i, __j, __k_18], upper_fix_2[__i, __j, __k_18], __g_self__zfix[__i, __j] = tasklet(dp2[__i, __j, __k_18], dup_gen_0_2[0], __g_tracers__qice__[__i, __j, __k_18], __g_self__zfix[__i, __j]) - dm_2[__i, __j, __k_18], dm_pos_2[__i, __j, __k_18] = tasklet(dp2[__i, __j, __k_18], __g_tracers__qice__[__i, __j, __k_18]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [77]: - assign mask_140581914647888_gen_0_2 = (upper_fix_2[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581845266000 = mask_140581914647888_gen_0_2 - assign if_condition_21 = if_expression_140581845266000 - state boundary - if if_condition_21: - state boundary - dm_2[__i, __j, __k], dm_pos_2[__i, __j, __k], __g_tracers__qice__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qice__[__i, __j, __k], upper_fix_2[__i, __j, __k + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_17 = 1; (__k_17 < (79 + 0)); __k_17 = (__k_17 + 1): - state boundary - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j] = tasklet(dm_2[__i, __j, __k_17], dm_pos_2[__i, __j, __k_17], __g_self__sum0[__i, __j], __g_self__sum1[__i, __j]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [1:79]: - fac_gen_0_2[0], mask_140581910137424_gen_0_2[0] = tasklet(__g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j]) - assign if_expression_140581845436688 = mask_140581910137424_gen_0_2 - assign if_condition_20 = if_expression_140581845436688 - state boundary - if if_condition_20: - __g_tracers__qice__[__i, __j, __k] = tasklet(dm_2[__i, __j, __k], dp2[__i, __j, __k], fac_gen_0_2[0]) - state boundary - __g_tracers__qsnow__ = nview __g_tracers__qsnow__[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - dp2 = nview dp2[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - __g_self__sum1 = nview __g_self__sum1[3:15, 3:15] as (12, 12) - state boundary - __g_self__sum0 = nview __g_self__sum0[3:15, 3:15] as (12, 12) - state boundary - __g_self__zfix = nview __g_self__zfix[3:15, 3:15] as (12, 12) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_20 = (2 - 1); (__k_20 > (1 - 1)); __k_20 = (__k_20 + (- 1)): - assign mask_140581926202832_gen_0_3 = (__g_tracers__qsnow__[0, 0, (__k_20 - 1)] < 0.0) - assign if_expression_140581854284688 = mask_140581926202832_gen_0_3 - assign if_condition_23 = if_expression_140581854284688 - state boundary - if if_condition_23: - state boundary - __g_tracers__qsnow__[__i, __j, __k_20] = tasklet(dp2[__i, __j, __k_20], dp2[__i, __j, __k_20 - 1], __g_tracers__qsnow__[__i, __j, __k_20], __g_tracers__qsnow__[__i, __j, __k_20 - 1]) - state boundary - for __k_20 = (1 - 1); (__k_20 > (0 - 1)); __k_20 = (__k_20 + (- 1)): - mask_140581926659024_gen_0_3[0] = tasklet(__g_tracers__qsnow__[__i, __j, __k_20]) - assign if_expression_140581850253712 = mask_140581926659024_gen_0_3 - assign if_condition_23 = if_expression_140581850253712 - state boundary - if (not if_condition_23): - pass - state boundary - else: - __g_tracers__qsnow__[__i, __j, __k_20] = tasklet() - dm_3[__i, __j, __k_20] = tasklet(dp2[__i, __j, __k_20], __g_tracers__qsnow__[__i, __j, __k_20]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_19 = 0; (__k_19 < (79 + 0)); __k_19 = (__k_19 + 1): - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j] = tasklet() - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - map __k in [0:79]: - lower_fix_3[__i, __j, __k], upper_fix_3[__i, __j, __k] = tasklet() - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_21 = 1; (__k_21 < (79 - 1)); __k_21 = (__k_21 + 1): - assign mask_140581914760464_gen_0_3 = (lower_fix[0, 0, (__k_21 - 1)] != 0.0) - assign if_expression_140581845852880 = mask_140581914760464_gen_0_3 - assign if_condition_25 = if_expression_140581845852880 - state boundary - if if_condition_25: - state boundary - __g_tracers__qsnow__[__i, __j, __k_21] = tasklet(dp2[__i, __j, __k_21], lower_fix_3[__i, __j, __k_21 - 1], __g_tracers__qsnow__[__i, __j, __k_21]) - assign mask_140581925917968_gen_0_3 = (__g_tracers__qsnow__[__i, __j, __k_21] < 0.0) - state boundary - state boundary - else: - assign mask_140581925917968_gen_0_3 = (__g_tracers__qsnow__[__i, __j, __k_21] < 0.0) - state boundary - assign if_expression_140581845908368 = mask_140581925917968_gen_0_3 - assign if_condition_25 = if_expression_140581845908368 - state boundary - if if_condition_25: - state boundary - mask_140581914769424_gen_0_3[0], __g_self__zfix[__i, __j] = tasklet(__g_tracers__qsnow__[__i, __j, __k_21 - 1], __g_self__zfix[__i, __j]) - assign if_expression_140581845926736 = mask_140581914769424_gen_0_3 - assign if_condition_25 = if_expression_140581845926736 - state boundary - if if_condition_25: - state boundary - __g_tracers__qsnow__[__i, __j, __k_21], upper_fix_3[__i, __j, __k_21] = tasklet(dp2[__i, __j, __k_21], dp2[__i, __j, __k_21 - 1], __g_tracers__qsnow__[__i, __j, __k_21], __g_tracers__qsnow__[__i, __j, __k_21 - 1]) - assign mask_140581910137744_gen_0_3 = ((__g_tracers__qsnow__[__i, __j, __k_21] < 0.0) and (__g_tracers__qsnow__[0, 0, (__k_21 + 1)] > 0.0)) - state boundary - state boundary - else: - assign mask_140581910137744_gen_0_3 = ((__g_tracers__qsnow__[__i, __j, __k_21] < 0.0) and (__g_tracers__qsnow__[0, 0, (__k_21 + 1)] > 0.0)) - state boundary - assign if_expression_140581841429328 = mask_140581910137744_gen_0_3 - assign if_condition_25 = if_expression_140581841429328 - state boundary - if if_condition_25: - state boundary - lower_fix_3[__i, __j, __k_21], __g_tracers__qsnow__[__i, __j, __k_21] = tasklet(dp2[__i, __j, __k_21], dp2[__i, __j, __k_21 + 1], __g_tracers__qsnow__[__i, __j, __k_21], __g_tracers__qsnow__[__i, __j, __k_21 + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [0:78]: - assign mask_140581914652944_gen_0_3 = (upper_fix_3[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581841493584 = mask_140581914652944_gen_0_3 - assign if_condition_24 = if_expression_140581841493584 - state boundary - if (not if_condition_24): - pass - state boundary - else: - state boundary - __g_tracers__qsnow__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qsnow__[__i, __j, __k], upper_fix_3[__i, __j, __k + 1]) - dm_3[__i, __j, __k], dm_pos_3[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qsnow__[__i, __j, __k]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_23 = (79 - 1); (__k_23 < (79 + 0)); __k_23 = (__k_23 + 1): - assign mask_140581921273680_gen_0_3 = (lower_fix[0, 0, (((- 79) + __k_23) + 1)] != 0.0) - assign if_expression_140581841549520 = mask_140581921273680_gen_0_3 - assign if_condition_28 = if_expression_140581841549520 - state boundary - if (not if_condition_28): - pass - state boundary - else: - state boundary - __g_tracers__qsnow__[__i, __j, __k_23] = tasklet(dp2[__i, __j, __k_23], lower_fix_3[__i, __j, __k_23 - 1], __g_tracers__qsnow__[__i, __j, __k_23]) - dup_gen_0_3[0], mask_140581899838608_gen_0_3[0] = tasklet(dp2[__i, __j, __k_23], dp2[__i, __j, __k_23 - 1], __g_tracers__qsnow__[__i, __j, __k_23], __g_tracers__qsnow__[__i, __j, __k_23 - 1]) - assign if_expression_140581841618320 = mask_140581899838608_gen_0_3 - assign if_condition_28 = if_expression_140581841618320 - state boundary - if (not if_condition_28): - pass - state boundary - else: - state boundary - __g_tracers__qsnow__[__i, __j, __k_23], upper_fix_3[__i, __j, __k_23], __g_self__zfix[__i, __j] = tasklet(dp2[__i, __j, __k_23], dup_gen_0_3[0], __g_tracers__qsnow__[__i, __j, __k_23], __g_self__zfix[__i, __j]) - dm_3[__i, __j, __k_23], dm_pos_3[__i, __j, __k_23] = tasklet(dp2[__i, __j, __k_23], __g_tracers__qsnow__[__i, __j, __k_23]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [77]: - assign mask_140581914647888_gen_0_3 = (upper_fix_3[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581841738064 = mask_140581914647888_gen_0_3 - assign if_condition_27 = if_expression_140581841738064 - state boundary - if if_condition_27: - state boundary - dm_3[__i, __j, __k], dm_pos_3[__i, __j, __k], __g_tracers__qsnow__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qsnow__[__i, __j, __k], upper_fix_3[__i, __j, __k + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_22 = 1; (__k_22 < (79 + 0)); __k_22 = (__k_22 + 1): - state boundary - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j] = tasklet(dm_3[__i, __j, __k_22], dm_pos_3[__i, __j, __k_22], __g_self__sum0[__i, __j], __g_self__sum1[__i, __j]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [1:79]: - fac_gen_0_3[0], mask_140581910137424_gen_0_3[0] = tasklet(__g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j]) - assign if_expression_140581841908688 = mask_140581910137424_gen_0_3 - assign if_condition_26 = if_expression_140581841908688 - state boundary - if if_condition_26: - __g_tracers__qsnow__[__i, __j, __k] = tasklet(dm_3[__i, __j, __k], dp2[__i, __j, __k], fac_gen_0_3[0]) - state boundary - __g_tracers__qgraupel__ = nview __g_tracers__qgraupel__[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - __g_self__sum1 = nview __g_self__sum1[3:15, 3:15] as (12, 12) - state boundary - __g_self__sum0 = nview __g_self__sum0[3:15, 3:15] as (12, 12) - state boundary - __g_self__zfix = nview __g_self__zfix[3:15, 3:15] as (12, 12) - state boundary - dp2 = nview dp2[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_25 = (2 - 1); (__k_25 > (1 - 1)); __k_25 = (__k_25 + (- 1)): - assign mask_140581926202832_gen_0_4 = (__g_tracers__qgraupel__[0, 0, (__k_25 - 1)] < 0.0) - assign if_expression_140581841056656 = mask_140581926202832_gen_0_4 - assign if_condition_29 = if_expression_140581841056656 - state boundary - if if_condition_29: - state boundary - __g_tracers__qgraupel__[__i, __j, __k_25] = tasklet(dp2[__i, __j, __k_25], dp2[__i, __j, __k_25 - 1], __g_tracers__qgraupel__[__i, __j, __k_25], __g_tracers__qgraupel__[__i, __j, __k_25 - 1]) - state boundary - for __k_25 = (1 - 1); (__k_25 > (0 - 1)); __k_25 = (__k_25 + (- 1)): - mask_140581926659024_gen_0_4[0] = tasklet(__g_tracers__qgraupel__[__i, __j, __k_25]) - assign if_expression_140581841425616 = mask_140581926659024_gen_0_4 - assign if_condition_29 = if_expression_140581841425616 - state boundary - if (not if_condition_29): - pass - state boundary - else: - __g_tracers__qgraupel__[__i, __j, __k_25] = tasklet() - dm_4[__i, __j, __k_25] = tasklet(dp2[__i, __j, __k_25], __g_tracers__qgraupel__[__i, __j, __k_25]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_24 = 0; (__k_24 < (79 + 0)); __k_24 = (__k_24 + 1): - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j] = tasklet() - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - map __k in [0:79]: - lower_fix_4[__i, __j, __k], upper_fix_4[__i, __j, __k] = tasklet() - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_26 = 1; (__k_26 < (79 - 1)); __k_26 = (__k_26 + 1): - assign mask_140581914760464_gen_0_4 = (lower_fix[0, 0, (__k_26 - 1)] != 0.0) - assign if_expression_140581837110288 = mask_140581914760464_gen_0_4 - assign if_condition_31 = if_expression_140581837110288 - state boundary - if if_condition_31: - state boundary - __g_tracers__qgraupel__[__i, __j, __k_26] = tasklet(dp2[__i, __j, __k_26], lower_fix_4[__i, __j, __k_26 - 1], __g_tracers__qgraupel__[__i, __j, __k_26]) - assign mask_140581925917968_gen_0_4 = (__g_tracers__qgraupel__[__i, __j, __k_26] < 0.0) - state boundary - state boundary - else: - assign mask_140581925917968_gen_0_4 = (__g_tracers__qgraupel__[__i, __j, __k_26] < 0.0) - state boundary - assign if_expression_140581837137936 = mask_140581925917968_gen_0_4 - assign if_condition_31 = if_expression_140581837137936 - state boundary - if if_condition_31: - state boundary - mask_140581914769424_gen_0_4[0], __g_self__zfix[__i, __j] = tasklet(__g_tracers__qgraupel__[__i, __j, __k_26 - 1], __g_self__zfix[__i, __j]) - assign if_expression_140581837148496 = mask_140581914769424_gen_0_4 - assign if_condition_31 = if_expression_140581837148496 - state boundary - if if_condition_31: - state boundary - __g_tracers__qgraupel__[__i, __j, __k_26], upper_fix_4[__i, __j, __k_26] = tasklet(dp2[__i, __j, __k_26], dp2[__i, __j, __k_26 - 1], __g_tracers__qgraupel__[__i, __j, __k_26], __g_tracers__qgraupel__[__i, __j, __k_26 - 1]) - assign mask_140581910137744_gen_0_4 = ((__g_tracers__qgraupel__[__i, __j, __k_26] < 0.0) and (__g_tracers__qgraupel__[0, 0, (__k_26 + 1)] > 0.0)) - state boundary - state boundary - else: - assign mask_140581910137744_gen_0_4 = ((__g_tracers__qgraupel__[__i, __j, __k_26] < 0.0) and (__g_tracers__qgraupel__[0, 0, (__k_26 + 1)] > 0.0)) - state boundary - assign if_expression_140581837717072 = mask_140581910137744_gen_0_4 - assign if_condition_31 = if_expression_140581837717072 - state boundary - if if_condition_31: - state boundary - lower_fix_4[__i, __j, __k_26], __g_tracers__qgraupel__[__i, __j, __k_26] = tasklet(dp2[__i, __j, __k_26], dp2[__i, __j, __k_26 + 1], __g_tracers__qgraupel__[__i, __j, __k_26], __g_tracers__qgraupel__[__i, __j, __k_26 + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [0:78]: - assign mask_140581914652944_gen_0_4 = (upper_fix_4[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581830539600 = mask_140581914652944_gen_0_4 - assign if_condition_30 = if_expression_140581830539600 - state boundary - if (not if_condition_30): - pass - state boundary - else: - state boundary - __g_tracers__qgraupel__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qgraupel__[__i, __j, __k], upper_fix_4[__i, __j, __k + 1]) - dm_4[__i, __j, __k], dm_pos_4[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qgraupel__[__i, __j, __k]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_28 = (79 - 1); (__k_28 < (79 + 0)); __k_28 = (__k_28 + 1): - assign mask_140581921273680_gen_0_4 = (lower_fix[0, 0, (((- 79) + __k_28) + 1)] != 0.0) - assign if_expression_140581830611984 = mask_140581921273680_gen_0_4 - assign if_condition_34 = if_expression_140581830611984 - state boundary - if (not if_condition_34): - pass - state boundary - else: - state boundary - __g_tracers__qgraupel__[__i, __j, __k_28] = tasklet(dp2[__i, __j, __k_28], lower_fix_4[__i, __j, __k_28 - 1], __g_tracers__qgraupel__[__i, __j, __k_28]) - dup_gen_0_4[0], mask_140581899838608_gen_0_4[0] = tasklet(dp2[__i, __j, __k_28], dp2[__i, __j, __k_28 - 1], __g_tracers__qgraupel__[__i, __j, __k_28], __g_tracers__qgraupel__[__i, __j, __k_28 - 1]) - assign if_expression_140581830680720 = mask_140581899838608_gen_0_4 - assign if_condition_34 = if_expression_140581830680720 - state boundary - if (not if_condition_34): - pass - state boundary - else: - state boundary - __g_tracers__qgraupel__[__i, __j, __k_28], upper_fix_4[__i, __j, __k_28], __g_self__zfix[__i, __j] = tasklet(dp2[__i, __j, __k_28], dup_gen_0_4[0], __g_tracers__qgraupel__[__i, __j, __k_28], __g_self__zfix[__i, __j]) - dm_4[__i, __j, __k_28], dm_pos_4[__i, __j, __k_28] = tasklet(dp2[__i, __j, __k_28], __g_tracers__qgraupel__[__i, __j, __k_28]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [77]: - assign mask_140581914647888_gen_0_4 = (upper_fix_4[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581830800464 = mask_140581914647888_gen_0_4 - assign if_condition_33 = if_expression_140581830800464 - state boundary - if if_condition_33: - state boundary - dm_4[__i, __j, __k], dm_pos_4[__i, __j, __k], __g_tracers__qgraupel__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qgraupel__[__i, __j, __k], upper_fix_4[__i, __j, __k + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_27 = 1; (__k_27 < (79 + 0)); __k_27 = (__k_27 + 1): - state boundary - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j] = tasklet(dm_4[__i, __j, __k_27], dm_pos_4[__i, __j, __k_27], __g_self__sum0[__i, __j], __g_self__sum1[__i, __j]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [1:79]: - fac_gen_0_4[0], mask_140581910137424_gen_0_4[0] = tasklet(__g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j]) - assign if_expression_140581830987536 = mask_140581910137424_gen_0_4 - assign if_condition_32 = if_expression_140581830987536 - state boundary - if if_condition_32: - __g_tracers__qgraupel__[__i, __j, __k] = tasklet(dm_4[__i, __j, __k], dp2[__i, __j, __k], fac_gen_0_4[0]) - state boundary - dp2 = nview dp2[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - __g_self__sum1 = nview __g_self__sum1[3:15, 3:15] as (12, 12) - state boundary - __g_tracers__qo3mr__ = nview __g_tracers__qo3mr__[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - __g_self__sum0 = nview __g_self__sum0[3:15, 3:15] as (12, 12) - state boundary - __g_self__zfix = nview __g_self__zfix[3:15, 3:15] as (12, 12) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_30 = (2 - 1); (__k_30 > (1 - 1)); __k_30 = (__k_30 + (- 1)): - assign mask_140581926202832_gen_0_5 = (__g_tracers__qo3mr__[0, 0, (__k_30 - 1)] < 0.0) - assign if_expression_140581837581840 = mask_140581926202832_gen_0_5 - assign if_condition_35 = if_expression_140581837581840 - state boundary - if if_condition_35: - state boundary - __g_tracers__qo3mr__[__i, __j, __k_30] = tasklet(dp2[__i, __j, __k_30], dp2[__i, __j, __k_30 - 1], __g_tracers__qo3mr__[__i, __j, __k_30], __g_tracers__qo3mr__[__i, __j, __k_30 - 1]) - state boundary - for __k_30 = (1 - 1); (__k_30 > (0 - 1)); __k_30 = (__k_30 + (- 1)): - mask_140581926659024_gen_0_5[0] = tasklet(__g_tracers__qo3mr__[__i, __j, __k_30]) - assign if_expression_140581837082832 = mask_140581926659024_gen_0_5 - assign if_condition_35 = if_expression_140581837082832 - state boundary - if (not if_condition_35): - pass - state boundary - else: - __g_tracers__qo3mr__[__i, __j, __k_30] = tasklet() - dm_5[__i, __j, __k_30] = tasklet(dp2[__i, __j, __k_30], __g_tracers__qo3mr__[__i, __j, __k_30]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_29 = 0; (__k_29 < (79 + 0)); __k_29 = (__k_29 + 1): - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j] = tasklet() - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - map __k in [0:79]: - lower_fix_5[__i, __j, __k], upper_fix_5[__i, __j, __k] = tasklet() - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_31 = 1; (__k_31 < (79 - 1)); __k_31 = (__k_31 + 1): - assign mask_140581914760464_gen_0_5 = (lower_fix[0, 0, (__k_31 - 1)] != 0.0) - assign if_expression_140581831424784 = mask_140581914760464_gen_0_5 - assign if_condition_37 = if_expression_140581831424784 - state boundary - if if_condition_37: - state boundary - __g_tracers__qo3mr__[__i, __j, __k_31] = tasklet(dp2[__i, __j, __k_31], lower_fix_5[__i, __j, __k_31 - 1], __g_tracers__qo3mr__[__i, __j, __k_31]) - assign mask_140581925917968_gen_0_5 = (__g_tracers__qo3mr__[__i, __j, __k_31] < 0.0) - state boundary - state boundary - else: - assign mask_140581925917968_gen_0_5 = (__g_tracers__qo3mr__[__i, __j, __k_31] < 0.0) - state boundary - assign if_expression_140581826208656 = mask_140581925917968_gen_0_5 - assign if_condition_37 = if_expression_140581826208656 - state boundary - if if_condition_37: - state boundary - mask_140581914769424_gen_0_5[0], __g_self__zfix[__i, __j] = tasklet(__g_tracers__qo3mr__[__i, __j, __k_31 - 1], __g_self__zfix[__i, __j]) - assign if_expression_140581826252944 = mask_140581914769424_gen_0_5 - assign if_condition_37 = if_expression_140581826252944 - state boundary - if if_condition_37: - state boundary - __g_tracers__qo3mr__[__i, __j, __k_31], upper_fix_5[__i, __j, __k_31] = tasklet(dp2[__i, __j, __k_31], dp2[__i, __j, __k_31 - 1], __g_tracers__qo3mr__[__i, __j, __k_31], __g_tracers__qo3mr__[__i, __j, __k_31 - 1]) - assign mask_140581910137744_gen_0_5 = ((__g_tracers__qo3mr__[__i, __j, __k_31] < 0.0) and (__g_tracers__qo3mr__[0, 0, (__k_31 + 1)] > 0.0)) - state boundary - state boundary - else: - assign mask_140581910137744_gen_0_5 = ((__g_tracers__qo3mr__[__i, __j, __k_31] < 0.0) and (__g_tracers__qo3mr__[0, 0, (__k_31 + 1)] > 0.0)) - state boundary - assign if_expression_140581826986064 = mask_140581910137744_gen_0_5 - assign if_condition_37 = if_expression_140581826986064 - state boundary - if if_condition_37: - state boundary - lower_fix_5[__i, __j, __k_31], __g_tracers__qo3mr__[__i, __j, __k_31] = tasklet(dp2[__i, __j, __k_31], dp2[__i, __j, __k_31 + 1], __g_tracers__qo3mr__[__i, __j, __k_31], __g_tracers__qo3mr__[__i, __j, __k_31 + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [0:78]: - assign mask_140581914652944_gen_0_5 = (upper_fix_5[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581827050256 = mask_140581914652944_gen_0_5 - assign if_condition_36 = if_expression_140581827050256 - state boundary - if (not if_condition_36): - pass - state boundary - else: - state boundary - __g_tracers__qo3mr__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qo3mr__[__i, __j, __k], upper_fix_5[__i, __j, __k + 1]) - dm_5[__i, __j, __k], dm_pos_5[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qo3mr__[__i, __j, __k]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_33 = (79 - 1); (__k_33 < (79 + 0)); __k_33 = (__k_33 + 1): - assign mask_140581921273680_gen_0_5 = (lower_fix[0, 0, (((- 79) + __k_33) + 1)] != 0.0) - assign if_expression_140581827122640 = mask_140581921273680_gen_0_5 - assign if_condition_40 = if_expression_140581827122640 - state boundary - if (not if_condition_40): - pass - state boundary - else: - state boundary - __g_tracers__qo3mr__[__i, __j, __k_33] = tasklet(dp2[__i, __j, __k_33], lower_fix_5[__i, __j, __k_33 - 1], __g_tracers__qo3mr__[__i, __j, __k_33]) - dup_gen_0_5[0], mask_140581899838608_gen_0_5[0] = tasklet(dp2[__i, __j, __k_33], dp2[__i, __j, __k_33 - 1], __g_tracers__qo3mr__[__i, __j, __k_33], __g_tracers__qo3mr__[__i, __j, __k_33 - 1]) - assign if_expression_140581827174992 = mask_140581899838608_gen_0_5 - assign if_condition_40 = if_expression_140581827174992 - state boundary - if (not if_condition_40): - pass - state boundary - else: - state boundary - __g_tracers__qo3mr__[__i, __j, __k_33], upper_fix_5[__i, __j, __k_33], __g_self__zfix[__i, __j] = tasklet(dp2[__i, __j, __k_33], dup_gen_0_5[0], __g_tracers__qo3mr__[__i, __j, __k_33], __g_self__zfix[__i, __j]) - dm_5[__i, __j, __k_33], dm_pos_5[__i, __j, __k_33] = tasklet(dp2[__i, __j, __k_33], __g_tracers__qo3mr__[__i, __j, __k_33]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [77]: - assign mask_140581914647888_gen_0_5 = (upper_fix_5[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581822051856 = mask_140581914647888_gen_0_5 - assign if_condition_39 = if_expression_140581822051856 - state boundary - if if_condition_39: - state boundary - dm_5[__i, __j, __k], dm_pos_5[__i, __j, __k], __g_tracers__qo3mr__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qo3mr__[__i, __j, __k], upper_fix_5[__i, __j, __k + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_32 = 1; (__k_32 < (79 + 0)); __k_32 = (__k_32 + 1): - state boundary - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j] = tasklet(dm_5[__i, __j, __k_32], dm_pos_5[__i, __j, __k_32], __g_self__sum0[__i, __j], __g_self__sum1[__i, __j]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [1:79]: - fac_gen_0_5[0], mask_140581910137424_gen_0_5[0] = tasklet(__g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j]) - assign if_expression_140581822206160 = mask_140581910137424_gen_0_5 - assign if_condition_38 = if_expression_140581822206160 - state boundary - if if_condition_38: - __g_tracers__qo3mr__[__i, __j, __k] = tasklet(dm_5[__i, __j, __k], dp2[__i, __j, __k], fac_gen_0_5[0]) - state boundary - __g_tracers__qsgs_tke__ = nview __g_tracers__qsgs_tke__[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - __g_self__sum1 = nview __g_self__sum1[3:15, 3:15] as (12, 12) - state boundary - __g_self__sum0 = nview __g_self__sum0[3:15, 3:15] as (12, 12) - state boundary - __g_self__zfix = nview __g_self__zfix[3:15, 3:15] as (12, 12) - state boundary - dp2 = nview dp2[3:15, 3:15, 0:79] as (12, 12, 79) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_35 = (2 - 1); (__k_35 > (1 - 1)); __k_35 = (__k_35 + (- 1)): - assign mask_140581926202832_gen_0_6 = (__g_tracers__qsgs_tke__[0, 0, (__k_35 - 1)] < 0.0) - assign if_expression_140581827083152 = mask_140581926202832_gen_0_6 - assign if_condition_41 = if_expression_140581827083152 - state boundary - if if_condition_41: - state boundary - __g_tracers__qsgs_tke__[__i, __j, __k_35] = tasklet(dp2[__i, __j, __k_35], dp2[__i, __j, __k_35 - 1], __g_tracers__qsgs_tke__[__i, __j, __k_35], __g_tracers__qsgs_tke__[__i, __j, __k_35 - 1]) - state boundary - for __k_35 = (1 - 1); (__k_35 > (0 - 1)); __k_35 = (__k_35 + (- 1)): - mask_140581926659024_gen_0_6[0] = tasklet(__g_tracers__qsgs_tke__[__i, __j, __k_35]) - assign if_expression_140581831436304 = mask_140581926659024_gen_0_6 - assign if_condition_41 = if_expression_140581831436304 - state boundary - if (not if_condition_41): - pass - state boundary - else: - __g_tracers__qsgs_tke__[__i, __j, __k_35] = tasklet() - dm_6[__i, __j, __k_35] = tasklet(dp2[__i, __j, __k_35], __g_tracers__qsgs_tke__[__i, __j, __k_35]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_34 = 0; (__k_34 < (79 + 0)); __k_34 = (__k_34 + 1): - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j] = tasklet() - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - map __k in [0:79]: - lower_fix_6[__i, __j, __k], upper_fix_6[__i, __j, __k] = tasklet() - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_36 = 1; (__k_36 < (79 - 1)); __k_36 = (__k_36 + 1): - assign mask_140581914760464_gen_0_6 = (lower_fix[0, 0, (__k_36 - 1)] != 0.0) - assign if_expression_140581822672528 = mask_140581914760464_gen_0_6 - assign if_condition_43 = if_expression_140581822672528 - state boundary - if if_condition_43: - state boundary - __g_tracers__qsgs_tke__[__i, __j, __k_36] = tasklet(dp2[__i, __j, __k_36], lower_fix_6[__i, __j, __k_36 - 1], __g_tracers__qsgs_tke__[__i, __j, __k_36]) - assign mask_140581925917968_gen_0_6 = (__g_tracers__qsgs_tke__[__i, __j, __k_36] < 0.0) - state boundary - state boundary - else: - assign mask_140581925917968_gen_0_6 = (__g_tracers__qsgs_tke__[__i, __j, __k_36] < 0.0) - state boundary - assign if_expression_140581822704400 = mask_140581925917968_gen_0_6 - assign if_condition_43 = if_expression_140581822704400 - state boundary - if if_condition_43: - state boundary - mask_140581914769424_gen_0_6[0], __g_self__zfix[__i, __j] = tasklet(__g_tracers__qsgs_tke__[__i, __j, __k_36 - 1], __g_self__zfix[__i, __j]) - assign if_expression_140581822726864 = mask_140581914769424_gen_0_6 - assign if_condition_43 = if_expression_140581822726864 - state boundary - if if_condition_43: - state boundary - __g_tracers__qsgs_tke__[__i, __j, __k_36], upper_fix_6[__i, __j, __k_36] = tasklet(dp2[__i, __j, __k_36], dp2[__i, __j, __k_36 - 1], __g_tracers__qsgs_tke__[__i, __j, __k_36], __g_tracers__qsgs_tke__[__i, __j, __k_36 - 1]) - assign mask_140581910137744_gen_0_6 = ((__g_tracers__qsgs_tke__[__i, __j, __k_36] < 0.0) and (__g_tracers__qsgs_tke__[0, 0, (__k_36 + 1)] > 0.0)) - state boundary - state boundary - else: - assign mask_140581910137744_gen_0_6 = ((__g_tracers__qsgs_tke__[__i, __j, __k_36] < 0.0) and (__g_tracers__qsgs_tke__[0, 0, (__k_36 + 1)] > 0.0)) - state boundary - assign if_expression_140581818162000 = mask_140581910137744_gen_0_6 - assign if_condition_43 = if_expression_140581818162000 - state boundary - if if_condition_43: - state boundary - lower_fix_6[__i, __j, __k_36], __g_tracers__qsgs_tke__[__i, __j, __k_36] = tasklet(dp2[__i, __j, __k_36], dp2[__i, __j, __k_36 + 1], __g_tracers__qsgs_tke__[__i, __j, __k_36], __g_tracers__qsgs_tke__[__i, __j, __k_36 + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [0:78]: - assign mask_140581914652944_gen_0_6 = (upper_fix_6[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581818226256 = mask_140581914652944_gen_0_6 - assign if_condition_42 = if_expression_140581818226256 - state boundary - if (not if_condition_42): - pass - state boundary - else: - state boundary - __g_tracers__qsgs_tke__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qsgs_tke__[__i, __j, __k], upper_fix_6[__i, __j, __k + 1]) - dm_6[__i, __j, __k], dm_pos_6[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qsgs_tke__[__i, __j, __k]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_38 = (79 - 1); (__k_38 < (79 + 0)); __k_38 = (__k_38 + 1): - assign mask_140581921273680_gen_0_6 = (lower_fix[0, 0, (((- 79) + __k_38) + 1)] != 0.0) - assign if_expression_140581818282192 = mask_140581921273680_gen_0_6 - assign if_condition_46 = if_expression_140581818282192 - state boundary - if (not if_condition_46): - pass - state boundary - else: - state boundary - __g_tracers__qsgs_tke__[__i, __j, __k_38] = tasklet(dp2[__i, __j, __k_38], lower_fix_6[__i, __j, __k_38 - 1], __g_tracers__qsgs_tke__[__i, __j, __k_38]) - dup_gen_0_6[0], mask_140581899838608_gen_0_6[0] = tasklet(dp2[__i, __j, __k_38], dp2[__i, __j, __k_38 - 1], __g_tracers__qsgs_tke__[__i, __j, __k_38], __g_tracers__qsgs_tke__[__i, __j, __k_38 - 1]) - assign if_expression_140581818334544 = mask_140581899838608_gen_0_6 - assign if_condition_46 = if_expression_140581818334544 - state boundary - if (not if_condition_46): - pass - state boundary - else: - state boundary - __g_tracers__qsgs_tke__[__i, __j, __k_38], upper_fix_6[__i, __j, __k_38], __g_self__zfix[__i, __j] = tasklet(dp2[__i, __j, __k_38], dup_gen_0_6[0], __g_tracers__qsgs_tke__[__i, __j, __k_38], __g_self__zfix[__i, __j]) - dm_6[__i, __j, __k_38], dm_pos_6[__i, __j, __k_38] = tasklet(dp2[__i, __j, __k_38], __g_tracers__qsgs_tke__[__i, __j, __k_38]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [77]: - assign mask_140581914647888_gen_0_6 = (upper_fix_6[__i, __j, (__k + 1)] != 0.0) - assign if_expression_140581818487120 = mask_140581914647888_gen_0_6 - assign if_condition_45 = if_expression_140581818487120 - state boundary - if if_condition_45: - state boundary - dm_6[__i, __j, __k], dm_pos_6[__i, __j, __k], __g_tracers__qsgs_tke__[__i, __j, __k] = tasklet(dp2[__i, __j, __k], __g_tracers__qsgs_tke__[__i, __j, __k], upper_fix_6[__i, __j, __k + 1]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - for __k_37 = 1; (__k_37 < (79 + 0)); __k_37 = (__k_37 + 1): - state boundary - __g_self__sum0[__i, __j], __g_self__sum1[__i, __j] = tasklet(dm_6[__i, __j, __k_37], dm_pos_6[__i, __j, __k_37], __g_self__sum0[__i, __j], __g_self__sum1[__i, __j]) - state boundary - map __tile_j, __tile_i in [0:12:8, 0:12:8]: - state boundary - map __i, __j in [__tile_i:__tile_i + Min(8, 12 - __tile_i), __tile_j:__tile_j + Min(8, 12 - __tile_j)]: - state boundary - map __k in [1:79]: - fac_gen_0_6[0], mask_140581910137424_gen_0_6[0] = tasklet(__g_self__sum0[__i, __j], __g_self__sum1[__i, __j], __g_self__zfix[__i, __j]) - assign if_expression_140581818657744 = mask_140581910137424_gen_0_6 - assign if_condition_44 = if_expression_140581818657744 - state boundary - if if_condition_44: - __g_tracers__qsgs_tke__[__i, __j, __k] = tasklet(dm_6[__i, __j, __k], dp2[__i, __j, __k], fac_gen_0_6[0]) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 8c477bec4c..6d9fe31c57 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -369,35 +369,6 @@ def remove_name_collisions(sdfg: SDFG): nsdfg.replace_dict(replacements) -def create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot): - """ - Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes - data containers, symbols, constants, etc. - - :param sdfg: The top-level SDFG to create the repository from. - :param stree: The tree root in which to make the unified descriptor repository. - """ - stree.containers = sdfg.arrays - stree.symbols = sdfg.symbols - stree.constants = sdfg.constants_prop - - # Since the SDFG is assumed to be de-aliased and contain unique names, we union the contents of - # the nested SDFGs' descriptor repositories - for nsdfg in sdfg.all_sdfgs_recursive(): - transients = {k: v for k, v in nsdfg.arrays.items() if v.transient} - - # Get all symbols that are not participating in nested SDFG symbol mappings (they will be removed) - syms_to_ignore = set() - if nsdfg.parent_nsdfg_node is not None: - syms_to_ignore = nsdfg.parent_nsdfg_node.symbol_mapping.keys() - symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols and k not in syms_to_ignore} - - constants = {k: v for k, v in nsdfg.constants_prop.items() if k not in stree.constants} - stree.containers.update(transients) - stree.symbols.update(symbols) - stree.constants.update(constants) - - def _make_view_node(state: SDFGState, edge: gr.MultiConnectorEdge[Memlet], view_name: str, viewed_name: str) -> tn.ViewNode: """ diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index d231437bda..05100906b8 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -8,7 +8,6 @@ from dace import sdfg as sd, symbolic from dace.sdfg import graph as gr, utils as sdutil, InterstateEdge from dace.sdfg.state import ControlFlowRegion, ControlFlowBlock -from dace.subsets import Range from dace.transformation import transformation @@ -858,47 +857,3 @@ def find_rotated_for_loop( return None return itervar, (start, end, stride), (start_states, last_loop_state) - - -class LoopRangeAnnotator(DetectLoop, transformation.MultiStateTransformation): - - def can_be_applied(self, graph, expr_index, sdfg, permissive = False): - if super().can_be_applied(graph, expr_index, sdfg, permissive): - loop_info = self.loop_information() - if loop_info is None: - return False - return True - return False - - def loop_guard_state(self): - """ - Returns the loop guard state of this loop (i.e., latch state or begin state for inverted or self loops). - """ - if self.expr_index in (0, 1): - return self.loop_guard - elif self.expr_index in (2, 3, 5, 6, 7): - return self.loop_latch - else: - return self.loop_begin - - def apply(self, graph, sdfg): - itvar, rng, _ = self.loop_information() - - body = self.loop_body() - meta = self.loop_meta_states() - full_body = set(body) - full_body.update(meta) - - # Make sure the range is flipped such that the stride is positive (in order to match subsets.Range). - start, stop, stride = rng - # ===== - # NOTE: This inequality needs to be checked exactly like this due to sympy limitations, do not simplify! - if (stride < 0) == True: - rng = (stop, start, -stride) - # ===== - - for v in full_body: - v.ranges[itvar] = Range([rng]) - guard_state = self.loop_guard_state() - guard_state.is_loop_guard = True - guard_state.itvar = itvar diff --git a/tests/codegen/allocation_lifetime_test.py b/tests/codegen/allocation_lifetime_test.py index 1f9f03e654..ebfc716ddf 100644 --- a/tests/codegen/allocation_lifetime_test.py +++ b/tests/codegen/allocation_lifetime_test.py @@ -206,7 +206,6 @@ def persistentmem(output: dace.int32[1]): del csdfg -@pytest.mark.skip(reason="In v1, produces two tasklets side-by-side, leading to nondeterministic code order") def test_alloc_persistent_threadlocal(): @dace.program @@ -600,7 +599,7 @@ def test_multisize(): test_persistent_gpu_transpose_regression() test_alloc_persistent_register() test_alloc_persistent() - # test_alloc_persistent_threadlocal() + test_alloc_persistent_threadlocal() test_alloc_persistent_threadlocal_naming() test_alloc_multistate() test_nested_view_samename() diff --git a/tests/codegen/cpp_test.py b/tests/codegen/cpp_test.py index c117de83ef..2e9f9b7ed1 100644 --- a/tests/codegen/cpp_test.py +++ b/tests/codegen/cpp_test.py @@ -2,7 +2,6 @@ from functools import reduce from operator import mul -from typing import Dict, Collection import warnings from dace import SDFG, Memlet, dtypes @@ -151,7 +150,7 @@ def test_reshape_strides_from_strided_and_offset_range(): def test_arrays_bigger_than_max_stack_size_get_deallocated(): # Setup SDFG with array A that is too big to be allocated on the stack. sdfg = SDFG("test") - sdfg.add_array(name="A", shape=(10000,), dtype=dtypes.float64, storage=dtypes.StorageType.Register, transient=True) + sdfg.add_array(name="A", shape=(10000, ), dtype=dtypes.float64, storage=dtypes.StorageType.Register, transient=True) state = sdfg.add_state("state", is_start_block=True) read = state.add_access("A") tasklet = state.add_tasklet("dummy", {"a"}, {}, "a = 1") diff --git a/tests/python_frontend/structures/structure_python_test.py b/tests/python_frontend/structures/structure_python_test.py index 5f505fd342..af317be7d8 100644 --- a/tests/python_frontend/structures/structure_python_test.py +++ b/tests/python_frontend/structures/structure_python_test.py @@ -1,5 +1,4 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -import ctypes import dace from dataclasses import dataclass import numpy as np @@ -460,7 +459,6 @@ def struct_recursive(A: Struct, B: Struct): test_write_structure_in_map() test_readwrite_structure_in_map() test_write_structure_in_loop() - test_struct_interface() test_struct_recursive() test_struct_recursive_from_dataclass() From 163f7fd24263a5494f6e78838c4cbcb68c0bdc62 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 8 Jan 2026 15:57:42 +0100 Subject: [PATCH 108/137] running pre-commit --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 ++ dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 2 +- dace/sdfg/memlet_utils.py | 2 ++ dace/sdfg/sdfg.py | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 6d9fe31c57..cb75f4a6ef 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -762,6 +762,7 @@ def _generate_views_in_scope( return result + def _prepare_sdfg_for_conversion(sdfg: SDFG, toplevel: bool) -> None: from dace.transformation import helpers as xfh # Avoid import loop @@ -778,6 +779,7 @@ def _prepare_sdfg_for_conversion(sdfg: SDFG, toplevel: bool) -> None: # Ensure no arrays alias in SDFG tree dealias_sdfg(sdfg) + def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeRoot: """ Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 2592f7a899..bf18f5ccf3 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -800,7 +800,7 @@ def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTr """ Inserts StateBoundaryNode objects into a schedule tree where more than one SDFG state would be necessary. Operates in-place on the given schedule tree. - + This happens when there is a: * write-after-write dependency; * write-after-read dependency that cannot be fulfilled via memlets; diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index d843ccfe5b..4f5f507ea2 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -11,6 +11,8 @@ from dace.frontend.python import memlet_parser import itertools from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Tuple, Union + + class MemletReplacer(ast.NodeTransformer): """ Iterates over all memlet expressions (name or subscript with matching array in SDFG) in a code block. diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 9aa0adfcb8..77fd9d6ed8 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1132,7 +1132,7 @@ def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeRoot': etc.) or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, - erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the + erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the ``as_sdfg`` method or the ``from_schedule_tree`` function in ``dace.sdfg.analysis.schedule_tree.tree_to_sdfg``. :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might From 32201beef5c16b25e92589814d020d4427c58fe0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 9 Jan 2026 12:07:06 +0100 Subject: [PATCH 109/137] fix import errors in tests --- .../analysis/schedule_tree/tree_to_sdfg.py | 109 +++++++++--------- dace/sdfg/analysis/schedule_tree/treenodes.py | 5 +- tests/schedule_tree/propagation_test.py | 3 +- 3 files changed, 59 insertions(+), 58 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index bf18f5ccf3..ef23fae244 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -176,59 +176,62 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: if memlet.data not in sdfg.arrays: raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") - def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: - before_state = self._current_state - pending = self._pending_interstate_assignments() - pending[node.header.itervar] = node.header.init - - guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) - self._current_state = guard_state - - body_state = sdfg.add_state(label="loop_body") - self._current_state = body_state - sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) - - # visit children inside the loop - self.visit(node.children, sdfg=sdfg) - - pending = self._pending_interstate_assignments() - pending[node.header.itervar] = node.header.update - _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) - - after_state = sdfg.add_state(label="loop_after") - self._current_state = after_state - sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) - - def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: - before_state = self._current_state - guard_state = _insert_and_split_assignments(sdfg, - before_state, - label="guard_state", - assignments=self._pending_interstate_assignments()) - self._current_state = guard_state - - body_state = sdfg.add_state(label="loop_body") - self._current_state = body_state - sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) - - # visit children inside the loop - self.visit(node.children, sdfg=sdfg) - _insert_and_split_assignments(sdfg, - before_state=self._current_state, - after_state=guard_state, - assignments=self._pending_interstate_assignments()) - - after_state = sdfg.add_state(label="loop_after") - self._current_state = after_state - sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) - - def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: - # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") - - def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + #def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + # before_state = self._current_state + # pending = self._pending_interstate_assignments() + # pending[node.header.itervar] = node.header.init + # + # guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) + # self._current_state = guard_state + # + # body_state = sdfg.add_state(label="loop_body") + # self._current_state = body_state + # sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) + # + # # visit children inside the loop + # self.visit(node.children, sdfg=sdfg) + # + # pending = self._pending_interstate_assignments() + # pending[node.header.itervar] = node.header.update + # _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) + # + # after_state = sdfg.add_state(label="loop_after") + # self._current_state = after_state + # sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) + + #def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + # before_state = self._current_state + # guard_state = _insert_and_split_assignments(sdfg, + # before_state, + # label="guard_state", + # assignments=self._pending_interstate_assignments()) + # self._current_state = guard_state + # + # body_state = sdfg.add_state(label="loop_body") + # self._current_state = body_state + # sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) + # + # # visit children inside the loop + # self.visit(node.children, sdfg=sdfg) + # _insert_and_split_assignments(sdfg, + # before_state=self._current_state, + # after_state=guard_state, + # assignments=self._pending_interstate_assignments()) + # + # after_state = sdfg.add_state(label="loop_after") + # self._current_state = after_state + # sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) + + #def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: + # # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. + # raise NotImplementedError(f"{type(node)} not implemented") + + #def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: + # # Let's see if we need this for the first prototype ... + # raise NotImplementedError(f"{type(node)} not implemented") + + def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: + raise NotImplementedError("TODO: LoopScopes are not yet implemented") def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: before_state = self._current_state diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index b471740199..446e4a7e1c 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,16 +1,13 @@ # Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. -import ast from dataclasses import dataclass, field from dace import nodes, data, subsets, dtypes -from dace.codegen import control_flow as cf from dace.properties import CodeBlock from dace.sdfg import InterstateEdge from dace.sdfg.memlet_utils import MemletSet from dace.sdfg.propagation import propagate_subset from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast -from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState -from dace.symbolic import symbol +from dace.sdfg.state import LoopRegion, SDFGState from dace.memlet import Memlet from types import TracebackType from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Set, Tuple, Union diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py index 507a3d7226..2b09fe612f 100644 --- a/tests/schedule_tree/propagation_test.py +++ b/tests/schedule_tree/propagation_test.py @@ -22,7 +22,8 @@ def tester(a: dace.float64[20]): stree = t2s.insert_state_boundaries_to_tree(stree) node_types = [n for n in stree.preorder_traversal()] - assert isinstance(node_types[2], tn.ForScope) + assert isinstance(node_types[2], tn.LoopScope) + assert node_types[2]._check_loop_variant() == "for" memlet = dace.Memlet('a[1:N]') memlet._is_data_src = False assert list(node_types[2].output_memlets()) == [memlet] From e9ecf5cfb11b93f8669d6c7d1148ab6f323c64f6 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:09:42 +0100 Subject: [PATCH 110/137] remove duplicate code Merging main duplicated this `elif` block because PR https://github.com/spcl/dace/pull/2165/ was backported as PR https://github.com/spcl/dace/pull/2166/ and then the merge saw those two at different lines, which lead to a duplicate `elif` block. --- dace/frontend/python/newast.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 7c25f0cef7..807e1e80df 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3654,12 +3654,6 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): true_name = self.sdfg.add_datadesc(name, desc, find_new_name=True) self.variables[name] = true_name defined_vars[name] = true_name - elif name in self.annotated_types and isinstance(self.annotated_types[name], data.Reference): - desc = copy.deepcopy(self.annotated_types[name]) - desc.transient = True - true_name = self.sdfg.add_datadesc(name, desc, find_new_name=True) - self.variables[name] = true_name - defined_vars[name] = true_name elif (not name.startswith('__return') and (isinstance(result_data, data.View) or (not result_data.transient and isinstance(result_data, data.Array)))): From 4631de6d8a143cb395bb8fc9459066182ced0111 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 9 Jan 2026 15:13:43 +0100 Subject: [PATCH 111/137] fixup: fixing bad merge in sdfg_to_stree.py --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index cb75f4a6ef..734cc089bf 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -800,7 +800,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) if not in_place: sdfg = copy.deepcopy(sdfg) - _prepare_sdfg_for_conversion(sdfg) + _prepare_sdfg_for_conversion(sdfg, toplevel) result = tn.ScheduleTreeScope(children=_block_schedule_tree(sdfg)) tn.validate_has_no_other_node_types(result) From 14b8fc0d688a2d59fffd957435aa3dfc509b35f4 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 12 Jan 2026 17:51:46 +0100 Subject: [PATCH 112/137] WIP: re-introduce ForScope, WhileScope and DoWhileScope --- .../analysis/schedule_tree/sdfg_to_tree.py | 33 +++-- .../analysis/schedule_tree/tree_to_sdfg.py | 6 +- dace/sdfg/analysis/schedule_tree/treenodes.py | 128 ++++++++++++------ dace/sdfg/propagation.py | 20 ++- dace/sdfg/state.py | 28 ++-- tests/schedule_tree/propagation_test.py | 3 +- 6 files changed, 139 insertions(+), 79 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 93c857a890..bd79b916ac 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -701,11 +701,24 @@ def _block_schedule_tree(block: ControlFlowBlock) -> List[tn.ScheduleTreeNode]: pivot = None if isinstance(block, LoopRegion): - # If this is a loop region, wrap everything in a LoopScope node. - loop_node = tn.LoopScope(loop=block, children=children) - return [loop_node] + # If this is a loop region, wrap everything in a loop scope node. + variant = tn.loop_variant(block) + if variant == "for": + return [tn.ForScope(loop=block, children=children)] + + if variant == "while": + return [tn.WhileScope(loop=block, children=children)] + + if variant == "do-while": + return [tn.DoWhileScope(loop=block, children=children)] + + # If we end up here, we don't need more granularity and just use + # a general loop scope + return [tn.LoopScope(loop=block, children=children)] + return children - elif isinstance(block, ConditionalBlock): + + if isinstance(block, ConditionalBlock): result: List[tn.ScheduleTreeNode] = [] if_node = tn.IfScope(condition=block.branches[0][0], children=_block_schedule_tree(block.branches[0][1])) result.append(if_node) @@ -717,9 +730,11 @@ def _block_schedule_tree(block: ControlFlowBlock) -> List[tn.ScheduleTreeNode]: else_node = tn.ElseScope(children=_block_schedule_tree(branch_body)) result.append(else_node) return result - elif isinstance(block, SDFGState): + + if isinstance(block, SDFGState): return _state_schedule_tree(block) - elif isinstance(block, ReturnBlock): + + if isinstance(block, ReturnBlock): # For return blocks, add a goto node to the end of the schedule tree. # NOTE: Return blocks currently always exit the entire SDFG context they are contained in, meaning that the exit # goto has target=None. However, in the future we want to adapt Return blocks to be able to return only a @@ -727,8 +742,8 @@ def _block_schedule_tree(block: ControlFlowBlock) -> List[tn.ScheduleTreeNode]: # entire SDFG. goto_node = tn.GotoNode(target=None) return [goto_node] - else: - raise tn.UnsupportedScopeException(type(block).__name__) + + raise tn.UnsupportedScopeException(type(block).__name__) def _generate_views_in_scope( @@ -807,7 +822,7 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) - return result + return tn.ScheduleTreeRoot(children=result.children, name="my_stree") if __name__ == '__main__': diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index ef23fae244..9eb61669f4 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -579,9 +579,9 @@ def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: # AFAIK we don't support consume scopes in the gt4py/dace bridge. raise NotImplementedError(f"{type(node)} not implemented") - def visit_PipelineScope(self, node: tn.PipelineScope, sdfg: SDFG) -> None: - # AFAIK we don't support pipeline scopes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + # def visit_PipelineScope(self, node: tn.PipelineScope, sdfg: SDFG) -> None: + # # AFAIK we don't support pipeline scopes in the gt4py/dace bridge. + # raise NotImplementedError(f"{type(node)} not implemented") def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add Tasklet to current state diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index c6c16ff46f..d3fc90708c 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -362,51 +362,78 @@ class LoopScope(ControlFlowScope): """ loop: LoopRegion - def _check_loop_variant( - self - ) -> Union[Literal['for'], Literal['while'], Literal['do-while'], Literal['do-for-uncond-increment'], - Literal['do-for']]: - if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - if self.loop.inverted: - if self.loop.update_before_condition: - return 'do-for-uncond-increment' - else: - return 'do-for' - else: - return 'for' - else: - if self.loop.inverted: - return 'do-while' - else: - return 'while' - def as_string(self, indent: int = 0): loop = self.loop - loop_variant = self._check_loop_variant() - if loop_variant == 'do-for-uncond-increment': + variant = loop_variant(loop) + if variant == 'do-for-uncond-increment': pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' header = indent * INDENTATION + 'do:\n' pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer - elif loop_variant == 'do-for': + + if variant == 'do-for': pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' header = indent * INDENTATION + 'while True:\n' pre_footer = (indent + 1) * INDENTATION + f'if (not {loop.loop_condition.as_string}):\n' pre_footer += (indent + 2) * INDENTATION + 'break\n' footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer - elif loop_variant == 'for': - result = (indent * INDENTATION + f'for {loop.init_statement.as_string}; ' + - f'{loop.loop_condition.as_string}; ' + f'{loop.update_statement.as_string}:\n') - return result + super().as_string(indent) - elif loop_variant == 'while': - result = indent * INDENTATION + f'while {loop.loop_condition.as_string}:\n' - return result + super().as_string(indent) - else: # 'do-while' - header = indent * INDENTATION + 'do:\n' - footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' - return header + super().as_string(indent) + '\n' + footer + + if variant in ["for", "while", "do-while"]: + return super().as_string(indent) + + return NotImplementedError # TODO: nice error message + + +@dataclass +class ForScope(LoopScope): + """Specialized LoopScope for for-loops.""" + + def as_string(self, indent: int = 0) -> str: + init_statement = self.loop.init_statement.as_string + condition = self.loop.loop_condition.as_string + update_statement = self.loop.update_statement.as_string + result = indent * INDENTATION + f"for {init_statement}; {condition}; {update_statement}:\n" + return result + super().as_string(indent) + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + raise NotImplementedError + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + raise NotImplementedError + + +@dataclass +class WhileScope(LoopScope): + """Specialized LoopScope for while-loops.""" + + def as_string(self, indent: int = 0) -> str: + condition = self.loop.loop_condition.as_string + result = indent * INDENTATION + f'while {condition}:\n' + return result + super().as_string(indent) + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + raise NotImplementedError + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + raise NotImplementedError + + +@dataclass +class DoWhileScope(LoopScope): + """Specialized LoopScope for do-while-loops""" + + def as_string(self, indent: int = 0) -> str: + header = indent * INDENTATION + 'do:\n' + footer = indent * INDENTATION + f'while {self.loop.loop_condition.as_string}' + return header + super().as_string(indent) + '\n' + footer + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + raise NotImplementedError + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + raise NotImplementedError @dataclass @@ -545,17 +572,16 @@ def as_string(self, indent: int = 0): @dataclass -class PipelineScope(MapScope): - """ - Pipeline scope. - """ - node: nodes.PipelineEntry - - def as_string(self, indent: int = 0): - rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) - result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' - return result + super().as_string(indent) - +# class PipelineScope(MapScope): +# """ +# Pipeline scope. +# """ +# node: nodes.PipelineEntry +# +# def as_string(self, indent: int = 0): +# rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) +# result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' +# return result + super().as_string(indent) @dataclass class TaskletNode(ScheduleTreeNode): @@ -802,3 +828,19 @@ def validate_has_no_other_node_types(stree: ScheduleTreeScope) -> None: raise RuntimeError(f'Unsupported node type: {type(child).__name__}') if isinstance(child, ScheduleTreeScope): validate_has_no_other_node_types(child) + + +def loop_variant( + loop: LoopRegion +) -> Union[Literal['for'], Literal['while'], Literal['do-while'], Literal['do-for-uncond-increment'], + Literal['do-for']]: + if loop.update_statement and loop.init_statement and loop.loop_variable: + if loop.inverted: + if loop.update_before_condition: + return 'do-for-uncond-increment' + return 'do-for' + return 'for' + + if loop.inverted: + return 'do-while' + return 'while' diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index c1962c6188..594f8fbfd6 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -423,9 +423,9 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di if symbolic.issymbolic(dim): used_symbols.update(dim.free_symbols) - if any(s not in defined_vars for s in (used_symbols - set(self.params))): - # Cannot propagate symbols that are undefined outside scope (e.g., internal symbols) - return False + # if any(s not in defined_vars for s in (used_symbols - set(self.params))): + # # Cannot propagate symbols that are undefined outside scope (e.g., internal symbols) + # return False if (used_symbols & set(self.params) and any(symbolic.pystr_to_symbolic(s) not in defined_vars for s in node_range.free_symbols)): @@ -1430,7 +1430,12 @@ def propagate_memlet(dfg_state, # Propagate subset if isinstance(entry_node, nodes.MapEntry): mapnode = entry_node.map - return propagate_subset(aggdata, arr, mapnode.params, mapnode.range, defined_vars, use_dst=use_dst) + return propagate_subset(aggdata, + arr, + mapnode.params, + mapnode.range, + defined_variables=defined_vars, + use_dst=use_dst) elif isinstance(entry_node, nodes.ConsumeEntry): # Nothing to analyze/propagate in consume @@ -1449,6 +1454,7 @@ def propagate_subset(memlets: List[Memlet], arr: data.Data, params: List[str], rng: subsets.Subset, + *, defined_variables: Set[symbolic.SymbolicType] = None, undefined_variables: Set[symbolic.SymbolicType] = None, use_dst: bool = False) -> Memlet: @@ -1485,8 +1491,10 @@ def propagate_subset(memlets: List[Memlet], else: defined_variables = set(defined_variables) - if undefined_variables: + if undefined_variables is not None: defined_variables = defined_variables - set(symbolic.pystr_to_symbolic(p) for p in undefined_variables) + else: + undefined_variables = set() # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1536,7 +1544,7 @@ def propagate_subset(memlets: List[Memlet], fsyms = _freesyms(sdim) fsyms_str = set(map(str, fsyms)) contains_params |= len(fsyms_str & paramset) != 0 - contains_undefs |= len(fsyms - defined_variables) != 0 + contains_undefs |= len(fsyms & undefined_variables) != 0 if contains_params or contains_undefs: tmp_subset_rng.append(ea) else: diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index cf64642128..51c8ba0379 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2609,6 +2609,7 @@ class AbstractControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.Inte ControlFlowBlock, abc.ABC): """ Abstract superclass to represent all kinds of control flow regions in an SDFG. + This is consequently one of the three main classes of control flow graph nodes, which include ``ControlFlowBlock``s, ``SDFGState``s, and nested ``AbstractControlFlowRegion``s. An ``AbstractControlFlowRegion`` can further be either a region that directly contains a control flow graph (``ControlFlowRegion``s and subclasses thereof), or something @@ -3131,6 +3132,7 @@ def start_block(self, block_id): class ControlFlowRegion(AbstractControlFlowRegion): """ A ``ControlFlowRegion`` represents a control flow graph node that itself contains a control flow graph. + This can be an arbitrary control flow graph, but may also be a specific type of control flow region with additional semantics, such as a loop or a function call. """ @@ -3189,32 +3191,26 @@ def __init__(self, update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, sdfg: Optional['SDFG'] = None, - update_before_condition=True): + update_before_condition: bool = True): super(LoopRegion, self).__init__(label, sdfg) - if initialize_expr is not None: - if isinstance(initialize_expr, CodeBlock): - self.init_statement = initialize_expr - else: - self.init_statement = CodeBlock(initialize_expr) + if initialize_expr is None or isinstance(initialize_expr, CodeBlock): + self.init_statement = initialize_expr else: - self.init_statement = None + self.init_statement = CodeBlock(initialize_expr) - if condition_expr: + if condition_expr is None: + self.loop_condition = CodeBlock('True') + else: if isinstance(condition_expr, CodeBlock): self.loop_condition = condition_expr else: self.loop_condition = CodeBlock(condition_expr) - else: - self.loop_condition = CodeBlock('True') - if update_expr is not None: - if isinstance(update_expr, CodeBlock): - self.update_statement = update_expr - else: - self.update_statement = CodeBlock(update_expr) + if update_expr is None or isinstance(update_expr, CodeBlock): + self.update_statement = update_expr else: - self.update_statement = None + self.update_statement = CodeBlock(update_expr) self.loop_variable = loop_var or '' self.inverted = inverted diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py index 2b09fe612f..507a3d7226 100644 --- a/tests/schedule_tree/propagation_test.py +++ b/tests/schedule_tree/propagation_test.py @@ -22,8 +22,7 @@ def tester(a: dace.float64[20]): stree = t2s.insert_state_boundaries_to_tree(stree) node_types = [n for n in stree.preorder_traversal()] - assert isinstance(node_types[2], tn.LoopScope) - assert node_types[2]._check_loop_variant() == "for" + assert isinstance(node_types[2], tn.ForScope) memlet = dace.Memlet('a[1:N]') memlet._is_data_src = False assert list(node_types[2].output_memlets()) == [memlet] From d3e615fb018a8b6c53eda2c05662bb0ffb159d00 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 13 Jan 2026 21:27:02 +0100 Subject: [PATCH 113/137] WIP: more work understanding LoopRegions and how to work with them --- dace/sdfg/analysis/schedule_tree/treenodes.py | 81 ++++++++++++++++--- dace/sdfg/sdfg.py | 4 +- 2 files changed, 71 insertions(+), 14 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index d3fc90708c..9fa9c8e73c 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,7 +1,7 @@ # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass, field -from dace import nodes, data, subsets, dtypes +from dace import nodes, data, subsets, dtypes, symbolic from dace.properties import CodeBlock from dace.sdfg import InterstateEdge from dace.sdfg.memlet_utils import MemletSet @@ -398,10 +398,29 @@ def as_string(self, indent: int = 0) -> str: return result + super().as_string(indent) def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: - raise NotImplementedError + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(self.loop.get_meta_read_memlets()) - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: - raise NotImplementedError + # If loop range is well-formed, use it in propagation + range = loop_range(self.loop) + if range is not None: + propagate = {self.loop.loop_variable: range} + else: + propagate = None + + result.update(super().input_memlets(root, propagate=propagate, **kwargs)) + return result + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + # If loop range is well-formed, use it in propagation + range = loop_range(self.loop) + if range is not None: + propagate = {self.loop.loop_variable: range} + else: + propagate = None + + return super().output_memlets(root, propagate=propagate, **kwargs) @dataclass @@ -414,10 +433,11 @@ def as_string(self, indent: int = 0) -> str: return result + super().as_string(indent) def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: - raise NotImplementedError - - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: - raise NotImplementedError + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(self.loop.get_meta_read_memlets()) + result.update(super().input_memlets(root, **kwargs)) + return result @dataclass @@ -430,10 +450,11 @@ def as_string(self, indent: int = 0) -> str: return header + super().as_string(indent) + '\n' + footer def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: - raise NotImplementedError - - def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: - raise NotImplementedError + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(self.loop.get_meta_read_memlets()) + result.update(super().input_memlets(root, **kwargs)) + return result @dataclass @@ -844,3 +865,39 @@ def loop_variant( if loop.inverted: return 'do-while' return 'while' + + +def loop_range( + loop: LoopRegion) -> Optional[Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType]]: + """ + For well-formed for-loops, returns a tuple of (start, end, stride). Otherwise, returns None. + """ + if loop_variant(loop) != "for": + # Loop range is only defined in for-loops + return None + + # TODO: + # The following is the old (dace v1) way of doing things. + # This is not how it's done anymore. + + # Get inspired by `LoopRegion.can_normalize()` and + # `LoopRegion.normalize()` to see how to figure out + # the range of for-loops (which is essential for + # correct memlet propagation). + + condition_edge = None + for edge in loop.all_interstate_edges(): + if edge.data.condition == loop.loop_condition: + condition_edge = edge + + if condition_edge is None: + return None # Condition edge not found + + from dace.transformation.interstate.loop_detection import find_for_loop + result = find_for_loop(loop.root_sdfg, condition_edge.src, condition_edge.dst, loop.loop_variable) + + if result is None: + # proper for-loop was not detected + return None + + return result[1] # (start, end, stride) where `end` is inclusive diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index df95eb516b..1c2099ddfb 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -417,7 +417,7 @@ def from_json(json_obj, context=None): def label(self): assignments = ','.join(['%s=%s' % (k, v) for k, v in self.assignments.items()]) - # Edge with assigment only (no condition) + # Edge with assignment only (no condition) if self.condition.as_string == '1': # Edge without conditions or assignments if len(self.assignments) == 0: @@ -428,7 +428,7 @@ def label(self): if len(self.assignments) == 0: return self.condition.as_string - # Edges with assigments and conditions + # Edges with assignments and conditions return self.condition.as_string + '; ' + assignments From abbac85a604ed03953d89ffdfe609921b0a078dd Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 14 Jan 2026 15:47:36 +0100 Subject: [PATCH 114/137] WIP: memlet propagation for simple cases --- .../analysis/schedule_tree/sdfg_to_tree.py | 38 ++++++++-- dace/sdfg/analysis/schedule_tree/treenodes.py | 71 ++++++++++++------- 2 files changed, 80 insertions(+), 29 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index bd79b916ac..7d7cc16cf4 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -777,7 +777,7 @@ def _generate_views_in_scope( return result -def _prepare_sdfg_for_conversion(sdfg: SDFG, toplevel: bool) -> None: +def _prepare_sdfg_for_conversion(sdfg: SDFG, *, toplevel: bool) -> None: from dace.transformation import helpers as xfh # Avoid import loop # Split edges with assignments and conditions @@ -794,7 +794,29 @@ def _prepare_sdfg_for_conversion(sdfg: SDFG, toplevel: bool) -> None: dealias_sdfg(sdfg) -def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeRoot: +def _create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot): + """ + Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes + data containers, symbols, constants, etc. + :param sdfg: The top-level SDFG to create the repository from. + :param stree: The tree root in which to make the unified descriptor repository. + """ + stree.containers = sdfg.arrays + stree.symbols = sdfg.symbols + stree.constants = sdfg.constants_prop + + # Since the SDFG is assumed to be de-aliased and contain unique names, we union the contents of + # the nested SDFGs' descriptor repositories + for nsdfg in sdfg.all_sdfgs_recursive(): + transients = {k: v for k, v in nsdfg.arrays.items() if v.transient} + symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols} + constants = {k: v for k, v in nsdfg.constants_prop.items() if k not in stree.constants} + stree.containers.update(transients) + stree.symbols.update(symbols) + stree.constants.update(constants) + + +def as_schedule_tree(sdfg: SDFG, *, in_place: bool = False, toplevel: bool = False) -> tn.ScheduleTreeRoot: """ Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. @@ -814,15 +836,21 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) if not in_place: sdfg = copy.deepcopy(sdfg) - _prepare_sdfg_for_conversion(sdfg, toplevel) + _prepare_sdfg_for_conversion(sdfg, toplevel=toplevel) + + if toplevel: + result = tn.ScheduleTreeRoot(children=[], name="my-stree-name") + _create_unified_descriptor_repository(sdfg, result) + result.children = _block_schedule_tree(sdfg) + else: + result = tn.ScheduleTreeScope(children=_block_schedule_tree(sdfg)) - result = tn.ScheduleTreeScope(children=_block_schedule_tree(sdfg)) tn.validate_has_no_other_node_types(result) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) - return tn.ScheduleTreeRoot(children=result.children, name="my_stree") + return result if __name__ == '__main__': diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 9fa9c8e73c..8f002f62c3 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,5 +1,6 @@ # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass, field +import sympy from dace import nodes, data, subsets, dtypes, symbolic from dace.properties import CodeBlock @@ -403,7 +404,7 @@ def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> result.update(self.loop.get_meta_read_memlets()) # If loop range is well-formed, use it in propagation - range = loop_range(self.loop) + range = _loop_range(self.loop) if range is not None: propagate = {self.loop.loop_variable: range} else: @@ -414,7 +415,7 @@ def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: # If loop range is well-formed, use it in propagation - range = loop_range(self.loop) + range = _loop_range(self.loop) if range is not None: propagate = {self.loop.loop_variable: range} else: @@ -867,37 +868,59 @@ def loop_variant( return 'while' -def loop_range( +def _loop_range( loop: LoopRegion) -> Optional[Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType]]: """ - For well-formed for-loops, returns a tuple of (start, end, stride). Otherwise, returns None. + Derive loop range for well-formed `for`-loops. + + :param: loop The loop to be analyzed. + :return: If well formed, `(start, end, step)` where `end` is inclusive, otherwise `None`. """ - if loop_variant(loop) != "for": + + if loop_variant(loop) != "for" or loop.loop_variable is None: # Loop range is only defined in for-loops + # and we need to know the loop variable. return None - # TODO: - # The following is the old (dace v1) way of doing things. - # This is not how it's done anymore. + # Avoid cyclic import + from dace.transformation.passes.analysis import loop_analysis - # Get inspired by `LoopRegion.can_normalize()` and - # `LoopRegion.normalize()` to see how to figure out - # the range of for-loops (which is essential for - # correct memlet propagation). + # If loop information cannot be determined, we cannot derive loop range + start = loop_analysis.get_init_assignment(loop) + step = loop_analysis.get_loop_stride(loop) + end = _match_loop_condition(loop) + if start is None or step is None or end is None: + return None - condition_edge = None - for edge in loop.all_interstate_edges(): - if edge.data.condition == loop.loop_condition: - condition_edge = edge + return (start, end, step) # `end` is inclusive - if condition_edge is None: - return None # Condition edge not found - from dace.transformation.interstate.loop_detection import find_for_loop - result = find_for_loop(loop.root_sdfg, condition_edge.src, condition_edge.dst, loop.loop_variable) +def _match_loop_condition(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Try to find the end of a for-loop by symbolically matching the loop condition. - if result is None: - # proper for-loop was not detected - return None + :return: loop end (inclusive) or `None` if matching failed. + """ + + condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) + loop_symbol = symbolic.pystr_to_symbolic(loop.loop_variable) + a = sympy.Wild('a') + + match = condition.match(loop_symbol < a) + if match is not None: + return match[a] - 1 + + match = condition.match(loop_symbol <= a) + if match is not None: + return match[a] + + match = condition.match(loop_symbol >= a) + if match is not None: + return match[a] + + match = condition.match(loop_symbol > a) + if match is not None: + return match[a] + 1 - return result[1] # (start, end, stride) where `end` is inclusive + # Matching failed - we can't derive end of loop + return None From 509da805dd518b5ba23969a4d5dc5279a476d2d7 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 15 Jan 2026 17:52:58 +0100 Subject: [PATCH 115/137] wip: work in progress towards building a valid stree --- .../analysis/schedule_tree/sdfg_to_tree.py | 7 +- dace/sdfg/analysis/schedule_tree/treenodes.py | 121 ++++++++++++++++-- 2 files changed, 115 insertions(+), 13 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 7d7cc16cf4..57996dc814 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -816,7 +816,7 @@ def _create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot stree.constants.update(constants) -def as_schedule_tree(sdfg: SDFG, *, in_place: bool = False, toplevel: bool = False) -> tn.ScheduleTreeRoot: +def as_schedule_tree(sdfg: SDFG, *, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeRoot: """ Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. @@ -839,13 +839,14 @@ def as_schedule_tree(sdfg: SDFG, *, in_place: bool = False, toplevel: bool = Fal _prepare_sdfg_for_conversion(sdfg, toplevel=toplevel) if toplevel: - result = tn.ScheduleTreeRoot(children=[], name="my-stree-name") + result = tn.ScheduleTreeRoot(name="my-stree-name", children=[]) _create_unified_descriptor_repository(sdfg, result) - result.children = _block_schedule_tree(sdfg) + result.add_children(_block_schedule_tree(sdfg)) else: result = tn.ScheduleTreeScope(children=_block_schedule_tree(sdfg)) tn.validate_has_no_other_node_types(result) + tn.validate_children_and_parents_align(result, root=toplevel) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 8f002f62c3..dd2a68d29b 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -11,7 +11,7 @@ from dace.sdfg.state import LoopRegion, SDFGState from dace.memlet import Memlet from types import TracebackType -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Literal, Optional, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Literal, Optional, Set, Tuple, Union if TYPE_CHECKING: from dace import SDFG @@ -112,13 +112,19 @@ def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> class ScheduleTreeScope(ScheduleTreeNode): children: List['ScheduleTreeNode'] - def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): - self.children = children or [] - if self.children: - for child in children: - child.parent = self - self.containers = {} - self.symbols = {} + def __init__(self, children: List['ScheduleTreeNode']) -> None: + for child in children: + child.parent = self + + self.children = children + + def add_children(self, children: Iterable['ScheduleTreeNode']) -> None: + for child in children: + child.parent = self + self.children.append(child) + + def add_child(self, child: ScheduleTreeNode) -> None: + self.add_children([child]) def as_string(self, indent: int = 0): if not self.children: @@ -242,6 +248,25 @@ class ScheduleTreeRoot(ScheduleTreeScope): callback_mapping: Dict[str, str] = field(default_factory=dict) arg_names: List[str] = field(default_factory=list) + def __init__( + self, + *, + name: str, + children: List[ScheduleTreeNode], + containers: Optional[Dict[str, data.Data]] = None, + symbols: Optional[Dict[str, dtypes.typeclass]] = None, + constants: Optional[Dict[str, Tuple[data.Data, Any]]] = None, + ) -> None: + super().__init__(children) + + self.name = name + if containers is not None: + self.containers = containers + if symbols is not None: + self.symbols = symbols + if constants is not None: + self.constants = constants + def as_sdfg(self, validate: bool = True, simplify: bool = True, @@ -282,7 +307,9 @@ def scope(self, state: SDFGState, ctx: Context) -> ContextPushPop: @dataclass class ControlFlowScope(ScheduleTreeScope): - pass + + def __init__(self, children: List[ScheduleTreeNode]) -> None: + super().__init__(children) @dataclass @@ -290,6 +317,15 @@ class DataflowScope(ScheduleTreeScope): node: nodes.EntryNode state: Optional[SDFGState] = None + def __init__(self, + node: nodes.EntryNode, + children: List[ScheduleTreeNode], + state: Optional[SDFGState] = None) -> None: + super().__init__(children) + + self.node = node + self.state = state + def scope(self, state: SDFGState, ctx: Context) -> ContextPushPop: return ContextPushPop(ctx, state, self) @@ -302,6 +338,9 @@ class GBlock(ControlFlowScope): Normally contains irreducible control flow. """ + def __init__(self, children: List[ScheduleTreeNode]) -> None: + super().__init__(children) + def as_string(self, indent: int = 0): result = indent * INDENTATION + 'gblock:\n' return result + super().as_string(indent) @@ -363,6 +402,11 @@ class LoopScope(ControlFlowScope): """ loop: LoopRegion + def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: + super().__init__(children) + + self.loop = loop + def as_string(self, indent: int = 0): loop = self.loop variant = loop_variant(loop) @@ -391,6 +435,9 @@ def as_string(self, indent: int = 0): class ForScope(LoopScope): """Specialized LoopScope for for-loops.""" + def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: + super().__init__(loop, children) + def as_string(self, indent: int = 0) -> str: init_statement = self.loop.init_statement.as_string condition = self.loop.loop_condition.as_string @@ -399,7 +446,6 @@ def as_string(self, indent: int = 0) -> str: return result + super().as_string(indent) def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: - root = root if root is not None else self.get_root() result = MemletSet() result.update(self.loop.get_meta_read_memlets()) @@ -428,6 +474,9 @@ def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> class WhileScope(LoopScope): """Specialized LoopScope for while-loops.""" + def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: + super().__init__(loop, children) + def as_string(self, indent: int = 0) -> str: condition = self.loop.loop_condition.as_string result = indent * INDENTATION + f'while {condition}:\n' @@ -445,6 +494,9 @@ def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> class DoWhileScope(LoopScope): """Specialized LoopScope for do-while-loops""" + def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: + super().__init__(loop, children) + def as_string(self, indent: int = 0) -> str: header = indent * INDENTATION + 'do:\n' footer = indent * INDENTATION + f'while {self.loop.loop_condition.as_string}' @@ -465,6 +517,11 @@ class IfScope(ControlFlowScope): """ condition: CodeBlock + def __init__(self, condition: CodeBlock, children: List[ScheduleTreeNode]) -> None: + super().__init__(children) + + self.condition = condition + def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) @@ -483,6 +540,9 @@ class StateIfScope(IfScope): A special class of an if scope in general blocks for if statements that are part of a state transition. """ + def __init__(self, condition: CodeBlock, children: List[ScheduleTreeNode]) -> None: + super().__init__(condition, children) + def as_string(self, indent: int = 0): result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' return result + super(IfScope, self).as_string(indent) @@ -527,6 +587,11 @@ class ElifScope(ControlFlowScope): """ condition: CodeBlock + def __init__(self, condition: CodeBlock, children: List[ScheduleTreeNode]) -> None: + super().__init__(children) + + self.condition = condition + def as_string(self, indent: int = 0): result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' return result + super().as_string(indent) @@ -545,6 +610,9 @@ class ElseScope(ControlFlowScope): Else branch scope. """ + def __init__(self, children: List[ScheduleTreeNode]) -> None: + super().__init__(children) + def as_string(self, indent: int = 0): result = indent * INDENTATION + 'else:\n' return result + super().as_string(indent) @@ -557,6 +625,12 @@ class MapScope(DataflowScope): """ node: nodes.MapEntry + def __init__(self, + node: nodes.MapEntry, + children: List[ScheduleTreeNode], + state: Optional[SDFGState] = None) -> None: + super().__init__(node, children, state) + def as_string(self, indent: int = 0): rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' @@ -586,6 +660,12 @@ class ConsumeScope(DataflowScope): """ node: nodes.ConsumeEntry + def __init__(self, + node: nodes.ConsumeEntry, + children: List[ScheduleTreeNode], + state: Optional[SDFGState] = None) -> None: + super().__init__(node, children, state) + def as_string(self, indent: int = 0): node: nodes.ConsumeEntry = self.node cond = 'stream not empty' if node.consume.condition is None else node.consume.condition.as_string @@ -852,6 +932,27 @@ def validate_has_no_other_node_types(stree: ScheduleTreeScope) -> None: validate_has_no_other_node_types(child) +def validate_children_and_parents_align(stree: ScheduleTreeScope, *, root: bool = False) -> None: + """ + Validates the child/parent information of schedule tree scopes are consistent. + + Walks through all children of a scope and raises if the children's parent isn't + the scope. If `root` is true, we additionally check that the top-most scope is + of type `ScheduleTreeRoot`. + + :param stree: Schedule tree scope to be analyzed + :param root: If true, we raise if the top-most scope isn't of type `ScheduleTreeRoot`. + """ + if root and not isinstance(stree, ScheduleTreeRoot): + raise RuntimeError("Expected schedule tree root.") + + for child in stree.children: + if id(child.parent) != id(stree): + raise RuntimeError(f"Inconsistent parent/child relationship. child: {child}, parent: {stree}") + if isinstance(child, ScheduleTreeScope): + validate_children_and_parents_align(child) + + def loop_variant( loop: LoopRegion ) -> Union[Literal['for'], Literal['while'], Literal['do-while'], Literal['do-for-uncond-increment'], From db1c1b85a3efcd31ae0783048ea2b5b05f61f379 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 19 Jan 2026 10:38:21 +0100 Subject: [PATCH 116/137] WIP: fix parent/children relationship in scope nodes --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 57996dc814..ca2271c6e5 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -554,6 +554,10 @@ def _state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: result = subnodes elif isinstance(node, dace.nodes.ExitNode): result = scopes.pop() + parent = result[-1] + assert isinstance(parent, tn.ScheduleTreeScope) + for child in parent.children: + child.parent = parent elif isinstance(node, dace.nodes.NestedSDFG): nested_array_mapping_input = {} nested_array_mapping_output = {} From d817afe7e62117ff24093010b47fe94d1f431db0 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 19 Jan 2026 12:20:16 +0100 Subject: [PATCH 117/137] fix memlet propagation for indices --- dace/sdfg/analysis/schedule_tree/treenodes.py | 3 ++- dace/sdfg/propagation.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index dd2a68d29b..da4fe06ca1 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -673,7 +673,8 @@ def as_string(self, indent: int = 0): return result + super().as_string(indent) -@dataclass +# TODO: to be removed. looks like `Pipeline` nodes aren't a thing anymore +# @dataclass # class PipelineScope(MapScope): # """ # Pipeline scope. diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 594f8fbfd6..b752430354 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1536,7 +1536,7 @@ def propagate_subset(memlets: List[Memlet], fsyms = _freesyms(s) fsyms_str = set(map(str, fsyms)) contains_params = len(fsyms_str & paramset) != 0 - contains_undefs = len(fsyms - defined_variables) != 0 + contains_undefs = len(fsyms & undefined_variables) != 0 else: contains_params = False contains_undefs = False From 1729dddbfc780dfb1fa7a0b3ef656d3c9d5a37a1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 19 Jan 2026 14:56:34 +0100 Subject: [PATCH 118/137] easy fix for some tests --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 +- dace/sdfg/analysis/schedule_tree/treenodes.py | 3 ++- tests/schedule_tree/nesting_test.py | 4 ++-- tests/schedule_tree/schedule_test.py | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index ca2271c6e5..14c1dcd7fc 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -843,7 +843,7 @@ def as_schedule_tree(sdfg: SDFG, *, in_place: bool = False, toplevel: bool = Tru _prepare_sdfg_for_conversion(sdfg, toplevel=toplevel) if toplevel: - result = tn.ScheduleTreeRoot(name="my-stree-name", children=[]) + result = tn.ScheduleTreeRoot(name="default_stree_name", children=[]) _create_unified_descriptor_repository(sdfg, result) result.add_children(_block_schedule_tree(sdfg)) else: diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index da4fe06ca1..c56733c11d 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -115,7 +115,7 @@ class ScheduleTreeScope(ScheduleTreeNode): def __init__(self, children: List['ScheduleTreeNode']) -> None: for child in children: child.parent = self - + self.children = children def add_children(self, children: Iterable['ScheduleTreeNode']) -> None: @@ -686,6 +686,7 @@ def as_string(self, indent: int = 0): # result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' # return result + super().as_string(indent) + @dataclass class TaskletNode(ScheduleTreeNode): node: nodes.Tasklet diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index 59512f88ab..a920b07067 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -63,8 +63,8 @@ def tester(A: dace.float64[N, N]): simplified = dace.Config.get_bool('optimizer', 'automatic_simplification') if simplified: - assert [type(n) - for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.LoopScope, tn.TaskletNode] + node_types = [type(n) for n in stree.preorder_traversal()][1:] + assert node_types == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index c15eb99f88..1f929bd779 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -158,7 +158,7 @@ def test_irreducible_sub_sdfg(): stree = as_schedule_tree(sdfg) node_types = [type(n) for n in stree.preorder_traversal()] assert node_types.count(tn.GBlock) == 1 # Only one gblock - assert node_types.count(tn.LoopScope) == 1 # Check that the loop was detected + assert node_types.count(tn.ForScope) == 1 # Check that the for-loop was detected def test_irreducible_in_loops(): From ee96f6702c8484233cb847d46ed2c2263c8bda6e Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 22 Jan 2026 12:01:24 +0100 Subject: [PATCH 119/137] Explicit constructors for treenodes / simple test for children of scopes --- dace/sdfg/analysis/schedule_tree/treenodes.py | 112 ++++++++++------ tests/schedule_tree/to_sdfg_test.py | 37 +++--- tests/schedule_tree/treenodes_test.py | 124 ++++++++++++++++++ 3 files changed, 219 insertions(+), 54 deletions(-) create mode 100644 tests/schedule_tree/treenodes_test.py diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index c56733c11d..2632377e95 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -110,15 +110,16 @@ def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> @dataclass class ScheduleTreeScope(ScheduleTreeNode): - children: List['ScheduleTreeNode'] + children: List[ScheduleTreeNode] - def __init__(self, children: List['ScheduleTreeNode']) -> None: + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional['ScheduleTreeScope'] = None) -> None: for child in children: child.parent = self self.children = children + self.parent = parent - def add_children(self, children: Iterable['ScheduleTreeNode']) -> None: + def add_children(self, children: Iterable[ScheduleTreeNode]) -> None: for child in children: child.parent = self self.children.append(child) @@ -242,11 +243,11 @@ class ScheduleTreeRoot(ScheduleTreeScope): the available descriptors, symbol types, and constants of the tree, aka the descriptor repository. """ name: str - containers: Dict[str, data.Data] = field(default_factory=dict) - symbols: Dict[str, dtypes.typeclass] = field(default_factory=dict) - constants: Dict[str, Tuple[data.Data, Any]] = field(default_factory=dict) - callback_mapping: Dict[str, str] = field(default_factory=dict) - arg_names: List[str] = field(default_factory=list) + containers: Dict[str, data.Data] + symbols: Dict[str, dtypes.typeclass] + constants: Dict[str, Tuple[data.Data, Any]] + callback_mapping: Dict[str, str] + arg_names: List[str] def __init__( self, @@ -256,16 +257,17 @@ def __init__( containers: Optional[Dict[str, data.Data]] = None, symbols: Optional[Dict[str, dtypes.typeclass]] = None, constants: Optional[Dict[str, Tuple[data.Data, Any]]] = None, + callback_mapping: Optional[Dict[str, str]] = None, + arg_names: Optional[List[str]] = None, ) -> None: - super().__init__(children) + super().__init__(children=children, parent=None) self.name = name - if containers is not None: - self.containers = containers - if symbols is not None: - self.symbols = symbols - if constants is not None: - self.constants = constants + self.containers = containers if containers is not None else dict() + self.symbols = symbols if symbols is not None else dict() + self.constants = constants if constants is not None else dict() + self.callback_mapping = callback_mapping if callback_mapping is not None else dict() + self.arg_names = arg_names if arg_names is not None else list() def as_sdfg(self, validate: bool = True, @@ -308,8 +310,8 @@ def scope(self, state: SDFGState, ctx: Context) -> ContextPushPop: @dataclass class ControlFlowScope(ScheduleTreeScope): - def __init__(self, children: List[ScheduleTreeNode]) -> None: - super().__init__(children) + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) @dataclass @@ -318,10 +320,12 @@ class DataflowScope(ScheduleTreeScope): state: Optional[SDFGState] = None def __init__(self, + *, node: nodes.EntryNode, children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None, state: Optional[SDFGState] = None) -> None: - super().__init__(children) + super().__init__(children=children, parent=parent) self.node = node self.state = state @@ -338,8 +342,8 @@ class GBlock(ControlFlowScope): Normally contains irreducible control flow. """ - def __init__(self, children: List[ScheduleTreeNode]) -> None: - super().__init__(children) + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) def as_string(self, indent: int = 0): result = indent * INDENTATION + 'gblock:\n' @@ -402,8 +406,12 @@ class LoopScope(ControlFlowScope): """ loop: LoopRegion - def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: - super().__init__(children) + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) self.loop = loop @@ -435,8 +443,12 @@ def as_string(self, indent: int = 0): class ForScope(LoopScope): """Specialized LoopScope for for-loops.""" - def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: - super().__init__(loop, children) + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(loop=loop, children=children, parent=parent) def as_string(self, indent: int = 0) -> str: init_statement = self.loop.init_statement.as_string @@ -474,8 +486,12 @@ def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> class WhileScope(LoopScope): """Specialized LoopScope for while-loops.""" - def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: - super().__init__(loop, children) + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(loop=loop, children=children, parent=parent) def as_string(self, indent: int = 0) -> str: condition = self.loop.loop_condition.as_string @@ -494,8 +510,12 @@ def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> class DoWhileScope(LoopScope): """Specialized LoopScope for do-while-loops""" - def __init__(self, loop: LoopRegion, children: List[ScheduleTreeNode]) -> None: - super().__init__(loop, children) + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(loop=loop, children=children, parent=parent) def as_string(self, indent: int = 0) -> str: header = indent * INDENTATION + 'do:\n' @@ -517,8 +537,12 @@ class IfScope(ControlFlowScope): """ condition: CodeBlock - def __init__(self, condition: CodeBlock, children: List[ScheduleTreeNode]) -> None: - super().__init__(children) + def __init__(self, + *, + condition: CodeBlock, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) self.condition = condition @@ -540,8 +564,12 @@ class StateIfScope(IfScope): A special class of an if scope in general blocks for if statements that are part of a state transition. """ - def __init__(self, condition: CodeBlock, children: List[ScheduleTreeNode]) -> None: - super().__init__(condition, children) + def __init__(self, + *, + condition: CodeBlock, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(condition=condition, children=children, parent=parent) def as_string(self, indent: int = 0): result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' @@ -587,8 +615,12 @@ class ElifScope(ControlFlowScope): """ condition: CodeBlock - def __init__(self, condition: CodeBlock, children: List[ScheduleTreeNode]) -> None: - super().__init__(children) + def __init__(self, + *, + condition: CodeBlock, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) self.condition = condition @@ -610,8 +642,8 @@ class ElseScope(ControlFlowScope): Else branch scope. """ - def __init__(self, children: List[ScheduleTreeNode]) -> None: - super().__init__(children) + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) def as_string(self, indent: int = 0): result = indent * INDENTATION + 'else:\n' @@ -626,10 +658,12 @@ class MapScope(DataflowScope): node: nodes.MapEntry def __init__(self, + *, node: nodes.MapEntry, children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None, state: Optional[SDFGState] = None) -> None: - super().__init__(node, children, state) + super().__init__(node=node, state=state, children=children, parent=parent) def as_string(self, indent: int = 0): rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) @@ -661,10 +695,12 @@ class ConsumeScope(DataflowScope): node: nodes.ConsumeEntry def __init__(self, + *, node: nodes.ConsumeEntry, children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None, state: Optional[SDFGState] = None) -> None: - super().__init__(node, children, state) + super().__init__(node=node, state=state, children=children, parent=parent) def as_string(self, indent: int = 0): node: nodes.ConsumeEntry = self.node diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 1f0590c912..6ce45fa584 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -145,9 +145,15 @@ def test_state_boundaries_cfg(): }, children=[ tn.TaskletNode(nodes.Tasklet('bla1', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), - tn.ForScope([ - tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet('A[1]')}), - ], cf.ForScope(None, None, True, 'i', None, '0', CodeBlock('i < 20'), 'i + 1', None, [])), + tn.ForScope(loop=cf.LoopRegion(label="for-loop", + condition_expr=CodeBlock("i < 20"), + loop_var="i", + initialize_expr=CodeBlock("i=0"), + update_expr=CodeBlock("i = i+1")), + children=[ + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = i'), {}, + {'out': dace.Memlet('A[1]')}), + ]), ], ) @@ -361,19 +367,18 @@ def test_create_while_loop(): def test_create_if_else(): # Manually create a schedule tree - stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': data.Array(dace.float64, [20])}, - children=[ - tn.IfScope(condition=CodeBlock("A[0] > 0"), - children=[ - tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, - {"out": dace.Memlet("A[1]")}), - ]), - tn.ElseScope([ - tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, - {"out": dace.Memlet("A[1]")}) - ]) - ]) + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.IfScope(condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ]), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, {"out": dace.Memlet("A[1]")}) + ]) + ]) sdfg = stree.as_sdfg() sdfg.validate() diff --git a/tests/schedule_tree/treenodes_test.py b/tests/schedule_tree/treenodes_test.py new file mode 100644 index 0000000000..28c5004974 --- /dev/null +++ b/tests/schedule_tree/treenodes_test.py @@ -0,0 +1,124 @@ +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace import nodes + +import pytest + + +@pytest.fixture +def tasklet() -> nodes.Tasklet: + return tn.TaskletNode(nodes.Tasklet("noop", {}, {}, code="pass"), {}, {}) + + +@pytest.mark.parametrize('ScopeClass', ( + tn.ScheduleTreeScope, + tn.ControlFlowScope, + tn.GBlock, + tn.ElseScope, +)) +def test_schedule_tree_scope_children(ScopeClass: type[tn.ScheduleTreeScope], tasklet: nodes.Tasklet) -> None: + scope = ScopeClass(children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = ScopeClass(children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = ScopeClass(children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +@pytest.mark.parametrize('LoopScope', ( + tn.LoopScope, + tn.ForScope, + tn.WhileScope, + tn.DoWhileScope, +)) +def test_loop_scope_children(LoopScope: type[tn.LoopScope], tasklet: nodes.Tasklet) -> None: + scope = LoopScope(loop=None, children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = LoopScope(loop=None, children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = LoopScope(loop=None, children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +@pytest.mark.parametrize('IfScope', ( + tn.IfScope, + tn.StateIfScope, + tn.ElifScope, +)) +def test_if_scope_children(IfScope: type[tn.IfScope], tasklet: nodes.Tasklet) -> None: + scope = IfScope(condition=None, children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = IfScope(condition=None, children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = IfScope(condition=None, children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +@pytest.mark.parametrize('DataflowScope', ( + tn.DataflowScope, + tn.MapScope, + tn.ConsumeScope, +)) +def test_dataflow_scope_children(DataflowScope: type[tn.DataflowScope], tasklet: nodes.Tasklet) -> None: + scope = DataflowScope(node=None, children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = DataflowScope(node=None, children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = DataflowScope(node=None, children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +if __name__ == '__main__': + test_schedule_tree_scope_children(tn.ScheduleTreeScope, tasklet) + test_schedule_tree_scope_children(tn.ControlFlowScope, tasklet) + test_schedule_tree_scope_children(tn.GBlock, tasklet) + test_schedule_tree_scope_children(tn.ElseScope, tasklet) + test_loop_scope_children(tn.LoopScope, tasklet) + test_loop_scope_children(tn.ForScope, tasklet) + test_loop_scope_children(tn.WhileScope, tasklet) + test_loop_scope_children(tn.DoWhileScope, tasklet) + test_if_scope_children(tn.IfScope, tasklet) + test_if_scope_children(tn.StateIfScope, tasklet) + test_if_scope_children(tn.ElifScope, tasklet) + test_dataflow_scope_children(tn.DataflowScope, tasklet) + test_dataflow_scope_children(tn.MapScope, tasklet) + test_dataflow_scope_children(tn.ConsumeScope, tasklet) From f81abc9dba6336583a32106060ffa523cc58da47 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 22 Jan 2026 17:54:18 +0100 Subject: [PATCH 120/137] WIP: a bit of cleanup and the first new-style tests --- .../analysis/schedule_tree/tree_to_sdfg.py | 108 +++++------ tests/schedule_tree/to_sdfg_test.py | 171 +++++------------- 2 files changed, 98 insertions(+), 181 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 9eb61669f4..2ca9061c12 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -1,7 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy from collections import defaultdict -from dace import symbolic, data +from dace import symbolic from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge @@ -176,59 +176,55 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: if memlet.data not in sdfg.arrays: raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") - #def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: - # before_state = self._current_state - # pending = self._pending_interstate_assignments() - # pending[node.header.itervar] = node.header.init - # - # guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) - # self._current_state = guard_state - # - # body_state = sdfg.add_state(label="loop_body") - # self._current_state = body_state - # sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) - # - # # visit children inside the loop - # self.visit(node.children, sdfg=sdfg) - # - # pending = self._pending_interstate_assignments() - # pending[node.header.itervar] = node.header.update - # _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) - # - # after_state = sdfg.add_state(label="loop_after") - # self._current_state = after_state - # sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) - - #def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: - # before_state = self._current_state - # guard_state = _insert_and_split_assignments(sdfg, - # before_state, - # label="guard_state", - # assignments=self._pending_interstate_assignments()) - # self._current_state = guard_state - # - # body_state = sdfg.add_state(label="loop_body") - # self._current_state = body_state - # sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) - # - # # visit children inside the loop - # self.visit(node.children, sdfg=sdfg) - # _insert_and_split_assignments(sdfg, - # before_state=self._current_state, - # after_state=guard_state, - # assignments=self._pending_interstate_assignments()) - # - # after_state = sdfg.add_state(label="loop_after") - # self._current_state = after_state - # sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) - - #def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: - # # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. - # raise NotImplementedError(f"{type(node)} not implemented") - - #def visit_GeneralLoopScope(self, node: tn.GeneralLoopScope, sdfg: SDFG) -> None: - # # Let's see if we need this for the first prototype ... - # raise NotImplementedError(f"{type(node)} not implemented") + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + before_state = self._current_state + pending = self._pending_interstate_assignments() + pending[node.header.itervar] = node.header.init + + guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) + self._current_state = guard_state + + body_state = sdfg.add_state(label="loop_body") + self._current_state = body_state + sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) + + # visit children inside the loop + self.visit(node.children, sdfg=sdfg) + + pending = self._pending_interstate_assignments() + pending[node.header.itervar] = node.header.update + _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) + + after_state = sdfg.add_state(label="loop_after") + self._current_state = after_state + sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) + + def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + before_state = self._current_state + guard_state = _insert_and_split_assignments(sdfg, + before_state, + label="guard_state", + assignments=self._pending_interstate_assignments()) + self._current_state = guard_state + + body_state = sdfg.add_state(label="loop_body") + self._current_state = body_state + sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) + + # visit children inside the loop + self.visit(node.children, sdfg=sdfg) + _insert_and_split_assignments(sdfg, + before_state=self._current_state, + after_state=guard_state, + assignments=self._pending_interstate_assignments()) + + after_state = sdfg.add_state(label="loop_after") + self._current_state = after_state + sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) + + def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: + # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. + raise NotImplementedError(f"{type(node)} not implemented") def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: raise NotImplementedError("TODO: LoopScopes are not yet implemented") @@ -579,10 +575,6 @@ def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: # AFAIK we don't support consume scopes in the gt4py/dace bridge. raise NotImplementedError(f"{type(node)} not implemented") - # def visit_PipelineScope(self, node: tn.PipelineScope, sdfg: SDFG) -> None: - # # AFAIK we don't support pipeline scopes in the gt4py/dace bridge. - # raise NotImplementedError(f"{type(node)} not implemented") - def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add Tasklet to current state tasklet = node.node diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 6ce45fa584..4dca6dca58 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -10,6 +10,8 @@ from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn import pytest +from dace.sdfg.state import ConditionalBlock, LoopRegion + def test_state_boundaries_none(): # Manually create a schedule tree @@ -239,7 +241,6 @@ def test_create_state_boundary_empty_memlet(control_flow): def test_create_tasklet_raw(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name='tester', containers={ @@ -270,7 +271,6 @@ def test_create_tasklet_raw(): def test_create_tasklet_waw(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name='tester', containers={ @@ -294,7 +294,6 @@ def test_create_tasklet_waw(): def test_create_tasklet_war(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={"A": data.Array(dace.float64, [20])}, @@ -320,20 +319,18 @@ def test_create_tasklet_war(): def test_create_for_loop(): - # yapf: disable - loop=tn.ForScope( - children=[ - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), - ], - header=cf.ForScope( - itervar="i", init="0", condition=CodeBlock("i<3"), update="i+1", - dispatch_state=None, parent=None, last_block=True, guard=None, body=None, init_edges=[] - ) - ) - # yapf: enable + loop = tn.ForScope(loop=LoopRegion(label="my_for_loop", + loop_var="i", + initialize_expr=CodeBlock("i = 0 "), + condition_expr=CodeBlock("i < 3"), + update_expr=CodeBlock("i = i+1")), + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[1]')}), + ]) - # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[loop]) sdfg = stree.as_sdfg() @@ -341,24 +338,15 @@ def test_create_for_loop(): def test_create_while_loop(): - # yapf: disable - loop=tn.WhileScope( - children=[ - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), - ], - header=cf.WhileScope( - test=CodeBlock("A[1] > 5"), - dispatch_state=None, - last_block=True, - parent=None, - guard=None, - body=None - ) - ) - # yapf: enable + loop = tn.WhileScope(children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + loop=LoopRegion( + label="my_while_loop", + condition_expr=CodeBlock("A[1] > 5"), + )) - # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[loop]) sdfg = stree.as_sdfg() @@ -366,7 +354,6 @@ def test_create_while_loop(): def test_create_if_else(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={'A': data.Array(dace.float64, [20])}, @@ -381,11 +368,25 @@ def test_create_if_else(): ]) sdfg = stree.as_sdfg() - sdfg.validate() + + blocks = list(filter(lambda x: isinstance(x, ConditionalBlock), sdfg.cfg_list)) + assert len(blocks) == 1, "SDFG contains one ConditionalBlock" + + block: ConditionalBlock = blocks[0] + assert len(block.branches) == 2, "Block contains two branches" + + if_branch = list(filter(lambda x: x[0] is not None, block.branches))[0] + tasklets = list(filter(lambda x: isinstance(x, nodes.Tasklet), if_branch[1].nodes()[0].nodes())) + assert if_branch[0] == CodeBlock("A[0] > 0"), "If branch has condition" + assert tasklets[0].label == "bla", "If branch contains Tasklet('bla')" + + else_branch = list(filter(lambda x: x[0] is None, block.branches))[0] + tasklets = list(filter(lambda x: isinstance(x, nodes.Tasklet), else_branch[1].nodes()[0].nodes())) + assert else_branch[0] is None, "Else branch has no condition" + assert tasklets[0].label == "blub", "Else branch contains Tasklet('blub')" def test_create_if_without_else(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", containers={'A': data.Array(dace.float64, [20])}, children=[ @@ -397,11 +398,20 @@ def test_create_if_without_else(): ]) sdfg = stree.as_sdfg() - sdfg.validate() + + blocks = list(filter(lambda x: isinstance(x, ConditionalBlock), sdfg.cfg_list)) + assert len(blocks) == 1, "SDFG contains one ConditionalBlock" + + block: ConditionalBlock = blocks[0] + assert len(block.branches) == 1, "Block contains one branch" + + branch = list(filter(lambda x: x[0] is not None, block.branches))[0] + tasklets = list(filter(lambda x: isinstance(x, nodes.Tasklet), branch[1].nodes()[0].nodes())) + assert branch[0] == CodeBlock("A[0] > 0"), "Branch has condition" + assert tasklets[0].label == "bla", "Branch contains Tasklet('bla')" def test_create_map_scope_write(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", containers={'A': data.Array(dace.float64, [20])}, children=[ @@ -418,7 +428,6 @@ def test_create_map_scope_write(): def test_create_map_scope_read_after_write(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={ @@ -440,7 +449,6 @@ def test_create_map_scope_read_after_write(): def test_create_map_scope_write_after_read(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={"A": data.Array(dace.float64, [20])}, @@ -457,7 +465,6 @@ def test_create_map_scope_write_after_read(): def test_create_map_scope_copy(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", containers={ 'A': data.Array(dace.float64, [20]), @@ -478,7 +485,6 @@ def test_create_map_scope_copy(): def test_create_map_scope_double_memlet(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={ @@ -500,7 +506,6 @@ def test_create_map_scope_double_memlet(): def test_create_nested_map_scope(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={'A': data.Array(dace.float64, [20])}, @@ -520,7 +525,6 @@ def test_create_nested_map_scope(): def test_create_nested_map_scope_multi_read(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={ @@ -546,7 +550,6 @@ def test_create_nested_map_scope_multi_read(): def test_map_with_state_boundary_inside(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot(name="tester", containers={'A': data.Array(dace.float64, [20])}, children=[ @@ -565,7 +568,6 @@ def test_map_with_state_boundary_inside(): def test_map_calculate_temporary_in_two_loops(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name="tester", containers={ @@ -613,7 +615,6 @@ def test_edge_assignment_read_after_write(): def test_assign_nodes_force_state_transition(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name='tester', containers={ @@ -630,7 +631,6 @@ def test_assign_nodes_force_state_transition(): def test_assign_nodes_multiple_force_one_transition(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name='tester', containers={ @@ -650,7 +650,6 @@ def test_assign_nodes_multiple_force_one_transition(): def test_assign_nodes_avoid_duplicate_boundaries(): - # Manually create a schedule tree stree = tn.ScheduleTreeRoot( name='tester', containers={ @@ -668,80 +667,6 @@ def test_assign_nodes_avoid_duplicate_boundaries(): assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] -def test_XPPM_tmp(): - loaded = dace.SDFG.from_file("tmp_XPPM.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_DelnFluxNoSG_tmp(): - loaded = dace.SDFG.from_file("tmp_DelnFluxNoSG.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_DelnFlux_tmp(): - loaded = dace.SDFG.from_file("tmp_DelnFlux.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_FvTp2d_tmp(): - loaded = dace.SDFG.from_file("tmp_FvTp2d.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_FxAdv_tmp(): - loaded = dace.SDFG.from_file("tmp_FxAdv.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_D_SW_tmp(): - loaded = dace.SDFG.from_file("tmp_D_SW.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_UpdateDzD_tmp(): - loaded = dace.SDFG.from_file("tmp_UpdateDzD-ConstantPropagation.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_Fillz_tmp(): - loaded = dace.SDFG.from_file("tmp_Fillz.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -def test_Ray_Fast_tmp(): - loaded = dace.SDFG.from_file("tmp_Ray_Fast.sdfgz") - stree = loaded.as_schedule_tree() - - sdfg = stree.as_sdfg() - sdfg.validate() - - -# TODO: find an automatic way to test stuff here - if __name__ == '__main__': test_state_boundaries_none() test_state_boundaries_waw() From 8e3cc90973fe2d744d314cdceeba28e709cbd128 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Fri, 23 Jan 2026 17:12:34 +0100 Subject: [PATCH 121/137] WIP: use ConditionalBlock for if/else scopes --- .../analysis/schedule_tree/tree_to_sdfg.py | 64 ++++++++++--------- .../passes/dead_state_elimination.py | 7 +- tests/schedule_tree/to_sdfg_test.py | 4 ++ 3 files changed, 44 insertions(+), 31 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 2ca9061c12..ab9f2138e1 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -5,7 +5,7 @@ from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg import propagation from enum import Enum, auto @@ -232,23 +232,27 @@ def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: before_state = self._current_state - # add guard state - guard_state = _insert_and_split_assignments(sdfg, - before_state, - label="guard_state", - assignments=self._pending_interstate_assignments()) + conditional_block = ConditionalBlock(f"if_scope_{id(node)}") + sdfg.add_node(conditional_block) + _insert_and_split_assignments(sdfg, + before_state, + conditional_block, + assignments=self._pending_interstate_assignments()) + + if_body = ControlFlowRegion("if_body", sdfg=sdfg) + conditional_block.add_branch(node.condition, if_body) - # add true_state - true_state = sdfg.add_state(label="true_state") - sdfg.add_edge(guard_state, true_state, InterstateEdge(condition=node.condition)) - self._current_state = true_state + if_state = if_body.add_state("if_state", is_start_block=True) + self._current_state = if_state - # visit children in the true branch + # visit children of that branch self.visit(node.children, sdfg=sdfg) + self._current_state = conditional_block + # add merge_state merge_state = _insert_and_split_assignments(sdfg, - self._current_state, + conditional_block, label="merge_state", assignments=self._pending_interstate_assignments()) @@ -261,14 +265,9 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: if has_else_branch: # push merge_state on the stack for later usage in `visit_ElseScope` self._state_stack.append(merge_state) - false_state = sdfg.add_state(label="false_state") - - sdfg.add_edge(guard_state, false_state, InterstateEdge(condition=f"not {node.condition.as_string}")) - - # push false_state on the stack for later usage in `visit_ElseScope` - self._state_stack.append(false_state) + # push condition_block on the stack for later usage in `visit_ElseScope` + self._state_stack.append(conditional_block) else: - sdfg.add_edge(guard_state, merge_state, InterstateEdge(condition=f"not {node.condition.as_string}")) self._current_state = merge_state def visit_StateIfScope(self, node: tn.StateIfScope, sdfg: SDFG) -> None: @@ -288,21 +287,25 @@ def visit_ElifScope(self, node: tn.ElifScope, sdfg: SDFG) -> None: raise NotImplementedError(f"{type(node)} not implemented") def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: - # get false_state form stack - false_state = self._pop_state("false_state") - self._current_state = false_state + # get ConditionalBlock form stack + conditional_block: ConditionalBlock = self._pop_state("if_scope") + + else_body = ControlFlowRegion("else_body", sdfg=sdfg) + conditional_block.add_branch(None, else_body) + + else_state = else_body.add_state("else_state", is_start_block=True) + self._current_state = else_state # visit children inside the else branch self.visit(node.children, sdfg=sdfg) # merge false-branch into merge_state merge_state = self._pop_state("merge_state") - _insert_and_split_assignments(sdfg, - before_state=self._current_state, - after_state=merge_state, - assignments=self._pending_interstate_assignments()) self._current_state = merge_state + if self._pending_interstate_assignments(): + raise NotImplementedError("TODO: update edge with new assignments") + def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: dataflow_stack_size = len(self._dataflow_stack) state_stack_size = len(self._state_stack) @@ -809,7 +812,7 @@ def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTr class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer): def visit_scope(self, scope: tn.ScheduleTreeScope): - if isinstance(scope, tn.ControlFlowScope): + if isinstance(scope, tn.ControlFlowScope) and not isinstance(scope, (tn.ElifScope, tn.ElseScope)): return [tn.StateBoundaryNode(True), self.generic_visit(scope)] return self.generic_visit(scope) @@ -971,10 +974,11 @@ def create_state_boundary(boundary_node: tn.StateBoundaryNode, def _insert_and_split_assignments(sdfg_region: ControlFlowRegion, - before_state: SDFGState, - after_state: Optional[SDFGState] = None, + before_state: ControlFlowBlock, + after_state: Optional[ControlFlowBlock] = None, + *, label: Optional[str] = None, - assignments: Optional[Dict] = None) -> SDFGState: + assignments: Optional[Dict] = None) -> ControlFlowBlock: """ Insert given assignments splitting them in case of potential race conditions. diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index cc46470d8e..20bd357c10 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -169,7 +169,12 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, for i, (cond, branch) in enumerate(block.branches): if cond is None: if not i == len(block.branches) - 1: - raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch') + raise InvalidSDFGNodeError( + 'Conditional block detected, where else branch is not the last branch', + sdfg=block.root_sdfg(), + state_id=block.block_id, + node_id=block.block_id, + ) break # If an unconditional branch is found, ignore all other branches that follow this one. if self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 4dca6dca58..09062ff459 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -386,6 +386,10 @@ def test_create_if_else(): assert tasklets[0].label == "blub", "Else branch contains Tasklet('blub')" +# TODO +# support for if_elif_else + + def test_create_if_without_else(): stree = tn.ScheduleTreeRoot(name="tester", containers={'A': data.Array(dace.float64, [20])}, From 6bba45c80415442827d89363a1c97a3530d3235b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 26 Jan 2026 11:53:14 +0100 Subject: [PATCH 122/137] WIP: Add support for basic ForScopes --- .../analysis/schedule_tree/tree_to_sdfg.py | 32 +++----- tests/schedule_tree/to_sdfg_test.py | 75 ++++++++++++++----- 2 files changed, 70 insertions(+), 37 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index ab9f2138e1..0ddbd2201f 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -5,7 +5,7 @@ from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge -from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, LoopRegion, SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg import propagation from enum import Enum, auto @@ -177,27 +177,17 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: - before_state = self._current_state - pending = self._pending_interstate_assignments() - pending[node.header.itervar] = node.header.init - - guard_state = _insert_and_split_assignments(sdfg, before_state, label="loop_guard", assignments=pending) - self._current_state = guard_state + loop_region = node.loop + loop_state = loop_region.add_state(f"for_loop_state_{id(node)}", is_start_block=True) + current_state = self._current_state - body_state = sdfg.add_state(label="loop_body") - self._current_state = body_state - sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.condition)) + _insert_and_split_assignments(sdfg, current_state, loop_region) - # visit children inside the loop - self.visit(node.children, sdfg=sdfg) + self._current_state = loop_state + self.visit(node.children, sdfg=loop_region) - pending = self._pending_interstate_assignments() - pending[node.header.itervar] = node.header.update - _insert_and_split_assignments(sdfg, self._current_state, after_state=guard_state, assignments=pending) - - after_state = sdfg.add_state(label="loop_after") + after_state = _insert_and_split_assignments(sdfg, loop_region, label="loop_after") self._current_state = after_state - sdfg.add_edge(guard_state, after_state, InterstateEdge(condition=f"not {node.header.condition.as_string}")) def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: before_state = self._current_state @@ -246,7 +236,7 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: self._current_state = if_state # visit children of that branch - self.visit(node.children, sdfg=sdfg) + self.visit(node.children, sdfg=if_body) self._current_state = conditional_block @@ -297,7 +287,7 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: self._current_state = else_state # visit children inside the else branch - self.visit(node.children, sdfg=sdfg) + self.visit(node.children, sdfg=else_body) # merge false-branch into merge_state merge_state = self._pop_state("merge_state") @@ -992,6 +982,8 @@ def _insert_and_split_assignments(sdfg_region: ControlFlowRegion, weaken (best case remove) the corresponding check from the sdfg validator. """ + assignments = assignments if assignments is not None else {} + has_potential_race = False for key, value in assignments.items(): syms = symbolic.free_symbols_and_functions(value) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 09062ff459..7b8fa15630 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -10,7 +10,7 @@ from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn import pytest -from dace.sdfg.state import ConditionalBlock, LoopRegion +from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState def test_state_boundaries_none(): @@ -319,22 +319,38 @@ def test_create_tasklet_war(): def test_create_for_loop(): - loop = tn.ForScope(loop=LoopRegion(label="my_for_loop", - loop_var="i", - initialize_expr=CodeBlock("i = 0 "), - condition_expr=CodeBlock("i < 3"), - update_expr=CodeBlock("i = i+1")), - children=[ - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, - {'out': dace.Memlet('A[1]')}), - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, - {'out': dace.Memlet('A[1]')}), - ]) - - stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[loop]) + for_scope = tn.ForScope(loop=LoopRegion(label="my_for_loop", + loop_var="i", + initialize_expr=CodeBlock("i = 0 "), + condition_expr=CodeBlock("i < 3"), + update_expr=CodeBlock("i = i+1")), + children=[ + tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[1]')}), + ]) + + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[for_scope]) sdfg = stree.as_sdfg() - sdfg.validate() + + loops = list(filter(lambda x: isinstance(x, LoopRegion), sdfg.cfg_list)) + assert len(loops) == 1, "SDFG contains one LoopRegion" + + loop: LoopRegion = loops[0] + assert loop.loop_variable == "i" + assert loop.init_statement == CodeBlock("i = 0") + assert loop.loop_condition == CodeBlock("i < 3") + assert loop.update_statement == CodeBlock("i = i+1") + + loop_states = list(filter(lambda x: isinstance(x, SDFGState), loop.nodes())) + assert len(loop_states) == 2, "Loop contains two states" + + tasklet_1: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[0].nodes()))[0] + assert tasklet_1.label == "assign_1" + tasklet_2: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[1].nodes()))[0] + assert tasklet_2.label == "assign_2" def test_create_while_loop(): @@ -386,8 +402,33 @@ def test_create_if_else(): assert tasklets[0].label == "blub", "Else branch contains Tasklet('blub')" -# TODO -# support for if_elif_else +@pytest.mark.xfail(reason="Not yet implemented") +def test_create_if_elif_else() -> None: + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.IfScope(condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ]), + tn.ElifScope(condition=CodeBlock("A[0] == 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, + {"out": dace.Memlet("A[1]")}), + ]), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("test", {}, {"out"}, "out=3"), {}, {"out": dace.Memlet("A[1]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + + blocks = list(filter(lambda x: isinstance(x, ConditionalBlock), sdfg.cfg_list)) + assert len(blocks) == 1, "SDFG contains one ConditionalBlock" + + block: ConditionalBlock = blocks[0] + assert len(block.branches) == 3, "Block contains three branches" def test_create_if_without_else(): From 32d66d9d51743f5763dbd4cbbe21294119912a11 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 26 Jan 2026 12:08:24 +0100 Subject: [PATCH 123/137] WIP: support for simple while-loops --- .../analysis/schedule_tree/tree_to_sdfg.py | 27 ++++-------- tests/schedule_tree/to_sdfg_test.py | 43 +++++++++++++------ 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 0ddbd2201f..de0d6bb71a 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -190,27 +190,17 @@ def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: self._current_state = after_state def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: - before_state = self._current_state - guard_state = _insert_and_split_assignments(sdfg, - before_state, - label="guard_state", - assignments=self._pending_interstate_assignments()) - self._current_state = guard_state + loop_region = node.loop + loop_state = loop_region.add_state(f"while_loop_state_{id(node)}", is_start_block=True) + current_state = self._current_state - body_state = sdfg.add_state(label="loop_body") - self._current_state = body_state - sdfg.add_edge(guard_state, body_state, InterstateEdge(condition=node.header.test)) + _insert_and_split_assignments(sdfg, current_state, loop_region) - # visit children inside the loop - self.visit(node.children, sdfg=sdfg) - _insert_and_split_assignments(sdfg, - before_state=self._current_state, - after_state=guard_state, - assignments=self._pending_interstate_assignments()) + self._current_state = loop_state + self.visit(node.children, sdfg=loop_region) - after_state = sdfg.add_state(label="loop_after") + after_state = _insert_and_split_assignments(sdfg, loop_region, label="loop_after") self._current_state = after_state - sdfg.add_edge(guard_state, after_state, InterstateEdge(f"not {node.header.test.as_string}")) def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. @@ -952,8 +942,7 @@ def create_state_boundary(boundary_node: tn.StateBoundaryNode, :return: The newly created state. """ if behavior != StateBoundaryBehavior.STATE_TRANSITION: - # Only STATE_TRANSITION is supported as StateBoundaryBehavior in this prototype. - raise NotImplementedError + raise NotImplementedError("Only STATE_TRANSITION is supported as StateBoundaryBehavior in this prototype.") # TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every # behavior. Fall back to state transition in that case. diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 7b8fa15630..7dbc581e77 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -318,7 +318,7 @@ def test_create_tasklet_war(): if isinstance(node, nodes.AccessNode)] == ["A", "A"], "Expect two AccessNodes for A." -def test_create_for_loop(): +def test_create_loop_for(): for_scope = tn.ForScope(loop=LoopRegion(label="my_for_loop", loop_var="i", initialize_expr=CodeBlock("i = 0 "), @@ -332,7 +332,6 @@ def test_create_for_loop(): ]) stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[for_scope]) - sdfg = stree.as_sdfg() loops = list(filter(lambda x: isinstance(x, LoopRegion), sdfg.cfg_list)) @@ -353,20 +352,36 @@ def test_create_for_loop(): assert tasklet_2.label == "assign_2" -def test_create_while_loop(): - loop = tn.WhileScope(children=[ - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), +def test_create_loop_while(): + while_scope = tn.WhileScope(children=[ + tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), ], - loop=LoopRegion( - label="my_while_loop", - condition_expr=CodeBlock("A[1] > 5"), - )) + loop=LoopRegion( + label="my_while_loop", + condition_expr=CodeBlock("A[1] > 5"), + )) - stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[loop]) + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[while_scope]) sdfg = stree.as_sdfg() - sdfg.validate() + + loops = list(filter(lambda x: isinstance(x, LoopRegion), sdfg.cfg_list)) + assert len(loops) == 1, "SDFG contains one LoopRegion" + + loop: LoopRegion = loops[0] + assert loop.loop_variable == "" + assert loop.init_statement == None + assert loop.loop_condition == CodeBlock("A[1] > 5") + assert loop.update_statement == None + + loop_states = list(filter(lambda x: isinstance(x, SDFGState), loop.nodes())) + assert len(loop_states) == 2, "Loop contains two states" + + tasklet_1: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[0].nodes()))[0] + assert tasklet_1.label == "assign_1" + tasklet_2: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[1].nodes()))[0] + assert tasklet_2.label == "assign_2" def test_create_if_else(): @@ -730,8 +745,8 @@ def test_assign_nodes_avoid_duplicate_boundaries(): # test_create_state_boundary_empty_memlet(control_flow=False) test_create_tasklet_raw() test_create_tasklet_waw() - test_create_for_loop() - test_create_while_loop() + test_create_loop_for() + test_create_loop_while() test_create_if_else() test_create_if_without_else() test_create_map_scope_write() From 513038cf744d7cba4d837907a242909a997d3fb3 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 26 Jan 2026 13:44:03 +0100 Subject: [PATCH 124/137] WIP: fix issues with nested SDFGs --- dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index de0d6bb71a..2d2362bc63 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -315,12 +315,9 @@ def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: assert len(self._dataflow_stack) == dataflow_stack_size # insert nested SDFG - nsdfg = self._current_state.add_nested_sdfg(inner_sdfg, - sdfg, + nsdfg = self._current_state.add_nested_sdfg(sdfg=inner_sdfg, inputs=connectors["inputs"], - outputs=connectors["outputs"], - schedule=node.node.map.schedule) - + outputs=connectors["outputs"]) # connect nested SDFG to surrounding map scope assert self._dataflow_stack map_entry, to_connect = self._dataflow_stack[-1] From cd474d2286bca59397012d7aedce0ba278535455 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 26 Jan 2026 13:58:43 +0100 Subject: [PATCH 125/137] WIP: fix tests --- tests/schedule_tree/roundtrip_test.py | 2 +- tests/schedule_tree/to_sdfg_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/schedule_tree/roundtrip_test.py b/tests/schedule_tree/roundtrip_test.py index e4aea2a56a..22f178cbc3 100644 --- a/tests/schedule_tree/roundtrip_test.py +++ b/tests/schedule_tree/roundtrip_test.py @@ -38,7 +38,7 @@ def tester(A: dace.float64[20, 20]): # Test SDFG a = np.random.rand(20, 20) - new_sdfg(a) # Tests arg_names + new_sdfg(A=a) # Tests arg_names assert np.allclose(a, 1) diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 7dbc581e77..ebe9b15781 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -647,8 +647,8 @@ def test_map_calculate_temporary_in_two_loops(): ]), tn.MapScope(node=nodes.MapEntry(nodes.Map("read_tmp", "i", sbs.Range.from_string("0:20"))), children=[ - tn.TaskletNode(nodes.Tasklet("read_temp", {"tmp"}, {"out"}, "out = tmp + 1"), - {"tmp": dace.Memlet("tmp[i]")}, {"out": dace.Memlet("A[i]")}) + tn.TaskletNode(nodes.Tasklet("read_temp", {"read"}, {"out"}, "out = read + 1"), + {"read": dace.Memlet("tmp[i]")}, {"out": dace.Memlet("A[i]")}) ]) ]) From 492061ff65fd4e927f39f9f9eb7a1c8563f03dec Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 27 Jan 2026 15:49:28 +0100 Subject: [PATCH 126/137] fix: add loops to sdfg --- .../analysis/schedule_tree/tree_to_sdfg.py | 86 ++--- tests/schedule_tree/to_sdfg_test.py | 309 ++++++++++-------- 2 files changed, 221 insertions(+), 174 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 2d2362bc63..06b5b8b705 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -178,28 +178,30 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: loop_region = node.loop + sdfg.add_node(loop_region) loop_state = loop_region.add_state(f"for_loop_state_{id(node)}", is_start_block=True) current_state = self._current_state - _insert_and_split_assignments(sdfg, current_state, loop_region) + _insert_and_split_assignments(current_state, loop_region) self._current_state = loop_state - self.visit(node.children, sdfg=loop_region) + self.visit(node.children, sdfg=sdfg) - after_state = _insert_and_split_assignments(sdfg, loop_region, label="loop_after") + after_state = _insert_and_split_assignments(loop_region, label="loop_after") self._current_state = after_state def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: loop_region = node.loop + sdfg.add_node(loop_region) loop_state = loop_region.add_state(f"while_loop_state_{id(node)}", is_start_block=True) current_state = self._current_state - _insert_and_split_assignments(sdfg, current_state, loop_region) + _insert_and_split_assignments(current_state, loop_region) self._current_state = loop_state - self.visit(node.children, sdfg=loop_region) + self.visit(node.children, sdfg=sdfg) - after_state = _insert_and_split_assignments(sdfg, loop_region, label="loop_after") + after_state = _insert_and_split_assignments(loop_region, label="loop_after") self._current_state = after_state def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: @@ -214,10 +216,11 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: conditional_block = ConditionalBlock(f"if_scope_{id(node)}") sdfg.add_node(conditional_block) - _insert_and_split_assignments(sdfg, - before_state, - conditional_block, - assignments=self._pending_interstate_assignments()) + _insert_and_split_assignments( + before_state, + conditional_block, + assignments=self._pending_interstate_assignments(), + ) if_body = ControlFlowRegion("if_body", sdfg=sdfg) conditional_block.add_branch(node.condition, if_body) @@ -226,15 +229,16 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: self._current_state = if_state # visit children of that branch - self.visit(node.children, sdfg=if_body) + self.visit(node.children, sdfg=sdfg) self._current_state = conditional_block # add merge_state - merge_state = _insert_and_split_assignments(sdfg, - conditional_block, - label="merge_state", - assignments=self._pending_interstate_assignments()) + merge_state = _insert_and_split_assignments( + conditional_block, + label="merge_state", + assignments=self._pending_interstate_assignments(), + ) # Check if there's an `ElseScope` following this node (in the parent's children). # Filter StateBoundaryNodes, which we inserted earlier, for this analysis. @@ -277,7 +281,7 @@ def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: self._current_state = else_state # visit children inside the else branch - self.visit(node.children, sdfg=else_body) + self.visit(node.children, sdfg=sdfg) # merge false-branch into merge_state merge_state = self._pop_state("merge_state") @@ -724,11 +728,12 @@ def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> Non # When creating a state boundary, include all inter-state assignments that precede it. pending = self._pending_interstate_assignments() - self._current_state = create_state_boundary(node, - sdfg, - self._current_state, - StateBoundaryBehavior.STATE_TRANSITION, - assignments=pending) + self._current_state = create_state_boundary( + node, + self._current_state, + StateBoundaryBehavior.STATE_TRANSITION, + assignments=pending, + ) def _pending_interstate_assignments(self) -> Dict: """ @@ -924,16 +929,16 @@ def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): # SDFG content creation functions -def create_state_boundary(boundary_node: tn.StateBoundaryNode, - sdfg_region: ControlFlowRegion, - state: SDFGState, - behavior: StateBoundaryBehavior, - assignments: Optional[Dict] = None) -> SDFGState: +def create_state_boundary( + boundary_node: tn.StateBoundaryNode, + state: SDFGState, + behavior: StateBoundaryBehavior, + assignments: Optional[Dict] = None, +) -> SDFGState: """ Creates a boundary between two states :param boundary_node: The state boundary node to generate. - :param sdfg_region: The control flow block in which to generate the boundary (e.g., SDFG). :param state: The last state prior to this boundary. :param behavior: The state boundary behavior with which to create the boundary. :return: The newly created state. @@ -946,15 +951,16 @@ def create_state_boundary(boundary_node: tn.StateBoundaryNode, label = "cf_state_boundary" if boundary_node.due_to_control_flow else "state_boundary" assignments = assignments if assignments is not None else {} - return _insert_and_split_assignments(sdfg_region, state, label=label, assignments=assignments) + return _insert_and_split_assignments(state, label=label, assignments=assignments) -def _insert_and_split_assignments(sdfg_region: ControlFlowRegion, - before_state: ControlFlowBlock, - after_state: Optional[ControlFlowBlock] = None, - *, - label: Optional[str] = None, - assignments: Optional[Dict] = None) -> ControlFlowBlock: +def _insert_and_split_assignments( + before_state: ControlFlowBlock, + after_state: Optional[ControlFlowBlock] = None, + *, + label: Optional[str] = None, + assignments: Optional[Dict] = None, +) -> ControlFlowBlock: """ Insert given assignments splitting them in case of potential race conditions. @@ -969,6 +975,9 @@ def _insert_and_split_assignments(sdfg_region: ControlFlowRegion, validator. """ assignments = assignments if assignments is not None else {} + cf_region = before_state.parent_graph + if after_state is not None and after_state.parent_graph != cf_region: + raise ValueError("Expected before_state and after_state to be in the same control flow region.") has_potential_race = False for key, value in assignments.items(): @@ -980,19 +989,20 @@ def _insert_and_split_assignments(sdfg_region: ControlFlowRegion, if not has_potential_race: if after_state is not None: - sdfg_region.add_edge(before_state, after_state, InterstateEdge(assignments=assignments)) + cf_region.add_edge(before_state, after_state, InterstateEdge(assignments=assignments)) return after_state - return sdfg_region.add_state_after(before_state, label=label, assignments=assignments) + + return cf_region.add_state_after(before_state, label=label, assignments=assignments) last_state = before_state for index, assignment in enumerate(assignments.items()): key, value = assignment is_last_state = index == len(assignments) - 1 if is_last_state and after_state is not None: - sdfg_region.add_edge(last_state, after_state, InterstateEdge(assignments={key: value})) + cf_region.add_edge(last_state, after_state, InterstateEdge(assignments={key: value})) last_state = after_state else: - last_state = sdfg_region.add_state_after(last_state, label=label, assignments={key: value}) + last_state = cf_region.add_state_after(last_state, label=label, assignments={key: value}) return last_state diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index ebe9b15781..cc01624135 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -225,7 +225,7 @@ def test_create_state_boundary_state_transition(control_flow): state = sdfg.add_state("start", is_start_block=True) bnode = tn.StateBoundaryNode(control_flow) - t2s.create_state_boundary(bnode, sdfg, state, t2s.StateBoundaryBehavior.STATE_TRANSITION) + t2s.create_state_boundary(bnode, state, t2s.StateBoundaryBehavior.STATE_TRANSITION) new_label = "cf_state_boundary" if control_flow else "state_boundary" assert ["start", new_label] == [state.label for state in sdfg.states()] @@ -237,7 +237,7 @@ def test_create_state_boundary_empty_memlet(control_flow): state = sdfg.add_state("start", is_start_block=True) bnode = tn.StateBoundaryNode(control_flow) - t2s.create_state_boundary(bnode, sdfg, state, t2s.StateBoundaryBehavior.EMPTY_MEMLET) + t2s.create_state_boundary(bnode, state, t2s.StateBoundaryBehavior.EMPTY_MEMLET) def test_create_tasklet_raw(): @@ -319,17 +319,17 @@ def test_create_tasklet_war(): def test_create_loop_for(): - for_scope = tn.ForScope(loop=LoopRegion(label="my_for_loop", - loop_var="i", - initialize_expr=CodeBlock("i = 0 "), - condition_expr=CodeBlock("i < 3"), - update_expr=CodeBlock("i = i+1")), - children=[ - tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, - {'out': dace.Memlet('A[1]')}), - tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, - {'out': dace.Memlet('A[1]')}), - ]) + for_scope = tn.ForScope( + loop=LoopRegion(label="my_for_loop", + loop_var="i", + initialize_expr=CodeBlock("i = 0 "), + condition_expr=CodeBlock("i < 3"), + update_expr=CodeBlock("i = i+1")), + children=[ + tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[for_scope]) sdfg = stree.as_sdfg() @@ -353,14 +353,16 @@ def test_create_loop_for(): def test_create_loop_while(): - while_scope = tn.WhileScope(children=[ - tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), - tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), - ], - loop=LoopRegion( - label="my_while_loop", - condition_expr=CodeBlock("A[1] > 5"), - )) + while_scope = tn.WhileScope( + children=[ + tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + loop=LoopRegion( + label="my_while_loop", + condition_expr=CodeBlock("A[1] > 5"), + ), + ) stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[while_scope]) @@ -389,13 +391,15 @@ def test_create_if_else(): name="tester", containers={'A': data.Array(dace.float64, [20])}, children=[ - tn.IfScope(condition=CodeBlock("A[0] > 0"), - children=[ - tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), - ]), + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), tn.ElseScope(children=[ tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, {"out": dace.Memlet("A[1]")}) - ]) + ]), ]) sdfg = stree.as_sdfg() @@ -423,15 +427,18 @@ def test_create_if_elif_else() -> None: name="tester", containers={'A': data.Array(dace.float64, [20])}, children=[ - tn.IfScope(condition=CodeBlock("A[0] > 0"), - children=[ - tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), - ]), - tn.ElifScope(condition=CodeBlock("A[0] == 0"), - children=[ - tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, - {"out": dace.Memlet("A[1]")}), - ]), + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), + tn.ElifScope( + condition=CodeBlock("A[0] == 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), tn.ElseScope(children=[ tn.TaskletNode(nodes.Tasklet("test", {}, {"out"}, "out=3"), {}, {"out": dace.Memlet("A[1]")}) ]) @@ -447,15 +454,18 @@ def test_create_if_elif_else() -> None: def test_create_if_without_else(): - stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': data.Array(dace.float64, [20])}, - children=[ - tn.IfScope(condition=CodeBlock("A[0] > 0"), - children=[ - tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, - {"out": dace.Memlet("A[1]")}), - ]), - ]) + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), + ], + ) sdfg = stree.as_sdfg() @@ -472,16 +482,18 @@ def test_create_if_without_else(): def test_create_map_scope_write(): - stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': data.Array(dace.float64, [20])}, - children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", - sbs.Range.from_string("0:20"))), - children=[ - tn.TaskletNode(nodes.Tasklet("asdf", {}, {"out"}, "out = i"), {}, - {"out": dace.Memlet("A[i]")}) - ]) - ]) + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("assign_i", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("A[i]")}) + ], + ) + ], + ) sdfg = stree.as_sdfg() sdfg.validate() @@ -495,14 +507,16 @@ def test_create_map_scope_read_after_write(): 'B': data.Array(dace.float64, [20], transient=True), }, children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), - children=[ - tn.TaskletNode(nodes.Tasklet("write", {}, {"out"}, "out = i"), {}, - {"out": dace.Memlet("B[i]")}), - tn.TaskletNode(nodes.Tasklet("read", {"in_field"}, {"out_field"}, "out_field = in_field"), - {"in_field": dace.Memlet("B[i]")}, {"out_field": dace.Memlet("A[i]")}) - ]) - ]) + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("write", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("B[i]")}), + tn.TaskletNode(nodes.Tasklet("read", {"in_field"}, {"out_field"}, "out_field = in_field"), + {"in_field": dace.Memlet("B[i]")}, {"out_field": dace.Memlet("A[i]")}) + ], + ) + ], + ) sdfg = stree.as_sdfg() sdfg.validate() @@ -513,32 +527,37 @@ def test_create_map_scope_write_after_read(): name="tester", containers={"A": data.Array(dace.float64, [20])}, children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), - children=[ - tn.TaskletNode(nodes.Tasklet("read_write", {"read"}, {"write"}, "write = read+1"), - {"read": dace.Memlet("A[i]")}, {"write": dace.Memlet("A[i]")}) - ]) - ]) + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("read_write", {"read"}, {"write"}, "write = read+1"), + {"read": dace.Memlet("A[i]")}, {"write": dace.Memlet("A[i]")}) + ], + ) + ], + ) sdfg = stree.as_sdfg() sdfg.validate() def test_create_map_scope_copy(): - stree = tn.ScheduleTreeRoot(name="tester", - containers={ - 'A': data.Array(dace.float64, [20]), - 'B': data.Array(dace.float64, [20]), - }, - children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", - sbs.Range.from_string("0:20"))), - children=[ - tn.TaskletNode(nodes.Tasklet("copy", {"inp"}, {"out"}, "out = inp"), - {"inp": dace.Memlet("A[i]")}, - {"out": dace.Memlet("B[i]")}) - ]) - ]) + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), + }, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("copy", {"inp"}, {"out"}, "out = inp"), {"inp": dace.Memlet("A[i]")}, + {"out": dace.Memlet("B[i]")}) + ], + ) + ], + ) sdfg = stree.as_sdfg() sdfg.validate() @@ -570,15 +589,20 @@ def test_create_nested_map_scope(): name="tester", containers={'A': data.Array(dace.float64, [20])}, children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("blub", "j", sbs.Range.from_string("0:10"))), - children=[ - tn.TaskletNode(nodes.Tasklet("asdf", {}, {"out"}, "out = i*10+j"), {}, - {"out": dace.Memlet("A[i*10+j]")}) - ]) - ]) - ]) + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = i*5+j"), {}, + {"out": dace.Memlet("A[i*5+j]")}) + ], + ) + ], + ) + ], + ) sdfg = stree.as_sdfg() sdfg.validate() @@ -592,36 +616,41 @@ def test_create_nested_map_scope_multi_read(): 'B': data.Array(dace.float64, [10]) }, children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("blub", "j", sbs.Range.from_string("0:5"))), children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("blub", "j", sbs.Range.from_string("0:5"))), - children=[ - tn.TaskletNode( - nodes.Tasklet("asdf", {"a_1", "a_2"}, {"out"}, "out = a_1 + a_2"), { - "a_1": dace.Memlet("A[i*5+j]"), - "a_2": dace.Memlet("A[10+i*5+j]"), - }, {"out": dace.Memlet("B[i*5+j]")}) - ]) - ]) - ]) + tn.TaskletNode(nodes.Tasklet("asdf", {"a_1", "a_2"}, {"out"}, "out = a_1 + a_2"), { + "a_1": dace.Memlet("A[i*5+j]"), + "a_2": dace.Memlet("A[10+i*5+j]"), + }, {"out": dace.Memlet("B[i*5+j]")}) + ], + ) + ], + ) + ], + ) sdfg = stree.as_sdfg() sdfg.validate() def test_map_with_state_boundary_inside(): - stree = tn.ScheduleTreeRoot(name="tester", - containers={'A': data.Array(dace.float64, [20])}, - children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", - sbs.Range.from_string("0:20"))), - children=[ - tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = i'), {}, - {'out': dace.Memlet('A[1]')}), - tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2*i'), {}, - {'out': dace.Memlet('A[1]')}), - ]) - ]) + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2*i'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + ], + ) sdfg = stree.as_sdfg() sdfg.validate() @@ -635,22 +664,28 @@ def test_map_calculate_temporary_in_two_loops(): "tmp": data.Array(dace.float64, [20], transient=True) }, children=[ - tn.MapScope(node=nodes.MapEntry(nodes.Map("first_half", "i", sbs.Range.from_string("0:10"))), - children=[ - tn.TaskletNode(nodes.Tasklet("beginning", {}, {'out'}, 'out = i'), {}, - {'out': dace.Memlet("tmp[i]")}) - ]), - tn.MapScope(node=nodes.MapEntry(nodes.Map("second_half", "i", sbs.Range.from_string("10:20"))), - children=[ - tn.TaskletNode(nodes.Tasklet("end", {}, {'out'}, 'out = i'), {}, - {'out': dace.Memlet("tmp[i]")}) - ]), - tn.MapScope(node=nodes.MapEntry(nodes.Map("read_tmp", "i", sbs.Range.from_string("0:20"))), - children=[ - tn.TaskletNode(nodes.Tasklet("read_temp", {"read"}, {"out"}, "out = read + 1"), - {"read": dace.Memlet("tmp[i]")}, {"out": dace.Memlet("A[i]")}) - ]) - ]) + tn.MapScope( + node=nodes.MapEntry(nodes.Map("first_half", "i", sbs.Range.from_string("0:10"))), + children=[ + tn.TaskletNode(nodes.Tasklet("beginning", {}, {'out'}, 'out = i'), {}, + {'out': dace.Memlet("tmp[i]")}) + ], + ), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("second_half", "i", sbs.Range.from_string("10:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("end", {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet("tmp[i]")}) + ], + ), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("read_tmp", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("read_temp", {"read"}, {"out"}, "out = read + 1"), + {"read": dace.Memlet("tmp[i]")}, {"out": dace.Memlet("A[i]")}) + ], + ) + ], + ) sdfg = stree.as_sdfg(simplify=True) sdfg.validate() @@ -660,13 +695,15 @@ def test_map_calculate_temporary_in_two_loops(): def test_edge_assignment_read_after_write(): - stree = tn.ScheduleTreeRoot(name="tester", - containers={}, - children=[ - tn.AssignNode("my_condition", CodeBlock("True"), dace.InterstateEdge()), - tn.AssignNode("condition", CodeBlock("my_condition"), dace.InterstateEdge()), - tn.StateBoundaryNode() - ]) + stree = tn.ScheduleTreeRoot( + name="tester", + containers={}, + children=[ + tn.AssignNode("my_condition", CodeBlock("True"), dace.InterstateEdge()), + tn.AssignNode("condition", CodeBlock("my_condition"), dace.InterstateEdge()), + tn.StateBoundaryNode(), + ], + ) sdfg = stree.as_sdfg(simplify=False) From 0da5e3a5548f8bf8c26c41bd1c1450bdb015fc34 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Tue, 27 Jan 2026 17:41:04 +0100 Subject: [PATCH 127/137] Fix nested if statements inside (triple) loop --- .../analysis/schedule_tree/tree_to_sdfg.py | 15 ++-- tests/schedule_tree/to_sdfg_test.py | 85 +++++++++++++++++++ 2 files changed, 95 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 06b5b8b705..9d5a84ca83 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -177,10 +177,12 @@ def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + current_state = self._current_state + cf_region = current_state.parent_graph + loop_region = node.loop - sdfg.add_node(loop_region) + cf_region.add_node(loop_region) loop_state = loop_region.add_state(f"for_loop_state_{id(node)}", is_start_block=True) - current_state = self._current_state _insert_and_split_assignments(current_state, loop_region) @@ -191,10 +193,12 @@ def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: self._current_state = after_state def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + current_state = self._current_state + cf_region = current_state.parent_graph + loop_region = node.loop - sdfg.add_node(loop_region) + cf_region.add_node(loop_region) loop_state = loop_region.add_state(f"while_loop_state_{id(node)}", is_start_block=True) - current_state = self._current_state _insert_and_split_assignments(current_state, loop_region) @@ -213,9 +217,10 @@ def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: before_state = self._current_state + cf_region = before_state.parent_graph conditional_block = ConditionalBlock(f"if_scope_{id(node)}") - sdfg.add_node(conditional_block) + cf_region.add_node(conditional_block) _insert_and_split_assignments( before_state, conditional_block, diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index cc01624135..069e0a7fc3 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -608,6 +608,91 @@ def test_create_nested_map_scope(): sdfg.validate() +def test_triple_map_flat_if(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [60])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", "k", sbs.Range.from_string("0:3"))), + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 2"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], ), + ], + ) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_triple_map_nested_if(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [60])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", "k", sbs.Range.from_string("0:3"))), + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.IfScope( + condition=CodeBlock("A[1] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 2"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 3"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], ) + ], ), + ], + ) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + def test_create_nested_map_scope_multi_read(): stree = tn.ScheduleTreeRoot( name="tester", From cff050a8b4b8f0c320904194e6c7c05d9e957ef1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Wed, 28 Jan 2026 15:45:02 +0100 Subject: [PATCH 128/137] fix broken sdfg.array access in loops --- dace/sdfg/analysis/schedule_tree/treenodes.py | 10 ++++-- dace/sdfg/state.py | 17 ++++++--- tests/schedule_tree/to_sdfg_test.py | 35 +++++++++++++++++++ 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 2632377e95..c68db8d5cf 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -458,8 +458,10 @@ def as_string(self, indent: int = 0) -> str: return result + super().as_string(indent) def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() - result.update(self.loop.get_meta_read_memlets()) + result.update(self.loop.get_meta_read_memlets(arrays=root.containers)) # If loop range is well-formed, use it in propagation range = _loop_range(self.loop) @@ -500,8 +502,9 @@ def as_string(self, indent: int = 0) -> str: def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() + result = MemletSet() - result.update(self.loop.get_meta_read_memlets()) + result.update(self.loop.get_meta_read_memlets(arrays=root.containers)) result.update(super().input_memlets(root, **kwargs)) return result @@ -524,8 +527,9 @@ def as_string(self, indent: int = 0) -> str: def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: root = root if root is not None else self.get_root() + result = MemletSet() - result.update(self.loop.get_meta_read_memlets()) + result.update(self.loop.get_meta_read_memlets(arrays=root.containers)) result.update(super().input_memlets(root, **kwargs)) return result diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 0ca9b33216..c1b3c08388 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -3502,14 +3502,23 @@ def get_meta_codeblocks(self): codes.append(self.update_statement) return codes - def get_meta_read_memlets(self) -> List[mm.Memlet]: + def get_meta_read_memlets(self, arrays: Optional[Dict[str, dt.Data]] = None) -> List[mm.Memlet]: + """ + Get a list of all (read) memlets in meta codeblocks. + + :param arrays: An optional dictionary mapping array names to their data descriptors. + If not not given defaults to ``self.sdfg.arrays``. + """ # Avoid cyclic imports. from dace.sdfg.sdfg import memlets_in_ast - read_memlets = memlets_in_ast(self.loop_condition.code[0], self.sdfg.arrays) + + arrays = arrays if arrays is not None else self.sdfg.arrays + + read_memlets = memlets_in_ast(self.loop_condition.code[0], arrays) if self.init_statement: - read_memlets.extend(memlets_in_ast(self.init_statement.code[0], self.sdfg.arrays)) + read_memlets.extend(memlets_in_ast(self.init_statement.code[0], arrays)) if self.update_statement: - read_memlets.extend(memlets_in_ast(self.update_statement.code[0], self.sdfg.arrays)) + read_memlets.extend(memlets_in_ast(self.update_statement.code[0], arrays)) return read_memlets def replace_meta_accesses(self, replacements): diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 069e0a7fc3..821fd88950 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -608,6 +608,41 @@ def test_create_nested_map_scope(): sdfg.validate() +def test_double_map_with_for_loop(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.ForScope( + loop=LoopRegion( + label="loop_k", + loop_var="k", + initialize_expr=CodeBlock("k = 0 "), + condition_expr=CodeBlock("k < 3"), + update_expr=CodeBlock("k = k+1"), + ), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1.0"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + assert sdfg.is_valid() + + def test_triple_map_flat_if(): stree = tn.ScheduleTreeRoot( name="tester", From aa1e4f4a872e3d96de2e6435e926c048a89241c2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 29 Jan 2026 18:30:12 +0100 Subject: [PATCH 129/137] propagate sdfg name into stree name (and back) --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 2 +- tests/schedule_tree/roundtrip_test.py | 13 +++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 14c1dcd7fc..b5a0a664f5 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -843,7 +843,7 @@ def as_schedule_tree(sdfg: SDFG, *, in_place: bool = False, toplevel: bool = Tru _prepare_sdfg_for_conversion(sdfg, toplevel=toplevel) if toplevel: - result = tn.ScheduleTreeRoot(name="default_stree_name", children=[]) + result = tn.ScheduleTreeRoot(name=sdfg.name, children=[]) _create_unified_descriptor_repository(sdfg, result) result.add_children(_block_schedule_tree(sdfg)) else: diff --git a/tests/schedule_tree/roundtrip_test.py b/tests/schedule_tree/roundtrip_test.py index 22f178cbc3..17e7091d6b 100644 --- a/tests/schedule_tree/roundtrip_test.py +++ b/tests/schedule_tree/roundtrip_test.py @@ -42,5 +42,18 @@ def tester(A: dace.float64[20, 20]): assert np.allclose(a, 1) +def test_name_propagation(): + name = "my_complicated_sdfg_test_name" + sdfg = dace.SDFG(name) + sdfg.add_state("empty", is_start_block=True) + + stree = sdfg.as_schedule_tree() + assert stree.name == name + + sdfg = stree.as_sdfg() + assert sdfg.name == name + + if __name__ == '__main__': test_implicit_inline_and_constants() + test_name_propagation() From 491d74626720c1f604edd0ba941db6923bc76686 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Mon, 2 Feb 2026 11:04:58 +0100 Subject: [PATCH 130/137] fix symbol replacement for FvTp2d --- dace/sdfg/replace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 04ceafb261..4bf1e74a0d 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -81,7 +81,7 @@ def replace_dict(subgraph: 'StateSubgraphView', desc = node.desc(state) # In case the AccessNode name was replaced in the sdfg.arrays but not in the SDFG itself # then we have to look for the replaced value in the sdfg.arrays - elif repl[node.data] in state.sdfg.arrays: + elif node.data in repl and repl[node.data] in state.sdfg.arrays: desc = state.sdfg.arrays[repl[node.data]] else: continue From e701af6d9034a944ea14f2ad0e6f13d1bd571b31 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Wed, 4 Feb 2026 09:30:59 +0100 Subject: [PATCH 131/137] skip memlet propagation tests According to Tal, suboptimal memlet propagation won't affect validation. It can be fixed in a follow-up PR. --- tests/schedule_tree/naming_test.py | 3 +-- tests/schedule_tree/propagation_test.py | 6 ++++++ tests/schedule_tree/roundtrip_test.py | 2 +- tests/schedule_tree/schedule_test.py | 4 ++-- tests/schedule_tree/to_sdfg_test.py | 4 ++-- tests/schedule_tree/treenodes_test.py | 2 ++ 6 files changed, 14 insertions(+), 7 deletions(-) diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py index 8c39e3033f..632b64dfdf 100644 --- a/tests/schedule_tree/naming_test.py +++ b/tests/schedule_tree/naming_test.py @@ -5,7 +5,6 @@ from dace.transformation.passes.constant_propagation import ConstantPropagation import pytest -from typing import List def _irreducible_loop_to_loop(): @@ -171,7 +170,7 @@ def test_edgecase_symbol_mapping(): def _check_for_name_clashes(stree: tn.ScheduleTreeNode): - def _traverse(node: tn.ScheduleTreeScope, scopes: List[str]): + def _traverse(node: tn.ScheduleTreeScope, scopes: list[str]): for child in node.children: if isinstance(child, tn.LoopScope): itervar = child.loop.loop.loop_variable diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py index 507a3d7226..0a5e671d20 100644 --- a/tests/schedule_tree/propagation_test.py +++ b/tests/schedule_tree/propagation_test.py @@ -6,7 +6,9 @@ from dace.sdfg import nodes from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn from dace.properties import CodeBlock + import numpy as np +import pytest def test_stree_propagation_forloop(): @@ -28,6 +30,8 @@ def tester(a: dace.float64[20]): assert list(node_types[2].output_memlets()) == [memlet] +# TODO: write issue and link it here s.t. we don't forget +@pytest.mark.skip("Suboptimal memlet propagation") def test_stree_propagation_symassign(): # Manually create a schedule tree N = dace.symbol('N') @@ -55,6 +59,8 @@ def test_stree_propagation_symassign(): assert list(stree.children[0].input_memlets()) == [dace.Memlet('A[0:20]', volume=N - 1)] +# TODO: write issue and link it here s.t. we don't forget +@pytest.mark.skip("Suboptimal memlet propagation") def test_stree_propagation_dynset(): H = dace.symbol('H') nnz = dace.symbol('nnz') diff --git a/tests/schedule_tree/roundtrip_test.py b/tests/schedule_tree/roundtrip_test.py index 17e7091d6b..3d0fa42b8c 100644 --- a/tests/schedule_tree/roundtrip_test.py +++ b/tests/schedule_tree/roundtrip_test.py @@ -25,7 +25,7 @@ def tester(A: dace.float64[20, 20]): # Inject constant into nested SDFG assert len(list(sdfg.all_sdfgs_recursive())) > 1 sdfg.add_constant('cst', 13) # Add an unused constant - sdfg.sdfg_list[-1].add_constant('cst', 1, dace.data.Scalar(dace.float64)) + sdfg.cfg_list[-1].add_constant('cst', 1, dace.data.Scalar(dace.float64)) tasklet = next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.Tasklet)) tasklet.code.as_string = tasklet.code.as_string.replace('12', 'cst') diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 1f929bd779..fc2f3ea314 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -3,11 +3,11 @@ import dace from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree -import numpy as np - from dace.transformation.pass_pipeline import FixedPointPipeline from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +import numpy as np + def test_for_in_map_in_for(): diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py index 821fd88950..5fd5e04fd1 100644 --- a/tests/schedule_tree/to_sdfg_test.py +++ b/tests/schedule_tree/to_sdfg_test.py @@ -8,10 +8,10 @@ from dace.properties import CodeBlock from dace.sdfg import nodes from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn -import pytest - from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState +import pytest + def test_state_boundaries_none(): # Manually create a schedule tree diff --git a/tests/schedule_tree/treenodes_test.py b/tests/schedule_tree/treenodes_test.py index 28c5004974..31a0abbd21 100644 --- a/tests/schedule_tree/treenodes_test.py +++ b/tests/schedule_tree/treenodes_test.py @@ -1,3 +1,5 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace import nodes From 40b424cdde52edf7ffc75ab65077d9e983648112 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Wed, 4 Feb 2026 15:24:55 +0100 Subject: [PATCH 132/137] trivial changes to trigger a CI run --- dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index b5a0a664f5..2869020b1e 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -76,7 +76,7 @@ def dealias_sdfg(sdfg: SDFG): if replacements.get(name, None) is not None: # There's an incoming and an outgoing connector with the same name. - # Make sure both map to the same memory in the parent sdfg + # Make sure both map to the same memory in the parent sdfg. assert replacements[name] == parent_name assert name in inv_replacements[parent_name] break @@ -566,7 +566,7 @@ def _state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG replace_symbols_until_set(node) - # Create memlets for nested SDFG mapping, or nview schedule nodes if slice cannot be determined + # Create memlets for nested SDFG mapping, or NView schedule nodes if slice cannot be determined. for e in state.all_edges(node): conn = e.dst_conn if e.dst is node else e.src_conn if e.data.is_empty() or not conn: @@ -585,7 +585,7 @@ def _state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: else: nested_array_mapping_output[conn] = e.data - if no_mapping: # Must use view (nview = nested SDFG view) + if no_mapping: # Must use view (NView = nested SDFG view) if conn not in generated_nviews: nview_node = tn.NView(target=conn, source=e.data.data, @@ -716,8 +716,7 @@ def _block_schedule_tree(block: ControlFlowBlock) -> List[tn.ScheduleTreeNode]: if variant == "do-while": return [tn.DoWhileScope(loop=block, children=children)] - # If we end up here, we don't need more granularity and just use - # a general loop scope + # If we end up here, we don't need more granularity and just use a general loop scope. return [tn.LoopScope(loop=block, children=children)] return children @@ -802,6 +801,7 @@ def _create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot """ Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes data containers, symbols, constants, etc. + :param sdfg: The top-level SDFG to create the repository from. :param stree: The tree root in which to make the unified descriptor repository. """ From 77359fd5158fb05d56432eb2b3f7b5e15c220d09 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Wed, 4 Feb 2026 15:54:47 +0100 Subject: [PATCH 133/137] WIP: cleanup from code review --- dace/sdfg/memlet_utils.py | 21 +++++++++++---------- dace/sdfg/propagation.py | 8 ++------ dace/version.py | 2 +- 3 files changed, 14 insertions(+), 17 deletions(-) diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 4f5f507ea2..81f4ca6726 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -1,4 +1,5 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from __future__ import annotations import ast from collections import defaultdict @@ -10,7 +11,7 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.frontend.python import memlet_parser import itertools -from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Tuple, Union +from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Tuple, Union, Generator, Any class MemletReplacer(ast.NodeTransformer): @@ -92,7 +93,7 @@ class MemletSet(Set[Memlet]): Set updates and unions also perform unions on the contained memlet subsets. """ - def __init__(self, iterable: Optional[Iterable[Memlet]] = None, intersection_is_contained: bool = True): + def __init__(self, iterable: Optional[Iterable[Memlet]] = None, intersection_is_contained: bool = True) -> None: """ Initializes a memlet set. @@ -106,14 +107,14 @@ def __init__(self, iterable: Optional[Iterable[Memlet]] = None, intersection_is_ if iterable is not None: self.update(iterable) - def __iter__(self): + def __iter__(self) -> Generator[Memlet, Any, None]: for subset in self.internal_set.values(): yield from subset - def __len__(self): + def __len__(self) -> int: return len(self.internal_set) - def update(self, *iterable: Iterable[Memlet]): + def update(self, *iterable: Iterable[Memlet]) -> None: """ Updates set of memlets via union of existing ranges. """ @@ -128,7 +129,7 @@ def update(self, *iterable: Iterable[Memlet]): for elem in to_update: self.add(elem) - def add(self, elem: Memlet): + def add(self, elem: Memlet) -> None: """ Adds a memlet to the set, potentially performing a union of existing ranges. """ @@ -175,7 +176,7 @@ def __contains__(self, elem: Memlet) -> bool: return False - def union(self, *s: Iterable[Memlet]) -> 'MemletSet': + def union(self, *s: Iterable[Memlet]) -> MemletSet: """ Performs a set-union (with memlet union) over the given sets of memlets. @@ -195,7 +196,7 @@ class MemletDict(Dict[Memlet, T]): """ covers_cache: Dict[Tuple, bool] = {} - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: self.internal_dict: Dict[str, Dict[Memlet, T]] = defaultdict(dict) if kwargs: self.update(kwargs) @@ -229,10 +230,10 @@ def _getkey(self, elem: Memlet) -> Optional[Memlet]: def _setkey(self, key: Memlet, value: T) -> None: self.internal_dict[key.data][key] = value - def clear(self): + def clear(self) -> None: self.internal_dict.clear() - def update(self, mapping: Dict[Memlet, T]): + def update(self, mapping: Dict[Memlet, T]) -> None: for k, v in mapping.items(): ak = self._getkey(k) if ak is None: diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index b752430354..530b279a4b 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -423,10 +423,6 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di if symbolic.issymbolic(dim): used_symbols.update(dim.free_symbols) - # if any(s not in defined_vars for s in (used_symbols - set(self.params))): - # # Cannot propagate symbols that are undefined outside scope (e.g., internal symbols) - # return False - if (used_symbols & set(self.params) and any(symbolic.pystr_to_symbolic(s) not in defined_vars for s in node_range.free_symbols)): # Cannot propagate symbols that are undefined in the outer range @@ -1525,11 +1521,11 @@ def propagate_subset(memlets: List[Memlet], tmp_subset = pattern.propagate(arr, [subset], rng) break else: - # No patterns found. Propagate the entire array whenever symbols are used + # No patterns found. Propagate the entire array whenever symbols are used. entire_array = subsets.Range.from_array(arr) paramset = set(map(str, params)) # Fill in the entire array only if one of the parameters appears in the - # free symbols list of the subset dimension or is undefined outside + # free symbols list of the subset dimension or is undefined outside. tmp_subset_rng = [] for s, ea in zip(subset, entire_array): if isinstance(subset, subsets.Indices): diff --git a/dace/version.py b/dace/version.py index a6221b3de7..1f356cc57b 100644 --- a/dace/version.py +++ b/dace/version.py @@ -1 +1 @@ -__version__ = '1.0.2' +__version__ = '1.0.0' From f6cb1843eb7b70e3650b4dabaa18bd866e0be27f Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Wed, 4 Feb 2026 16:31:03 +0100 Subject: [PATCH 134/137] update tests: LoopScop -> ForScope --- tests/sdfg/loop_region_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/sdfg/loop_region_test.py b/tests/sdfg/loop_region_test.py index f09bf0004f..011787c98b 100644 --- a/tests/sdfg/loop_region_test.py +++ b/tests/sdfg/loop_region_test.py @@ -311,7 +311,7 @@ def test_loop_to_stree_triple_nested_for(): stree = s2t.as_schedule_tree(sdfg) po_nodes = list(stree.preorder_traversal())[1:] - assert [type(n) for n in po_nodes] == [tn.LoopScope, tn.LoopScope, tn.LoopScope, tn.TaskletNode, tn.LibraryCall] + assert [type(n) for n in po_nodes] == [tn.ForScope, tn.ForScope, tn.ForScope, tn.TaskletNode, tn.LibraryCall] if __name__ == '__main__': From 78b3305bcdbe14d4d07408e93d9ea828ced94ef5 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Wed, 4 Feb 2026 17:01:45 +0100 Subject: [PATCH 135/137] keep track of skipped tests --- tests/schedule_tree/propagation_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py index 0a5e671d20..c5233531b9 100644 --- a/tests/schedule_tree/propagation_test.py +++ b/tests/schedule_tree/propagation_test.py @@ -30,8 +30,7 @@ def tester(a: dace.float64[20]): assert list(node_types[2].output_memlets()) == [memlet] -# TODO: write issue and link it here s.t. we don't forget -@pytest.mark.skip("Suboptimal memlet propagation") +@pytest.mark.skip("Suboptimal memlet propagation: https://github.com/spcl/dace/issues/2293") def test_stree_propagation_symassign(): # Manually create a schedule tree N = dace.symbol('N') @@ -59,8 +58,7 @@ def test_stree_propagation_symassign(): assert list(stree.children[0].input_memlets()) == [dace.Memlet('A[0:20]', volume=N - 1)] -# TODO: write issue and link it here s.t. we don't forget -@pytest.mark.skip("Suboptimal memlet propagation") +@pytest.mark.skip("Suboptimal memlet propagation: https://github.com/spcl/dace/issues/2293") def test_stree_propagation_dynset(): H = dace.symbol('H') nnz = dace.symbol('nnz') From 052ff1d265accf0bbcc0fe39eeb72ecf52c3f6b1 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Mon, 9 Feb 2026 14:55:23 +0100 Subject: [PATCH 136/137] simple tests for MemletSet and MemletDict --- dace/sdfg/memlet_utils.py | 28 ++++++++----- tests/sdfg/memlet_utils_test.py | 70 ++++++++++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 16 deletions(-) diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 81f4ca6726..ddba9d66ef 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -93,7 +93,7 @@ class MemletSet(Set[Memlet]): Set updates and unions also perform unions on the contained memlet subsets. """ - def __init__(self, iterable: Optional[Iterable[Memlet]] = None, intersection_is_contained: bool = True) -> None: + def __init__(self, iterable: Optional[Iterable[Memlet]] = None, *, intersection_is_contained: bool = True) -> None: """ Initializes a memlet set. @@ -192,22 +192,28 @@ def union(self, *s: Iterable[Memlet]) -> MemletSet: class MemletDict(Dict[Memlet, T]): """ - Implements a dictionary with memlet keys that considers subsets that intersect or are covered by its other memlets. + Implements a dictionary with memlet keys that considers subsets that intersect + or are covered by its other memlets. """ - covers_cache: Dict[Tuple, bool] = {} def __init__(self, **kwargs) -> None: self.internal_dict: Dict[str, Dict[Memlet, T]] = defaultdict(dict) + self.covers_cache: Dict[Tuple, bool] = defaultdict() + if kwargs: self.update(kwargs) + def __len__(self) -> int: + return len(self.internal_dict) + def _getkey(self, elem: Memlet) -> Optional[Memlet]: """ - Returns the corresponding key (exact, covered, intersecting, or indeterminately intersecting memlet) if - exists in the dictionary, or None if it does not. + Returns the corresponding key (exact, covered, intersecting, or indeterminately intersecting memlet) + if it exists in the dictionary, or None if it does not. """ if elem.data not in self.internal_dict: return None + for existing_memlet in self.internal_dict[elem.data]: key = (existing_memlet.subset, elem.subset) is_covered = self.covers_cache.get(key, None) @@ -216,6 +222,7 @@ def _getkey(self, elem: Memlet) -> Optional[Memlet]: self.covers_cache[key] = is_covered if is_covered: return existing_memlet + try: if subsets.intersects(existing_memlet.subset, elem.subset) == False: # Definitely does not intersect continue @@ -234,12 +241,12 @@ def clear(self) -> None: self.internal_dict.clear() def update(self, mapping: Dict[Memlet, T]) -> None: - for k, v in mapping.items(): - ak = self._getkey(k) - if ak is None: - self._setkey(k, v) + for key, value in mapping.items(): + actual_key = self._getkey(key) + if actual_key is None: + self._setkey(key, value) else: - self._setkey(ak, v) + self._setkey(actual_key, value) def __contains__(self, elem: Memlet) -> bool: """ @@ -251,6 +258,7 @@ def __getitem__(self, key: Memlet) -> T: actual_key = self._getkey(key) if actual_key is None: raise KeyError(key) + return self.internal_dict[key.data][actual_key] def __setitem__(self, key: Memlet, value: T) -> None: diff --git a/tests/sdfg/memlet_utils_test.py b/tests/sdfg/memlet_utils_test.py index 3c85d72f21..c4a78991c7 100644 --- a/tests/sdfg/memlet_utils_test.py +++ b/tests/sdfg/memlet_utils_test.py @@ -1,10 +1,10 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import copy import dace import numpy as np import pytest -from dace import symbolic -from dace.sdfg import memlet_utils as mu +from dace.sdfg import graph, memlet_utils as mu import re from typing import Tuple, Optional @@ -12,6 +12,7 @@ def _replace_zero_with_one(memlet: dace.Memlet) -> dace.Memlet: if not isinstance(memlet.subset, dace.subsets.Range): return memlet + for i, (rb, re, rs) in enumerate(memlet.subset.ndrange()): if rb == 0: memlet.subset.ranges[i] = (1, 1, rs) @@ -19,7 +20,7 @@ def _replace_zero_with_one(memlet: dace.Memlet) -> dace.Memlet: @pytest.mark.parametrize('filter_type', ['none', 'same_array', 'different_array']) -def test_replace_memlet(filter_type): +def test_replace_memlet(filter_type: str) -> None: # Prepare SDFG sdfg = dace.SDFG('replace_memlet') sdfg.add_array('A', [2, 2], dace.float64) @@ -66,7 +67,7 @@ def test_replace_memlet(filter_type): assert B[0] == 1 -def _perform_non_lin_delin_test(sdfg: dace.SDFG, edge) -> bool: +def _perform_non_lin_delin_test(sdfg: dace.SDFG, edge: graph.MultiConnectorEdge) -> None: assert sdfg.number_of_nodes() == 1 state: dace.SDFGState = sdfg.states()[0] assert state.number_of_nodes() == 2 @@ -104,8 +105,6 @@ def _perform_non_lin_delin_test(sdfg: dace.SDFG, edge) -> bool: sdfg(a=a, b=b_opt) assert np.allclose(b_unopt, b_opt) - return True - def _make_non_lin_delin_sdfg( shape_a: Tuple[int, ...], @@ -211,6 +210,62 @@ def test_non_lin_delin_8(): _perform_non_lin_delin_test(sdfg, e) +def test_MemletSet() -> None: + empty_set = mu.MemletSet() + assert len(empty_set) == 0 + + memlet_set = mu.MemletSet([dace.Memlet("A[0:5]")]) + covered_set = mu.MemletSet([dace.Memlet("A[0:5]")], intersection_is_contained=False) + + assert dace.Memlet("A[0:2]") in memlet_set + assert dace.Memlet("A[0:2]") in covered_set + assert dace.Memlet("A[2:7]") in memlet_set + assert dace.Memlet("A[2:7]") not in covered_set + + assert dace.Memlet("B[0:2]") not in memlet_set + + before = copy.deepcopy(covered_set.internal_set) + covered_set.add(dace.Memlet("A[0:2]")) + assert covered_set.internal_set == before + + covered_set.add(dace.Memlet("A[4:9]")) + assert dace.Memlet("A[2:7]") in covered_set + assert covered_set.internal_set != before + + union = empty_set.union(dace.Memlet("A[0:3]"), dace.Memlet("A[2:10]")) + assert dace.Memlet("A[5:7]") not in empty_set + assert dace.Memlet("A[5:7]") in union + assert len(union.internal_set["A"]) == 1 + internal_memlet = list(union.internal_set["A"])[0] + assert internal_memlet.subset == dace.subsets.Range.from_string("0:10") + + +def test_MemletDict() -> None: + A_01 = dace.Memlet("A[0:1]") + A_02 = dace.Memlet("A[0:2]") + A_34 = dace.Memlet("A[3:4]") + memlet_dict: mu.Memlet[list[int]] = mu.MemletDict() + assert len(memlet_dict) == 0 + assert A_02 not in memlet_dict + + memlet_dict[A_02] = [42] + assert A_02 in memlet_dict + assert A_01 in memlet_dict + assert A_34 not in memlet_dict + assert dace.Memlet("B[0:2]") not in memlet_dict + + memlet_dict[A_01].append(43) + assert memlet_dict[A_02] == [42, 43] + + memlet_dict.update({A_34: [0], A_01: [44]}) + assert A_34 in memlet_dict + assert memlet_dict[A_02] == [44] # @Tal this is expected, right? + assert memlet_dict[A_34] == [0] + + memlet_dict.clear() + assert len(memlet_dict) == 0 + + if __name__ == '__main__': test_replace_memlet('none') test_replace_memlet('same_array') @@ -224,3 +279,6 @@ def test_non_lin_delin_8(): test_non_lin_delin_6() test_non_lin_delin_7() test_non_lin_delin_8() + + test_MemletSet() + test_MemletDict() From 23afc4301bc3b8be85d77e05834347af60dfec99 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <> Date: Mon, 9 Feb 2026 15:56:47 +0100 Subject: [PATCH 137/137] cleanup: self-review --- .../analysis/schedule_tree/sdfg_to_tree.py | 2 +- .../analysis/schedule_tree/tree_to_sdfg.py | 47 ++++++------------- dace/sdfg/analysis/schedule_tree/treenodes.py | 26 ++-------- 3 files changed, 21 insertions(+), 54 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 2869020b1e..c0fa7457ee 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -797,7 +797,7 @@ def _prepare_sdfg_for_conversion(sdfg: SDFG, *, toplevel: bool) -> None: dealias_sdfg(sdfg) -def _create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot): +def _create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot) -> None: """ Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes data containers, symbols, constants, etc. diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py index 9d5a84ca83..988e73fff9 100644 --- a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -5,7 +5,7 @@ from dace.memlet import Memlet from dace.sdfg import nodes, memlet_utils as mmu from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge -from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, LoopRegion, SDFGState +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg import propagation from enum import Enum, auto @@ -101,35 +101,28 @@ def _pop_state(self, label: Optional[str] = None) -> SDFGState: return popped def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: - # -- to be torched -- assert self._current_state is None, "Expected no 'current_state' at root." assert not self._state_stack, "Expected empty state stack at root." assert not self._dataflow_stack, "Expected empty dataflow stack at root." assert not self._interstate_symbols, "Expected empty list of symbols at root." - # end -- to be torched -- self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) self._ctx = tn.Context(root=node, access_cache={}, current_scope=None) with node.scope(self._current_state, self._ctx): self.visit(node.children, sdfg=sdfg) - # -- to be torched -- assert not self._state_stack, "Expected empty state stack." assert not self._dataflow_stack, "Expected empty dataflow stack." assert not self._interstate_symbols, "Expected empty list of symbols to add." - # end -- to be torched -- def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_StateLabel(self, node: tn.StateLabel, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_GotoNode(self, node: tn.GotoNode, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for{type(node)} not yet implemented.") def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: # We just collect them here. They'll be added when state boundaries are added, @@ -209,11 +202,10 @@ def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: self._current_state = after_state def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: - # AFAIK we don't support for do-while loops in the gt4py -> dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: - raise NotImplementedError("TODO: LoopScopes are not yet implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: before_state = self._current_state @@ -260,20 +252,16 @@ def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: self._current_state = merge_state def visit_StateIfScope(self, node: tn.StateIfScope, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_BreakNode(self, node: tn.BreakNode, sdfg: SDFG) -> None: - # AFAIK we don't support for break statements in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_ContinueNode(self, node: tn.ContinueNode, sdfg: SDFG) -> None: - # AFAIK we don't support for continue statements in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_ElifScope(self, node: tn.ElifScope, sdfg: SDFG) -> None: - # AFAIK we don't support elif scopes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: # get ConditionalBlock form stack @@ -561,8 +549,7 @@ def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: assert len(self._current_state.in_edges(map_exit)) > 0 def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: - # AFAIK we don't support consume scopes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: # Add Tasklet to current state @@ -668,8 +655,7 @@ def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: assert scope_node is None def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: - # AFAIK we expand all library calls in the gt4py/dace bridge before coming here. - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: # ensure we have an access_cache and fetch it @@ -689,12 +675,10 @@ def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: self._current_state.add_memlet_path(source, target, memlet=node.memlet) def visit_DynScopeCopyNode(self, node: tn.DynScopeCopyNode, sdfg: SDFG) -> None: - # AFAIK we don't support dyn scope copy nodes in the gt4py/dace bridge. - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_ViewNode(self, node: tn.ViewNode, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: # Basic working principle: @@ -726,8 +710,7 @@ def visit_NViewEnd(self, node: tn.NViewEnd, sdfg: SDFG) -> None: raise RuntimeError(f"No matching NView found for target {node.target} in {self._nviews_free}.") def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: - # Let's see if we need this for the first prototype ... - raise NotImplementedError(f"{type(node)} not implemented") + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: # When creating a state boundary, include all inter-state assignments that precede it. diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index c68db8d5cf..a142f1fa89 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -51,9 +51,7 @@ def __init__(self, ctx: Context, state: SDFGState, node: 'ScheduleTreeScope') -> def __enter__(self) -> None: assert not self._ctx.access_cache[(self._state, id( self._node))], "Expecting an empty access_cache when entering the context." - # self._node.parent = self._parent_scope - # if self._parent_scope is not None: # Exception for ScheduleTreeRoot - # self._parent_scope.children.append(self._node) + self._ctx.current_scope = self._node def __exit__( @@ -64,7 +62,6 @@ def __exit__( ) -> None: cache_key = (self._state, id(self._node)) assert cache_key in self._ctx.access_cache - # self._ctx.access_cache[cache_key].clear() self._ctx.current_scope = self._parent_scope @@ -436,7 +433,7 @@ def as_string(self, indent: int = 0): if variant in ["for", "while", "do-while"]: return super().as_string(indent) - return NotImplementedError # TODO: nice error message + return NotImplementedError(f"Unknown LoopRegion variant '{variant}.") @dataclass @@ -713,20 +710,6 @@ def as_string(self, indent: int = 0): return result + super().as_string(indent) -# TODO: to be removed. looks like `Pipeline` nodes aren't a thing anymore -# @dataclass -# class PipelineScope(MapScope): -# """ -# Pipeline scope. -# """ -# node: nodes.PipelineEntry -# -# def as_string(self, indent: int = 0): -# rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) -# result = indent * INDENTATION + f'pipeline {", ".join(self.node.map.params)} in [{rangestr}]:\n' -# return result + super().as_string(indent) - - @dataclass class TaskletNode(ScheduleTreeNode): node: nodes.Tasklet @@ -859,7 +842,8 @@ class NViewEnd(ScheduleTreeNode): Artificial node to denote the scope end of the associated Nested SDFG view node. """ - target: str #: target name of the associated NView container + target: str + """Target name of the associated NView container.""" def as_string(self, indent: int = 0): return indent * INDENTATION + f"end nview {self.target}" @@ -1035,7 +1019,7 @@ def _loop_range( if start is None or step is None or end is None: return None - return (start, end, step) # `end` is inclusive + return (start, end, step) def _match_loop_condition(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: