diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 46eb37cdb2..c0fa7457ee 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -43,7 +43,8 @@ def dealias_sdfg(sdfg: SDFG): replacements: Dict[str, str] = {} inv_replacements: Dict[str, List[str]] = {} - parent_edges: Dict[str, Memlet] = {} + parent_edges_inputs: Dict[str, Memlet] = {} + parent_edges_outputs: Dict[str, Memlet] = {} to_unsqueeze: Set[str] = set() parent_sdfg = nsdfg.parent_sdfg @@ -53,19 +54,42 @@ def dealias_sdfg(sdfg: SDFG): for name, desc in nsdfg.arrays.items(): if desc.transient: continue - for edge in parent_state.edges_by_connector(parent_node, name): + for edge in parent_state.in_edges_by_connector(parent_node, name): parent_name = edge.data.data assert parent_name in parent_sdfg.arrays - if parent_name != name: + if name != parent_name: + parent_edges_inputs[name] = edge replacements[name] = parent_name - parent_edges[name] = edge if parent_name in inv_replacements: inv_replacements[parent_name].append(name) to_unsqueeze.add(parent_name) else: inv_replacements[parent_name] = [name] + # We found an incoming edge for name and we don't expect a second one. break + for edge in parent_state.out_edges_by_connector(parent_node, name): + parent_name = edge.data.data + assert parent_name in parent_sdfg.arrays + if name != parent_name: + parent_edges_outputs[name] = edge + + if replacements.get(name, None) is not None: + # There's an incoming and an outgoing connector with the same name. + # Make sure both map to the same memory in the parent sdfg. + assert replacements[name] == parent_name + assert name in inv_replacements[parent_name] + break + else: + replacements[name] = parent_name + if parent_name in inv_replacements: + inv_replacements[parent_name].append(name) + to_unsqueeze.add(parent_name) + else: + inv_replacements[parent_name] = [name] + # We found an outgoing edge for name and we don't expect a second one. + break + if to_unsqueeze: for parent_name in to_unsqueeze: parent_arr = parent_sdfg.arrays[parent_name] @@ -94,14 +118,18 @@ def dealias_sdfg(sdfg: SDFG): # destination subset if isinstance(src, nd.AccessNode) and src.data in child_names: src_data = src.data - new_src_memlet = unsqueeze_memlet(e.data, parent_edges[src.data].data, use_src_subset=True) + new_src_memlet = unsqueeze_memlet(e.data, + parent_edges_inputs[src.data].data, + use_src_subset=True) else: src_data = None new_src_memlet = None # We need to take directionality of the memlet into account if isinstance(dst, nd.AccessNode) and dst.data in child_names: dst_data = dst.data - new_dst_memlet = unsqueeze_memlet(e.data, parent_edges[dst.data].data, use_dst_subset=True) + new_dst_memlet = unsqueeze_memlet(e.data, + parent_edges_outputs[dst.data].data, + use_dst_subset=True) else: dst_data = None new_dst_memlet = None @@ -120,23 +148,26 @@ def dealias_sdfg(sdfg: SDFG): syms = e.data.read_symbols() for memlet in e.data.get_read_memlets(nsdfg.arrays): if memlet.data in child_names: - repl_dict[str(memlet)] = unsqueeze_memlet(memlet, parent_edges[memlet.data].data) + repl_dict[str(memlet)] = unsqueeze_memlet(memlet, parent_edges_inputs[memlet.data].data) if memlet.data in syms: syms.remove(memlet.data) for s in syms: - if s in parent_edges: + if s in parent_edges_inputs: if s in nsdfg.arrays: - repl_dict[s] = parent_edges[s].data.data + repl_dict[s] = parent_edges_inputs[s].data.data else: - repl_dict[s] = str(parent_edges[s].data) + repl_dict[s] = str(parent_edges_inputs[s].data) e.data.replace_dict(repl_dict) for name in child_names: - edge = parent_edges[name] - for e in parent_state.memlet_tree(edge): - if e.data.data == parent_name: - e.data.subset = subsets.Range.from_array(parent_arr) - else: - e.data.other_subset = subsets.Range.from_array(parent_arr) + for edge in [parent_edges_inputs.get(name, None), parent_edges_outputs.get(name, None)]: + if edge is None: + continue + + for e in parent_state.memlet_tree(edge): + if e.data.data == parent_name: + e.data.subset = subsets.Range.from_array(parent_arr) + else: + e.data.other_subset = subsets.Range.from_array(parent_arr) if replacements: struct_outside_replacements: Dict[str, str] = {} @@ -523,6 +554,10 @@ def _state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: result = subnodes elif isinstance(node, dace.nodes.ExitNode): result = scopes.pop() + parent = result[-1] + assert isinstance(parent, tn.ScheduleTreeScope) + for child in parent.children: + child.parent = parent elif isinstance(node, dace.nodes.NestedSDFG): nested_array_mapping_input = {} nested_array_mapping_output = {} @@ -531,7 +566,7 @@ def _state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Replace symbols and memlets in nested SDFGs to match the namespace of the parent SDFG replace_symbols_until_set(node) - # Create memlets for nested SDFG mapping, or nview schedule nodes if slice cannot be determined + # Create memlets for nested SDFG mapping, or NView schedule nodes if slice cannot be determined. for e in state.all_edges(node): conn = e.dst_conn if e.dst is node else e.src_conn if e.data.is_empty() or not conn: @@ -550,7 +585,7 @@ def _state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: else: nested_array_mapping_output[conn] = e.data - if no_mapping: # Must use view (nview = nested SDFG view) + if no_mapping: # Must use view (NView = nested SDFG view) if conn not in generated_nviews: nview_node = tn.NView(target=conn, source=e.data.data, @@ -565,6 +600,12 @@ def _state_schedule_tree(state: SDFGState) -> List[tn.ScheduleTreeNode]: # Insert the nested SDFG flattened nested_stree = as_schedule_tree(node.sdfg, in_place=True, toplevel=False) result.extend(nested_stree.children) + + if generated_nviews: + # Insert matching NViewEnd nodes to define the scope NView nodes. + for target in generated_nviews: + result.append(tn.NViewEnd(target=target)) + elif isinstance(node, dace.nodes.Tasklet): in_memlets = {e.dst_conn: e.data for e in state.in_edges(node) if e.dst_conn} out_memlets = {e.src_conn: e.data for e in state.out_edges(node) if e.src_conn} @@ -664,11 +705,23 @@ def _block_schedule_tree(block: ControlFlowBlock) -> List[tn.ScheduleTreeNode]: pivot = None if isinstance(block, LoopRegion): - # If this is a loop region, wrap everything in a LoopScope node. - loop_node = tn.LoopScope(loop=block, children=children) - return [loop_node] + # If this is a loop region, wrap everything in a loop scope node. + variant = tn.loop_variant(block) + if variant == "for": + return [tn.ForScope(loop=block, children=children)] + + if variant == "while": + return [tn.WhileScope(loop=block, children=children)] + + if variant == "do-while": + return [tn.DoWhileScope(loop=block, children=children)] + + # If we end up here, we don't need more granularity and just use a general loop scope. + return [tn.LoopScope(loop=block, children=children)] + return children - elif isinstance(block, ConditionalBlock): + + if isinstance(block, ConditionalBlock): result: List[tn.ScheduleTreeNode] = [] if_node = tn.IfScope(condition=block.branches[0][0], children=_block_schedule_tree(block.branches[0][1])) result.append(if_node) @@ -680,9 +733,11 @@ def _block_schedule_tree(block: ControlFlowBlock) -> List[tn.ScheduleTreeNode]: else_node = tn.ElseScope(children=_block_schedule_tree(branch_body)) result.append(else_node) return result - elif isinstance(block, SDFGState): + + if isinstance(block, SDFGState): return _state_schedule_tree(block) - elif isinstance(block, ReturnBlock): + + if isinstance(block, ReturnBlock): # For return blocks, add a goto node to the end of the schedule tree. # NOTE: Return blocks currently always exit the entire SDFG context they are contained in, meaning that the exit # goto has target=None. However, in the future we want to adapt Return blocks to be able to return only a @@ -690,8 +745,8 @@ def _block_schedule_tree(block: ControlFlowBlock) -> List[tn.ScheduleTreeNode]: # entire SDFG. goto_node = tn.GotoNode(target=None) return [goto_node] - else: - raise tn.UnsupportedScopeException(type(block).__name__) + + raise tn.UnsupportedScopeException(type(block).__name__) def _generate_views_in_scope( @@ -725,7 +780,47 @@ def _generate_views_in_scope( return result -def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeScope: +def _prepare_sdfg_for_conversion(sdfg: SDFG, *, toplevel: bool) -> None: + from dace.transformation import helpers as xfh # Avoid import loop + + # Split edges with assignments and conditions + xfh.split_interstate_edges(sdfg) + + # Replace code->code edges with data<->code edges + xfh.replace_code_to_code_edges(sdfg) + + if toplevel: # Top-level SDFG preparation (only perform once) + # Handle name collisions (in arrays, state labels, symbols) + remove_name_collisions(sdfg) + + # Ensure no arrays alias in SDFG tree + dealias_sdfg(sdfg) + + +def _create_unified_descriptor_repository(sdfg: SDFG, stree: tn.ScheduleTreeRoot) -> None: + """ + Creates a single descriptor repository from an SDFG and all nested SDFGs. This includes + data containers, symbols, constants, etc. + + :param sdfg: The top-level SDFG to create the repository from. + :param stree: The tree root in which to make the unified descriptor repository. + """ + stree.containers = sdfg.arrays + stree.symbols = sdfg.symbols + stree.constants = sdfg.constants_prop + + # Since the SDFG is assumed to be de-aliased and contain unique names, we union the contents of + # the nested SDFGs' descriptor repositories + for nsdfg in sdfg.all_sdfgs_recursive(): + transients = {k: v for k, v in nsdfg.arrays.items() if v.transient} + symbols = {k: v for k, v in nsdfg.symbols.items() if k not in stree.symbols} + constants = {k: v for k, v in nsdfg.constants_prop.items() if k not in stree.constants} + stree.containers.update(transients) + stree.symbols.update(symbols) + stree.constants.update(constants) + + +def as_schedule_tree(sdfg: SDFG, *, in_place: bool = False, toplevel: bool = True) -> tn.ScheduleTreeRoot: """ Converts an SDFG into a schedule tree. The schedule tree is a tree of nodes that represent the execution order of the SDFG. @@ -741,30 +836,21 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) usable after the conversion if ``in_place`` is True! :return: A schedule tree representing the given SDFG. """ - from dace.transformation import helpers as xfh # Avoid import loop if not in_place: sdfg = copy.deepcopy(sdfg) - # Prepare SDFG for conversion - ############################# - - # Split edges with assignments and conditions - xfh.split_interstate_edges(sdfg) - - # Replace code->code edges with data<->code edges - xfh.replace_code_to_code_edges(sdfg) - - if toplevel: # Top-level SDFG preparation (only perform once) - # Handle name collisions (in arrays, state labels, symbols) - remove_name_collisions(sdfg) - # Ensure no arrays alias in SDFG tree - dealias_sdfg(sdfg) + _prepare_sdfg_for_conversion(sdfg, toplevel=toplevel) - ############################# + if toplevel: + result = tn.ScheduleTreeRoot(name=sdfg.name, children=[]) + _create_unified_descriptor_repository(sdfg, result) + result.add_children(_block_schedule_tree(sdfg)) + else: + result = tn.ScheduleTreeScope(children=_block_schedule_tree(sdfg)) - result = tn.ScheduleTreeScope(children=_block_schedule_tree(sdfg)) tn.validate_has_no_other_node_types(result) + tn.validate_children_and_parents_align(result, root=toplevel) # Clean up tree stpasses.remove_unused_and_duplicate_labels(result) diff --git a/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py new file mode 100644 index 0000000000..988e73fff9 --- /dev/null +++ b/dace/sdfg/analysis/schedule_tree/tree_to_sdfg.py @@ -0,0 +1,1007 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import copy +from collections import defaultdict +from dace import symbolic +from dace.memlet import Memlet +from dace.sdfg import nodes, memlet_utils as mmu +from dace.sdfg.sdfg import SDFG, ControlFlowRegion, InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, SDFGState +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace.sdfg import propagation +from enum import Enum, auto +from typing import Dict, Final, List, Optional, Set, Tuple + + +class StateBoundaryBehavior(Enum): + STATE_TRANSITION = auto() #: Creates multiple states with a state transition + EMPTY_MEMLET = auto() #: Happens-before empty memlet edges in the same state + + +PREFIX_PASSTHROUGH_IN: Final[str] = "IN_" +PREFIX_PASSTHROUGH_OUT: Final[str] = "OUT_" +MAX_NESTED_SDFGS: Final[int] = 1000 + + +class StreeToSDFG(tn.ScheduleNodeVisitor): + + def __init__(self, start_state: Optional[SDFGState] = None) -> None: + self._ctx: tn.Context + """Context information like tree root and current scope.""" + + self._current_state = start_state + """Current SDFGState in the SDFG that we are building.""" + + self._current_nestedSDFG: int | None = None + """Id of the current nested SDFG if we are inside one.""" + + self._interstate_symbols: List[tn.AssignNode] = [] + """Interstate symbol assignments. Will be assigned with the next state transition.""" + + self._nviews_free: List[tn.NView] = [] + """Keep track of NView (nested SDFG view) nodes that are "free" to be used.""" + + self._nviews_bound_per_scope: Dict[int, List[tn.NView]] = {} + """Mapping of id(SDFG) -> list of active NView nodes in that SDFG.""" + + self._nviews_deferred_removal: Dict[int, List[tn.NView]] = {} + """"Mapping of id(SDFG) -> list of NView nodes to be removed once we exit this nested SDFG.""" + + # state management + self._state_stack: List[SDFGState] = [] + + # dataflow scopes + # List[ (MapEntryNode, ToConnect) | (SDFG, {"inputs": set(), "outputs": set()}) ] + self._dataflow_stack: List[Tuple[nodes.EntryNode, Dict[str, Tuple[nodes.AccessNode, Memlet]]] + | Tuple[SDFG, Dict[str, Set[str]]]] = [] + + def _apply_nview_array_override(self, array_name: str, sdfg: SDFG) -> bool: + """Apply an NView override if applicable. Returns true if the NView was applied.""" + length = len(self._nviews_free) + for index, nview in enumerate(reversed(self._nviews_free), start=1): + if nview.target == array_name and nview not in self._nviews_deferred_removal[id(sdfg)]: + # Add the "override" data descriptor + sdfg.add_datadesc(nview.target, nview.view_desc.clone()) + if nview.src_desc.transient: + sdfg.arrays[nview.target].transient = False + + # Keep track of used NViews per scope (to "free" them again once the scope ends) + self._nviews_bound_per_scope[id(sdfg)].append(nview) + + # This NView is in use now, remove it from the free NViews. + del self._nviews_free[length - index] + return True + + return False + + def _parent_sdfg_with_array(self, name: str, sdfg: SDFG) -> SDFG: + """Find the closest parent SDFG containing an array with the given name.""" + parent_sdfg = sdfg.parent.parent + sdfg_counter = 1 + while name not in parent_sdfg.arrays and sdfg_counter < MAX_NESTED_SDFGS: + parent_sdfg = parent_sdfg.parent.parent + assert isinstance(parent_sdfg, SDFG) + sdfg_counter += 1 + assert sdfg_counter < MAX_NESTED_SDFGS, f"Array '{name}' not found in any parent of SDFG '{sdfg.name}'." + return parent_sdfg + + def _pop_state(self, label: Optional[str] = None) -> SDFGState: + """Pops the last state from the state stack. + + :param str, optional label: Ensures the popped state's label starts with the given string. + + :return: The popped state. + """ + if not self._state_stack: + raise ValueError("Can't pop state from empty stack.") + + popped = self._state_stack.pop() + if label is not None: + assert popped.label.startswith(label) + + return popped + + def visit_ScheduleTreeRoot(self, node: tn.ScheduleTreeRoot, sdfg: SDFG) -> None: + assert self._current_state is None, "Expected no 'current_state' at root." + assert not self._state_stack, "Expected empty state stack at root." + assert not self._dataflow_stack, "Expected empty dataflow stack at root." + assert not self._interstate_symbols, "Expected empty list of symbols at root." + + self._current_state = sdfg.add_state(label="tree_root", is_start_block=True) + self._ctx = tn.Context(root=node, access_cache={}, current_scope=None) + with node.scope(self._current_state, self._ctx): + self.visit(node.children, sdfg=sdfg) + + assert not self._state_stack, "Expected empty state stack." + assert not self._dataflow_stack, "Expected empty dataflow stack." + assert not self._interstate_symbols, "Expected empty list of symbols to add." + + def visit_GBlock(self, node: tn.GBlock, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_StateLabel(self, node: tn.StateLabel, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_GotoNode(self, node: tn.GotoNode, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for{type(node)} not yet implemented.") + + def visit_AssignNode(self, node: tn.AssignNode, sdfg: SDFG) -> None: + # We just collect them here. They'll be added when state boundaries are added, + # see visitors below. + self._interstate_symbols.append(node) + + # If AssignNode depends on arrays, e.g. `my_sym = my_array[__k] > 0`, make sure array accesses can be resolved. + input_memlets = node.input_memlets() + if not input_memlets: + return + + for entry in reversed(self._dataflow_stack): + scope_node, to_connect = entry + if isinstance(scope_node, SDFG): + # In case we are inside a nested SDFG, make sure memlet data can be + # resolved by explicitly adding inputs. + for memlet in input_memlets: + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: + parent_sdfg = self._parent_sdfg_with_array(memlet.data, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + # Dev note: nview.target and memlet.data are identical + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) + return + + for memlet in input_memlets: + # If we aren't inside a nested SDFG, make sure all memlets can be resolved. + # Imo, this should always be the case. It not, raise an error. + if memlet.data not in sdfg.arrays: + raise ValueError(f"Parsing AssignNode {node} failed. Can't find {memlet.data} in {sdfg}.") + + def visit_ForScope(self, node: tn.ForScope, sdfg: SDFG) -> None: + current_state = self._current_state + cf_region = current_state.parent_graph + + loop_region = node.loop + cf_region.add_node(loop_region) + loop_state = loop_region.add_state(f"for_loop_state_{id(node)}", is_start_block=True) + + _insert_and_split_assignments(current_state, loop_region) + + self._current_state = loop_state + self.visit(node.children, sdfg=sdfg) + + after_state = _insert_and_split_assignments(loop_region, label="loop_after") + self._current_state = after_state + + def visit_WhileScope(self, node: tn.WhileScope, sdfg: SDFG) -> None: + current_state = self._current_state + cf_region = current_state.parent_graph + + loop_region = node.loop + cf_region.add_node(loop_region) + loop_state = loop_region.add_state(f"while_loop_state_{id(node)}", is_start_block=True) + + _insert_and_split_assignments(current_state, loop_region) + + self._current_state = loop_state + self.visit(node.children, sdfg=sdfg) + + after_state = _insert_and_split_assignments(loop_region, label="loop_after") + self._current_state = after_state + + def visit_DoWhileScope(self, node: tn.DoWhileScope, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_LoopScope(self, node: tn.LoopScope, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_IfScope(self, node: tn.IfScope, sdfg: SDFG) -> None: + before_state = self._current_state + cf_region = before_state.parent_graph + + conditional_block = ConditionalBlock(f"if_scope_{id(node)}") + cf_region.add_node(conditional_block) + _insert_and_split_assignments( + before_state, + conditional_block, + assignments=self._pending_interstate_assignments(), + ) + + if_body = ControlFlowRegion("if_body", sdfg=sdfg) + conditional_block.add_branch(node.condition, if_body) + + if_state = if_body.add_state("if_state", is_start_block=True) + self._current_state = if_state + + # visit children of that branch + self.visit(node.children, sdfg=sdfg) + + self._current_state = conditional_block + + # add merge_state + merge_state = _insert_and_split_assignments( + conditional_block, + label="merge_state", + assignments=self._pending_interstate_assignments(), + ) + + # Check if there's an `ElseScope` following this node (in the parent's children). + # Filter StateBoundaryNodes, which we inserted earlier, for this analysis. + filtered = [n for n in node.parent.children if not isinstance(n, tn.StateBoundaryNode)] + if_index = _list_index(filtered, node) + has_else_branch = len(filtered) > if_index + 1 and isinstance(filtered[if_index + 1], tn.ElseScope) + + if has_else_branch: + # push merge_state on the stack for later usage in `visit_ElseScope` + self._state_stack.append(merge_state) + # push condition_block on the stack for later usage in `visit_ElseScope` + self._state_stack.append(conditional_block) + else: + self._current_state = merge_state + + def visit_StateIfScope(self, node: tn.StateIfScope, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_BreakNode(self, node: tn.BreakNode, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_ContinueNode(self, node: tn.ContinueNode, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_ElifScope(self, node: tn.ElifScope, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_ElseScope(self, node: tn.ElseScope, sdfg: SDFG) -> None: + # get ConditionalBlock form stack + conditional_block: ConditionalBlock = self._pop_state("if_scope") + + else_body = ControlFlowRegion("else_body", sdfg=sdfg) + conditional_block.add_branch(None, else_body) + + else_state = else_body.add_state("else_state", is_start_block=True) + self._current_state = else_state + + # visit children inside the else branch + self.visit(node.children, sdfg=sdfg) + + # merge false-branch into merge_state + merge_state = self._pop_state("merge_state") + self._current_state = merge_state + + if self._pending_interstate_assignments(): + raise NotImplementedError("TODO: update edge with new assignments") + + def _insert_nestedSDFG(self, node: tn.MapScope, sdfg: SDFG) -> None: + dataflow_stack_size = len(self._dataflow_stack) + state_stack_size = len(self._state_stack) + outer_nestedSDFG = self._current_nestedSDFG + + # prepare inner SDFG + inner_sdfg = SDFG("nested_sdfg", parent=self._current_state) + start_state = inner_sdfg.add_state("nested_root", is_start_block=True) + + # update stacks and current state + old_state_label = self._current_state.label + self._state_stack.append(self._current_state) + self._dataflow_stack.append((inner_sdfg, {"inputs": set(), "outputs": set()})) + self._nviews_bound_per_scope[id(inner_sdfg)] = [] + self._nviews_deferred_removal[id(inner_sdfg)] = [] + self._current_nestedSDFG = id(inner_sdfg) + self._current_state = start_state + + # visit children + with node.scope(self._current_state, self._ctx): + self.visit(node.children, sdfg=inner_sdfg) + + # restore current state and stacks + self._current_state = self._pop_state(old_state_label) + assert len(self._state_stack) == state_stack_size + _, connectors = self._dataflow_stack.pop() + assert len(self._dataflow_stack) == dataflow_stack_size + + # insert nested SDFG + nsdfg = self._current_state.add_nested_sdfg(sdfg=inner_sdfg, + inputs=connectors["inputs"], + outputs=connectors["outputs"]) + # connect nested SDFG to surrounding map scope + assert self._dataflow_stack + map_entry, to_connect = self._dataflow_stack[-1] + + # connect nsdfg input memlets (to be propagated upon completion of the SDFG) + for name in nsdfg.in_connectors: + out_connector = f"{PREFIX_PASSTHROUGH_OUT}{name}" + new_in_connector = map_entry.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{name}") + new_out_connector = map_entry.add_out_connector(out_connector) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + # Add Memlet for NView node (if applicable) + edge_added = False + for nview in self._nviews_bound_per_scope[id(inner_sdfg)]: + if name == nview.target: + self._current_state.add_edge(map_entry, out_connector, nsdfg, name, + Memlet.from_memlet(nview.memlet)) + edge_added = True + break + + if not edge_added: + self._current_state.add_edge(map_entry, out_connector, nsdfg, name, + Memlet.from_array(name, nsdfg.sdfg.arrays[name])) + + # Add empty memlet if we didn't add any in the loop above + if self._current_state.out_degree(map_entry) < 1: + self._current_state.add_nedge(map_entry, nsdfg, Memlet()) + + # connect nsdfg output memlets (to be propagated) + for name in nsdfg.out_connectors: + # Add memlets for NView node (if applicable) + edge_added = False + for nview in self._nviews_bound_per_scope[id(inner_sdfg)]: + if name == nview.target: + to_connect[name] = (nsdfg, Memlet.from_memlet(nview.memlet)) + edge_added = True + break + + if not edge_added: + to_connect[name] = (nsdfg, Memlet.from_array(name, nsdfg.sdfg.arrays[name])) + + # Move NViews back to "free" NViews for usage in a sibling scope. + for nview in self._nviews_bound_per_scope[id(inner_sdfg)]: + # If this NView ended in the current nested SDFG, don't add it back to the + # "free NView" nodes. We need to keep it alive until here to make sure that + # we can add the memlets above. + if nview in self._nviews_deferred_removal[id(inner_sdfg)]: + continue + self._nviews_free.append(nview) + + del self._nviews_bound_per_scope[id(inner_sdfg)] + del self._nviews_deferred_removal[id(inner_sdfg)] + + # Restore current nested SDFG + self._current_nestedSDFG = outer_nestedSDFG + + def visit_MapScope(self, node: tn.MapScope, sdfg: SDFG) -> None: + dataflow_stack_size = len(self._dataflow_stack) + cache_state = self._current_state + + # map entry + # --------- + map_entry = nodes.MapEntry(node.node.map) + self._current_state.add_node(map_entry) + self._dataflow_stack.append((map_entry, dict())) + + # visit children inside the map + type_of_children = [type(child) for child in node.children] + last_child_is_MapScope = type_of_children[-1] == tn.MapScope + all_others_are_Boundaries = type_of_children.count(tn.StateBoundaryNode) == len(type_of_children) - 1 + if last_child_is_MapScope and all_others_are_Boundaries: + # skip weirdly added StateBoundaryNode + # tmp: use this - for now - to "backprop-insert" extra state boundaries for nested SDFGs + with node.scope(self._current_state, self._ctx): + self.visit(node.children[-1], sdfg=sdfg) + elif any([isinstance(child, tn.StateBoundaryNode) for child in node.children]): + self._insert_nestedSDFG(node, sdfg) + else: + with node.scope(self._current_state, self._ctx): + self.visit(node.children, sdfg=sdfg) + + cache_key = (cache_state, id(self._ctx.current_scope)) + if cache_key not in self._ctx.access_cache: + self._ctx.access_cache[cache_key] = {} + access_cache = self._ctx.access_cache[cache_key] + + # dataflow stack management + _, to_connect = self._dataflow_stack.pop() + assert len(self._dataflow_stack) == dataflow_stack_size + outer_map_entry, outer_to_connect = self._dataflow_stack[-1] if dataflow_stack_size else (None, None) + + # connect potential input connectors on map_entry + for connector in map_entry.in_connectors: + memlet_data = connector.removeprefix(PREFIX_PASSTHROUGH_IN) + + # connect to local access node (if available) + if memlet_data in access_cache: + cached_access = access_cache[memlet_data] + self._current_state.add_memlet_path(cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + continue + + if isinstance(outer_map_entry, nodes.EntryNode): + + # get it from outside the map + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet_data}" + if connector_name not in outer_map_entry.out_connectors: + new_in_connector = outer_map_entry.add_in_connector(connector) + new_out_connector = outer_map_entry.add_out_connector(connector_name) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(outer_map_entry, connector_name, map_entry, connector, + Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + else: + if isinstance(outer_map_entry, SDFG): + # Copy data descriptor from parent SDFG and add input connector + if memlet_data not in sdfg.arrays: + parent_sdfg: SDFG = self._parent_sdfg_with_array(memlet_data, sdfg) + + # Add support for NView nodes + use_nview = self._apply_nview_array_override(memlet_data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet_data, parent_sdfg.arrays[memlet_data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet_data].transient: + sdfg.arrays[memlet_data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + # Dev note: nview.target and memlet_data are identical + assert memlet_data not in outer_to_connect["inputs"] + outer_to_connect["inputs"].add(memlet_data) + else: + assert outer_map_entry is None + + # cache local read access + assert memlet_data not in access_cache + access_cache[memlet_data] = self._current_state.add_read(memlet_data) + cached_access = access_cache[memlet_data] + self._current_state.add_memlet_path(cached_access, + map_entry, + dst_conn=connector, + memlet=Memlet.from_array(memlet_data, sdfg.arrays[memlet_data])) + + if isinstance(outer_map_entry, nodes.EntryNode) and self._current_state.out_degree(outer_map_entry) < 1: + self._current_state.add_nedge(outer_map_entry, map_entry, Memlet()) + + # map_exit + # -------- + map_exit = nodes.MapExit(node.node.map) + self._current_state.add_node(map_exit) + + # connect writes to map_exit node + for name in to_connect: + in_connector_name = f"{PREFIX_PASSTHROUGH_IN}{name}" + out_connector_name = f"{PREFIX_PASSTHROUGH_OUT}{name}" + new_in_connector = map_exit.add_in_connector(in_connector_name) + new_out_connector = map_exit.add_out_connector(out_connector_name) + assert new_in_connector == new_out_connector + + # connect "inside the map" + access_node, memlet = to_connect[name] + if isinstance(access_node, nodes.NestedSDFG): + self._current_state.add_edge(access_node, name, map_exit, in_connector_name, memlet) + else: + assert isinstance(access_node, nodes.AccessNode) + if self._current_state.out_degree(access_node) == 0 and self._current_state.in_degree(access_node) == 1: + # this access_node is not used for anything else. + # let's remove it and add a direct connection instead + edges = [edge for edge in self._current_state.edges() if edge.dst == access_node] + assert len(edges) == 1 + self._current_state.add_memlet_path(edges[0].src, + map_exit, + src_conn=edges[0].src_conn, + dst_conn=in_connector_name, + memlet=edges[0].data) + self._current_state.remove_node(access_node) # edge is remove automatically + else: + self._current_state.add_memlet_path(access_node, + map_exit, + dst_conn=in_connector_name, + memlet=memlet) + + if isinstance(outer_map_entry, SDFG): + if name not in sdfg.arrays: + parent_sdfg = self._parent_sdfg_with_array(name, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(name, sdfg) + if not use_nview: + sdfg.add_datadesc(name, parent_sdfg.arrays[name].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[name].transient: + sdfg.arrays[name].transient = False + + # Add out_connector in any case if not yet present, e.g. write after read + # Dev not: name and nview.target are identical + outer_to_connect["outputs"].add(name) + + # connect "outside the map" + # only re-use cached write-only nodes, e.g. don't create a cycle for + # map i=0:20: + # A[i] = tasklet(A[i]) + if name not in access_cache or self._current_state.out_degree(access_cache[name]) > 0: + # cache write access into access_cache + write_access_node = self._current_state.add_write(name) + access_cache[name] = write_access_node + + access_node = access_cache[name] + self._current_state.add_memlet_path(map_exit, + access_node, + src_conn=out_connector_name, + memlet=Memlet.from_array(name, sdfg.arrays[name])) + + if isinstance(outer_map_entry, nodes.EntryNode): + outer_to_connect[name] = (access_node, Memlet.from_array(name, sdfg.arrays[name])) + else: + assert isinstance(outer_map_entry, SDFG) or outer_map_entry is None + + # TODO If nothing is connected at this point, figure out what's the last thing that + # we should connect to. Then, add an empty memlet from that last thing to this + # map_exit. + assert len(self._current_state.in_edges(map_exit)) > 0 + + def visit_ConsumeScope(self, node: tn.ConsumeScope, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_TaskletNode(self, node: tn.TaskletNode, sdfg: SDFG) -> None: + # Add Tasklet to current state + tasklet = node.node + self._current_state.add_node(tasklet) + + cache_key = (self._current_state, id(self._ctx.current_scope)) + if cache_key not in self._ctx.access_cache: + self._ctx.access_cache[cache_key] = {} + cache = self._ctx.access_cache[cache_key] + scope_node, to_connect = self._dataflow_stack[-1] if self._dataflow_stack else (None, None) + + # Connect input memlets + for name, memlet in node.in_memlets.items(): + # connect to local access node if possible + if memlet.data in cache: + cached_access = cache[memlet.data] + self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) + continue + + if isinstance(scope_node, nodes.MapEntry): + # get it from outside the map + connector_name = f"{PREFIX_PASSTHROUGH_OUT}{memlet.data}" + if connector_name not in scope_node.out_connectors: + new_in_connector = scope_node.add_in_connector(f"{PREFIX_PASSTHROUGH_IN}{memlet.data}") + new_out_connector = scope_node.add_out_connector(connector_name) + assert new_in_connector == True + assert new_in_connector == new_out_connector + + self._current_state.add_edge(scope_node, connector_name, tasklet, name, memlet) + continue + + if isinstance(scope_node, SDFG): + # Copy data descriptor from parent SDFG and add input connector + if memlet.data not in sdfg.arrays: + parent_sdfg = self._parent_sdfg_with_array(memlet.data, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + # TODO + # ... unless they are only ever used inside the nested SDFG, in which case + # we should delete them from the parent SDFG's array list. + # NOTE This can probably be done automatically by a cleanup pass in the end. + # Something like DDE should be able to do this. + + # Dev note: memlet.data and nview.target are identical + assert memlet.data not in to_connect["inputs"] + to_connect["inputs"].add(memlet.data) + else: + assert scope_node is None + + # cache local read access + assert memlet.data not in cache + cache[memlet.data] = self._current_state.add_read(memlet.data) + cached_access = cache[memlet.data] + self._current_state.add_memlet_path(cached_access, tasklet, dst_conn=name, memlet=memlet) + + # Add empty memlet if map_entry has no out_connectors to connect to + if isinstance(scope_node, nodes.MapEntry) and self._current_state.out_degree(scope_node) < 1: + self._current_state.add_nedge(scope_node, tasklet, Memlet()) + + # Connect output memlets + for name, memlet in node.out_memlets.items(): + # only re-use cached write-only nodes, e.g. don't create a cycle for + # A[1] = tasklet(A[1]) + if memlet.data not in cache or self._current_state.out_degree(cache[memlet.data]) > 0: + # cache write access node + write_access_node = self._current_state.add_write(memlet.data) + cache[memlet.data] = write_access_node + + access_node = cache[memlet.data] + self._current_state.add_memlet_path(tasklet, access_node, src_conn=name, memlet=memlet) + + if isinstance(scope_node, nodes.MapEntry): + # copy the memlet since we already used it in the memlet path above + to_connect[memlet.data] = (access_node, copy.deepcopy(memlet)) + continue + + if isinstance(scope_node, SDFG): + if memlet.data not in sdfg.arrays: + parent_sdfg: SDFG = self._parent_sdfg_with_array(memlet.data, sdfg) + + # Support for NView nodes + use_nview = self._apply_nview_array_override(memlet.data, sdfg) + if not use_nview: + sdfg.add_datadesc(memlet.data, parent_sdfg.arrays[memlet.data].clone()) + + # Transients passed into a nested SDFG become non-transient inside that nested SDFG + if parent_sdfg.arrays[memlet.data].transient: + sdfg.arrays[memlet.data].transient = False + + # Add out_connector in any case if not yet present, e.g. write after read + # Dev note: memlet.data and nview.target are identical + to_connect["outputs"].add(memlet.data) + + else: + assert scope_node is None + + def visit_LibraryCall(self, node: tn.LibraryCall, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_CopyNode(self, node: tn.CopyNode, sdfg: SDFG) -> None: + # ensure we have an access_cache and fetch it + cache_key = (self._current_state, id(self._ctx.current_scope)) + if cache_key not in self._ctx.access_cache: + self._ctx.access_cache[cache_key] = {} + access_cache = self._ctx.access_cache[cache_key] + + # assumption source access may or may not yet exist (in this state) + src_name = node.memlet.data + source = access_cache[src_name] if src_name in access_cache else self._current_state.add_read(src_name) + + # assumption: target access node doesn't exist yet + assert node.target not in access_cache + target = self._current_state.add_write(node.target) + + self._current_state.add_memlet_path(source, target, memlet=node.memlet) + + def visit_DynScopeCopyNode(self, node: tn.DynScopeCopyNode, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_ViewNode(self, node: tn.ViewNode, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_NView(self, node: tn.NView, sdfg: SDFG) -> None: + # Basic working principle: + # + # - NView and (artificial) NViewEnd nodes are added in parallel to mark the region where the view applies. + # - Keep a stack of NView nodes (per name) that is pushed/popped when NView and NViewEnd nodes are visited. + # - In between, when going "down into" a NestedSDFG, use the current NView (if it applies) + # - In between, when "coming back up" from a NestedSDFG, pop the NView from the stack. + # - AccessNodes will automatically pick up the right name (from the NestedSDFG's array list) + self._nviews_free.append(node) + + def visit_NViewEnd(self, node: tn.NViewEnd, sdfg: SDFG) -> None: + # If bound to the current nested SDFG, defer cleanup + if self._current_nestedSDFG is not None: + currently_bound = self._nviews_bound_per_scope[self._current_nestedSDFG] + for index, nview in enumerate(reversed(currently_bound)): + if node.target == nview.target: + # Bound to current nested SDFG. Slate for deferred removal once we exit that nested SDFG. + self._nviews_deferred_removal[self._current_nestedSDFG].append(nview) + return + + length = len(self._nviews_free) + for index, nview in enumerate(reversed(self._nviews_free), start=1): + if node.target == nview.target: + # Stack semantics: remove from the back of the list + del self._nviews_free[length - index] + return + + raise RuntimeError(f"No matching NView found for target {node.target} in {self._nviews_free}.") + + def visit_RefSetNode(self, node: tn.RefSetNode, sdfg: SDFG) -> None: + raise NotImplementedError(f"Support for {type(node)} not yet implemented.") + + def visit_StateBoundaryNode(self, node: tn.StateBoundaryNode, sdfg: SDFG) -> None: + # When creating a state boundary, include all inter-state assignments that precede it. + pending = self._pending_interstate_assignments() + + self._current_state = create_state_boundary( + node, + self._current_state, + StateBoundaryBehavior.STATE_TRANSITION, + assignments=pending, + ) + + def _pending_interstate_assignments(self) -> Dict: + """ + Return currently pending interstate assignments. Clears the cache. + """ + assignments = {} + + for symbol in self._interstate_symbols: + assignments[symbol.name] = symbol.value.as_string + self._interstate_symbols.clear() + + return assignments + + +def from_schedule_tree(stree: tn.ScheduleTreeRoot, + state_boundary_behavior: StateBoundaryBehavior = StateBoundaryBehavior.STATE_TRANSITION) -> SDFG: + """ + Converts a schedule tree into an SDFG. + + :param stree: The schedule tree root to convert. + :param state_boundary_behavior: Sets the behavior upon encountering a state boundary (e.g., write-after-write). + See the ``StateBoundaryBehavior`` enumeration for more details. + :return: An SDFG representing the schedule tree. + """ + # Setup SDFG descriptor repository + result = SDFG(stree.name, propagate=False) + result.arg_names = copy.deepcopy(stree.arg_names) + for key, container in stree.containers.items(): + result._arrays[key] = copy.deepcopy(container) + result.constants_prop = copy.deepcopy(stree.constants) + result.symbols = copy.deepcopy(stree.symbols) + + # Insert artificial state boundaries after WAW, before label, etc. + stree = insert_state_boundaries_to_tree(stree) + + # Traverse tree and incrementally build SDFG, finally propagate memlets + StreeToSDFG().visit(stree, sdfg=result) + propagation.propagate_memlets_sdfg(result) + + return result + + +def insert_state_boundaries_to_tree(stree: tn.ScheduleTreeRoot) -> tn.ScheduleTreeRoot: + """ + Inserts StateBoundaryNode objects into a schedule tree where more than one SDFG state would be necessary. + Operates in-place on the given schedule tree. + + This happens when there is a: + * write-after-write dependency; + * write-after-read dependency that cannot be fulfilled via memlets; + * control flow block (for/if); or + * otherwise before a state label (which means a state transition could occur, e.g., in a gblock) + + :param stree: The schedule tree to operate on. + """ + + # Simple boundary node inserter for control flow blocks and state labels + class SimpleStateBoundaryInserter(tn.ScheduleNodeTransformer): + + def visit_scope(self, scope: tn.ScheduleTreeScope): + if isinstance(scope, tn.ControlFlowScope) and not isinstance(scope, (tn.ElifScope, tn.ElseScope)): + return [tn.StateBoundaryNode(True), self.generic_visit(scope)] + return self.generic_visit(scope) + + def visit_StateLabel(self, node: tn.StateLabel): + return [tn.StateBoundaryNode(True), self.generic_visit(node)] + + # First, insert boundaries around labels and control flow + stree = SimpleStateBoundaryInserter().visit(stree) + + # Then, insert boundaries after unmet memory dependencies or potential data races + _insert_memory_dependency_state_boundaries(stree) + + # Insert a state boundary after every symbol assignment to ensure symbols are assigned before usage + class SymbolAssignmentBoundaryInserter(tn.ScheduleNodeTransformer): + + def visit_AssignNode(self, node: tn.AssignNode): + # We can assume that assignment nodes are at least contained in the root scope. + assert node.parent, "Expected assignment nodes live a parent scope." + + # Find this node in the parent's children. + node_index = _list_index(node.parent.children, node) + + # Don't add boundary if there's already one or for immediately following assignment nodes. + if node_index < len(node.parent.children) - 1 and isinstance(node.parent.children[node_index + 1], + (tn.StateBoundaryNode, tn.AssignNode)): + return self.generic_visit(node) + + return [self.generic_visit(node), tn.StateBoundaryNode()] + + stree = SymbolAssignmentBoundaryInserter().visit(stree) + + # Hack: "backprop-insert" state boundaries from nested SDFGs + class NestedSDFGStateBoundaryInserter(tn.ScheduleNodeTransformer): + + def visit_MapScope(self, scope: tn.MapScope): + visited = self.generic_visit(scope) + if any([isinstance(child, tn.StateBoundaryNode) for child in scope.children]): + # We can assume that map nodes are at least contained in the root scope. + assert scope.parent is not None + + # Find this scope in its parent's children + node_index = _list_index(scope.parent.children, scope) + + # If there's already a state boundary before the map, don't add another one + if node_index > 0 and isinstance(scope.parent.children[node_index - 1], tn.StateBoundaryNode): + return visited + + return [tn.StateBoundaryNode(), visited] + return visited + + stree = NestedSDFGStateBoundaryInserter().visit(stree) + + return stree + + +def _insert_memory_dependency_state_boundaries(scope: tn.ScheduleTreeScope): + """ + Helper function that inserts boundaries after unmet memory dependencies. + """ + reads: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict() + writes: mmu.MemletDict[List[tn.ScheduleTreeNode]] = mmu.MemletDict() + parents: Dict[int, Set[int]] = defaultdict(set) + boundaries_to_insert: List[int] = [] + + for i, n in enumerate(scope.children): + if isinstance(n, (tn.StateBoundaryNode, tn.ControlFlowScope)): # Clear state + reads.clear() + writes.clear() + parents.clear() + if isinstance(n, tn.ControlFlowScope): # Insert memory boundaries recursively + _insert_memory_dependency_state_boundaries(n) + continue + + # If dataflow scope, insert state boundaries recursively and as a node + if isinstance(n, tn.DataflowScope): + _insert_memory_dependency_state_boundaries(n) + + inputs = n.input_memlets() + outputs = n.output_memlets() + + # Register reads + for inp in inputs: + if inp not in reads: + reads[inp] = [n] + else: + reads[inp].append(n) + + # Transitively add parents + if inp in writes: + for parent in writes[inp]: + parents[id(n)].add(id(parent)) + parents[id(n)].update(parents[id(parent)]) + + # Inter-state assignment nodes with reads necessitate a state transition if they were written to. + if isinstance(n, tn.AssignNode) and any(inp in writes for inp in inputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Write after write or potential write/write data race, insert state boundary + if any(o in writes and (o not in reads or any(id(r) not in parents for r in reads[o])) for o in outputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Potential read/write data race: if any read is not in the parents of this node, it might + # be performed in parallel + if any(o in reads and any(id(r) not in parents for r in reads[o]) for o in outputs): + boundaries_to_insert.append(i) + reads.clear() + writes.clear() + parents.clear() + continue + + # Register writes after all hazards have been tested for + for out in outputs: + if out not in writes: + writes[out] = [n] + else: + writes[out].append(n) + + # Insert memory dependency state boundaries in reverse in order to keep indices intact + for i in reversed(boundaries_to_insert): + scope.children.insert(i, tn.StateBoundaryNode()) + + +############################################################################# +# SDFG content creation functions + + +def create_state_boundary( + boundary_node: tn.StateBoundaryNode, + state: SDFGState, + behavior: StateBoundaryBehavior, + assignments: Optional[Dict] = None, +) -> SDFGState: + """ + Creates a boundary between two states + + :param boundary_node: The state boundary node to generate. + :param state: The last state prior to this boundary. + :param behavior: The state boundary behavior with which to create the boundary. + :return: The newly created state. + """ + if behavior != StateBoundaryBehavior.STATE_TRANSITION: + raise NotImplementedError("Only STATE_TRANSITION is supported as StateBoundaryBehavior in this prototype.") + + # TODO: Some boundaries (control flow, state labels with goto) could not be fulfilled with every + # behavior. Fall back to state transition in that case. + + label = "cf_state_boundary" if boundary_node.due_to_control_flow else "state_boundary" + assignments = assignments if assignments is not None else {} + return _insert_and_split_assignments(state, label=label, assignments=assignments) + + +def _insert_and_split_assignments( + before_state: ControlFlowBlock, + after_state: Optional[ControlFlowBlock] = None, + *, + label: Optional[str] = None, + assignments: Optional[Dict] = None, +) -> ControlFlowBlock: + """ + Insert given assignments splitting them in case of potential race conditions. + + DaCe validation (currently) won't let us add multiple assignment with read after + write pattern on the same edge. We thus split them over multiple state transitions + (inserting empty states in between) to be safe. + + NOTE (later) This should be double-checked since python dictionaries preserve + insertion order since python 3.7 (which we rely on in this function + too). Depending on code generation it could(TM) be that we can + weaken (best case remove) the corresponding check from the sdfg + validator. + """ + assignments = assignments if assignments is not None else {} + cf_region = before_state.parent_graph + if after_state is not None and after_state.parent_graph != cf_region: + raise ValueError("Expected before_state and after_state to be in the same control flow region.") + + has_potential_race = False + for key, value in assignments.items(): + syms = symbolic.free_symbols_and_functions(value) + also_assigned = (syms & assignments.keys()) - {key} + if also_assigned: + has_potential_race = True + break + + if not has_potential_race: + if after_state is not None: + cf_region.add_edge(before_state, after_state, InterstateEdge(assignments=assignments)) + return after_state + + return cf_region.add_state_after(before_state, label=label, assignments=assignments) + + last_state = before_state + for index, assignment in enumerate(assignments.items()): + key, value = assignment + is_last_state = index == len(assignments) - 1 + if is_last_state and after_state is not None: + cf_region.add_edge(last_state, after_state, InterstateEdge(assignments={key: value})) + last_state = after_state + else: + last_state = cf_region.add_state_after(last_state, label=label, assignments={key: value}) + + return last_state + + +def _list_index(list: List[tn.ScheduleTreeNode], node: tn.ScheduleTreeNode) -> int: + """Check if node is in list with "is" operator.""" + index = 0 + for element in list: + # compare with "is" to get memory comparison. ".index()" uses value comparison + if element is node: + return index + index += 1 + + raise StopIteration diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 0d709ee0fd..a142f1fa89 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -1,13 +1,17 @@ # Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass, field +import sympy -from dace import nodes, data, subsets +from dace import nodes, data, subsets, dtypes, symbolic from dace.properties import CodeBlock from dace.sdfg import InterstateEdge -from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState -from dace.symbolic import symbol +from dace.sdfg.memlet_utils import MemletSet +from dace.sdfg.propagation import propagate_subset +from dace.sdfg.sdfg import InterstateEdge, SDFG, memlets_in_ast +from dace.sdfg.state import LoopRegion, SDFGState from dace.memlet import Memlet -from typing import TYPE_CHECKING, Dict, Iterator, List, Literal, Optional, Set, Union +from types import TracebackType +from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Literal, Optional, Set, Tuple, Union if TYPE_CHECKING: from dace import SDFG @@ -19,6 +23,49 @@ class UnsupportedScopeException(Exception): pass +@dataclass +class Context: + root: 'ScheduleTreeRoot' + current_scope: Optional['ScheduleTreeScope'] + + access_cache: Dict[Tuple[SDFGState, str], Dict[str, nodes.AccessNode]] + """Per scope (hashed by id(scope_node) access_cache.""" + + +class ContextPushPop: + """Append the given node to the scope, then push/pop the scope.""" + + def __init__(self, ctx: Context, state: SDFGState, node: 'ScheduleTreeScope') -> None: + if ctx.current_scope is None and not isinstance(node, ScheduleTreeRoot): + raise ValueError("ctx.current_scope is only allowed to be 'None' when node it tree root.") + + self._ctx = ctx + self._parent_scope = ctx.current_scope + self._node = node + self._state = state + + cache_key = (state, id(node)) + assert cache_key not in self._ctx.access_cache + self._ctx.access_cache[cache_key] = {} + + def __enter__(self) -> None: + assert not self._ctx.access_cache[(self._state, id( + self._node))], "Expecting an empty access_cache when entering the context." + + self._ctx.current_scope = self._node + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + cache_key = (self._state, id(self._node)) + assert cache_key in self._ctx.access_cache + + self._ctx.current_scope = self._parent_scope + + @dataclass class ScheduleTreeNode: parent: Optional['ScheduleTreeScope'] = field(default=None, init=False, repr=False) @@ -32,20 +79,50 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: """ yield self + def get_root(self) -> 'ScheduleTreeRoot': + if self.parent is None: + raise ValueError('Non-root schedule tree node has no parent') + return self.parent.get_root() + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + """ + Returns a set of inputs for this node. For scopes, returns the union of its contents. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :return: A set of memlets representing the inputs of this node. + """ + raise NotImplementedError + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + """ + Returns a set of outputs for this node. For scopes, returns the union of its contents. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :return: A set of memlets representing the inputs of this node. + """ + raise NotImplementedError + @dataclass class ScheduleTreeScope(ScheduleTreeNode): - children: List['ScheduleTreeNode'] - containers: Optional[Dict[str, data.Data]] = field(default_factory=dict, init=False) - symbols: Optional[Dict[str, symbol]] = field(default_factory=dict, init=False) - - def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): - self.children = children or [] - if self.children: - for child in children: - child.parent = self - self.containers = {} - self.symbols = {} + children: List[ScheduleTreeNode] + + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional['ScheduleTreeScope'] = None) -> None: + for child in children: + child.parent = self + + self.children = children + self.parent = parent + + def add_children(self, children: Iterable[ScheduleTreeNode]) -> None: + for child in children: + child.parent = self + self.children.append(child) + + def add_child(self, child: ScheduleTreeNode) -> None: + self.add_children([child]) def as_string(self, indent: int = 0): if not self.children: @@ -60,12 +137,178 @@ def preorder_traversal(self) -> Iterator['ScheduleTreeNode']: for child in self.children: yield from child.preorder_traversal() - # TODO: Helper function that gets input/output memlets of the scope + def _gather_memlets_in_scope(self, inputs: bool, root: Optional['ScheduleTreeRoot'], keep_locals: bool, + propagate: Dict[str, + subsets.Range], disallow_propagation: Set[str], **kwargs) -> MemletSet: + gather = (lambda n, root: n.input_memlets(root, **kwargs)) if inputs else ( + lambda n, root: n.output_memlets(root, **kwargs)) + + # Fast path, no propagation necessary + if keep_locals: + return MemletSet().union(*(gather(c) for c in self.children)) + + root = root if root is not None else self.get_root() + + if propagate: + to_propagate = list(propagate.items()) + propagate_keys = [a[0] for a in to_propagate] + propagate_values = subsets.Range([a[1] for a in to_propagate]) + + current_locals = set() + current_locals |= disallow_propagation + result = MemletSet() + + # Loop over children in order, if any new symbol is defined within this scope (e.g., symbol assignment, + # dynamic map range), consider it as a new local + for c in self.children: + # Add new locals + if isinstance(c, AssignNode): + current_locals.add(c.name) + elif isinstance(c, DynScopeCopyNode): + current_locals.add(c.target) + + internal_memlets: MemletSet = gather(c, root) + if propagate: + for memlet in internal_memlets: + result.add( + propagate_subset([memlet], + root.containers[memlet.data], + propagate_keys, + propagate_values, + undefined_variables=current_locals, + use_dst=not inputs)) + + return result + + def input_memlets(self, + root: Optional['ScheduleTreeRoot'] = None, + keep_locals: bool = False, + propagate: Optional[Dict[str, subsets.Range]] = None, + disallow_propagation: Optional[Set[str]] = None, + **kwargs) -> MemletSet: + """ + Returns a union of the set of inputs for this scope. Propagates the memlets used in the scope if ``keep_locals`` + is set to False. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :param keep_locals: If True, keeps the local symbols defined within the scope as part of the resulting memlets. + Otherwise, performs memlet propagation (see ``propagate`` and ``disallow_propagation``) or + assumes the entire container is used. + :param propagate: An optional dictionary mapping symbols to their corresponding ranges outside of this scope. + For example, the range of values a for-loop may take. + If ``keep_locals`` is False, this dictionary will be used to create projection memlets over + the ranges. See :ref:`memprop` in the documentation for more information. + :param disallow_propagation: If ``keep_locals`` is False, this optional set of strings will be considered + as additional locals. + :return: A set of memlets representing the inputs of this scope. + """ + return self._gather_memlets_in_scope(True, root, keep_locals, propagate or {}, disallow_propagation or set(), + **kwargs) + + def output_memlets(self, + root: Optional['ScheduleTreeRoot'] = None, + keep_locals: bool = False, + propagate: Optional[Dict[str, subsets.Range]] = None, + disallow_propagation: Optional[Set[str]] = None, + **kwargs) -> MemletSet: + """ + Returns a union of the set of outputs for this scope. Propagates the memlets used in the scope if + ``keep_locals`` is set to False. + + :param root: An optional argument specifying the schedule tree's root. If not given, + the value is computed from the current tree node. + :param keep_locals: If True, keeps the local symbols defined within the scope as part of the resulting memlets. + Otherwise, performs memlet propagation (see ``propagate`` and ``disallow_propagation``) or + assumes the entire container is used. + :param propagate: An optional dictionary mapping symbols to their corresponding ranges outside of this scope. + For example, the range of values a for-loop may take. + If ``keep_locals`` is False, this dictionary will be used to create projection memlets over + the ranges. See :ref:`memprop` in the documentation for more information. + :param disallow_propagation: If ``keep_locals`` is False, this optional set of strings will be considered + as additional locals. + :return: A set of memlets representing the inputs of this scope. + """ + return self._gather_memlets_in_scope(False, root, keep_locals, propagate or {}, disallow_propagation or set(), + **kwargs) + + +@dataclass +class ScheduleTreeRoot(ScheduleTreeScope): + """ + A root of an SDFG schedule tree. This is a schedule tree scope with additional information on + the available descriptors, symbol types, and constants of the tree, aka the descriptor repository. + """ + name: str + containers: Dict[str, data.Data] + symbols: Dict[str, dtypes.typeclass] + constants: Dict[str, Tuple[data.Data, Any]] + callback_mapping: Dict[str, str] + arg_names: List[str] + + def __init__( + self, + *, + name: str, + children: List[ScheduleTreeNode], + containers: Optional[Dict[str, data.Data]] = None, + symbols: Optional[Dict[str, dtypes.typeclass]] = None, + constants: Optional[Dict[str, Tuple[data.Data, Any]]] = None, + callback_mapping: Optional[Dict[str, str]] = None, + arg_names: Optional[List[str]] = None, + ) -> None: + super().__init__(children=children, parent=None) + + self.name = name + self.containers = containers if containers is not None else dict() + self.symbols = symbols if symbols is not None else dict() + self.constants = constants if constants is not None else dict() + self.callback_mapping = callback_mapping if callback_mapping is not None else dict() + self.arg_names = arg_names if arg_names is not None else list() + + def as_sdfg(self, + validate: bool = True, + simplify: bool = True, + validate_all: bool = False, + skip: Set[str] = set(), + verbose: bool = False) -> SDFG: + """ + Convert this schedule tree representation (back) into an SDFG. + + :param validate: If true, validate generated SDFG. + :param simplify: If true, simplify generated SDFG. The conversion might insert things like extra + empty states that can be cleaned up automatically. The value of `validate` is + passed on to `simplify()`. + :param validate_all: When simplifying, validate all intermediate SDFGs. Unused if simplify is False. + :param skip: Set of names of simplify passes to skip. Unused if simplify is False. + :param verbose: Turn on verbose logging of simplify. Unused if simplify is False. + + :return: SDFG version of this schedule tree. + """ + from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s # Avoid import loop + sdfg = t2s.from_schedule_tree(self) + + if validate: + sdfg.validate() + + if simplify: + from dace.transformation.passes.simplify import SimplifyPass + SimplifyPass(validate=validate, validate_all=validate_all, skip=skip, verbose=verbose).apply_pass(sdfg, {}) + + return sdfg + + def get_root(self) -> 'ScheduleTreeRoot': + return self + + def scope(self, state: SDFGState, ctx: Context) -> ContextPushPop: + return ContextPushPop(ctx, state, self) @dataclass class ControlFlowScope(ScheduleTreeScope): - pass + + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) @dataclass @@ -73,6 +316,20 @@ class DataflowScope(ScheduleTreeScope): node: nodes.EntryNode state: Optional[SDFGState] = None + def __init__(self, + *, + node: nodes.EntryNode, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None, + state: Optional[SDFGState] = None) -> None: + super().__init__(children=children, parent=parent) + + self.node = node + self.state = state + + def scope(self, state: SDFGState, ctx: Context) -> ContextPushPop: + return ContextPushPop(ctx, state, self) + @dataclass class GBlock(ControlFlowScope): @@ -82,6 +339,9 @@ class GBlock(ControlFlowScope): Normally contains irreducible control flow. """ + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) + def as_string(self, indent: int = 0): result = indent * INDENTATION + 'gblock:\n' return result + super().as_string(indent) @@ -94,6 +354,12 @@ class StateLabel(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'label {self.state.name}:' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class GotoNode(ScheduleTreeNode): @@ -103,6 +369,12 @@ def as_string(self, indent: int = 0): name = self.target or 'exit' return indent * INDENTATION + f'goto {name}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class AssignNode(ScheduleTreeNode): @@ -116,6 +388,13 @@ class AssignNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'assign {self.name} = {self.value.as_string}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + return MemletSet(self.edge.get_read_memlets(root.containers)) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class LoopScope(ControlFlowScope): @@ -124,51 +403,132 @@ class LoopScope(ControlFlowScope): """ loop: LoopRegion - def _check_loop_variant( - self - ) -> Union[Literal['for'], Literal['while'], Literal['do-while'], Literal['do-for-uncond-increment'], - Literal['do-for']]: - if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - if self.loop.inverted: - if self.loop.update_before_condition: - return 'do-for-uncond-increment' - else: - return 'do-for' - else: - return 'for' - else: - if self.loop.inverted: - return 'do-while' - else: - return 'while' + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) + + self.loop = loop def as_string(self, indent: int = 0): loop = self.loop - loop_variant = self._check_loop_variant() - if loop_variant == 'do-for-uncond-increment': + variant = loop_variant(loop) + if variant == 'do-for-uncond-increment': pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' header = indent * INDENTATION + 'do:\n' pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer - elif loop_variant == 'do-for': + + if variant == 'do-for': pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' header = indent * INDENTATION + 'while True:\n' pre_footer = (indent + 1) * INDENTATION + f'if (not {loop.loop_condition.as_string}):\n' pre_footer += (indent + 2) * INDENTATION + 'break\n' footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer - elif loop_variant == 'for': - result = (indent * INDENTATION + f'for {loop.init_statement.as_string}; ' + - f'{loop.loop_condition.as_string}; ' + f'{loop.update_statement.as_string}:\n') - return result + super().as_string(indent) - elif loop_variant == 'while': - result = indent * INDENTATION + f'while {loop.loop_condition.as_string}:\n' - return result + super().as_string(indent) - else: # 'do-while' - header = indent * INDENTATION + 'do:\n' - footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' - return header + super().as_string(indent) + '\n' + footer + + if variant in ["for", "while", "do-while"]: + return super().as_string(indent) + + return NotImplementedError(f"Unknown LoopRegion variant '{variant}.") + + +@dataclass +class ForScope(LoopScope): + """Specialized LoopScope for for-loops.""" + + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(loop=loop, children=children, parent=parent) + + def as_string(self, indent: int = 0) -> str: + init_statement = self.loop.init_statement.as_string + condition = self.loop.loop_condition.as_string + update_statement = self.loop.update_statement.as_string + result = indent * INDENTATION + f"for {init_statement}; {condition}; {update_statement}:\n" + return result + super().as_string(indent) + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + + result = MemletSet() + result.update(self.loop.get_meta_read_memlets(arrays=root.containers)) + + # If loop range is well-formed, use it in propagation + range = _loop_range(self.loop) + if range is not None: + propagate = {self.loop.loop_variable: range} + else: + propagate = None + + result.update(super().input_memlets(root, propagate=propagate, **kwargs)) + return result + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + # If loop range is well-formed, use it in propagation + range = _loop_range(self.loop) + if range is not None: + propagate = {self.loop.loop_variable: range} + else: + propagate = None + + return super().output_memlets(root, propagate=propagate, **kwargs) + + +@dataclass +class WhileScope(LoopScope): + """Specialized LoopScope for while-loops.""" + + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(loop=loop, children=children, parent=parent) + + def as_string(self, indent: int = 0) -> str: + condition = self.loop.loop_condition.as_string + result = indent * INDENTATION + f'while {condition}:\n' + return result + super().as_string(indent) + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + + result = MemletSet() + result.update(self.loop.get_meta_read_memlets(arrays=root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result + + +@dataclass +class DoWhileScope(LoopScope): + """Specialized LoopScope for do-while-loops""" + + def __init__(self, + *, + loop: LoopRegion, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(loop=loop, children=children, parent=parent) + + def as_string(self, indent: int = 0) -> str: + header = indent * INDENTATION + 'do:\n' + footer = indent * INDENTATION + f'while {self.loop.loop_condition.as_string}' + return header + super().as_string(indent) + '\n' + footer + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + + result = MemletSet() + result.update(self.loop.get_meta_read_memlets(arrays=root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result @dataclass @@ -178,10 +538,26 @@ class IfScope(ControlFlowScope): """ condition: CodeBlock + def __init__(self, + *, + condition: CodeBlock, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) + + self.condition = condition + def as_string(self, indent: int = 0): result = indent * INDENTATION + f'if {self.condition.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result + @dataclass class StateIfScope(IfScope): @@ -189,6 +565,13 @@ class StateIfScope(IfScope): A special class of an if scope in general blocks for if statements that are part of a state transition. """ + def __init__(self, + *, + condition: CodeBlock, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(condition=condition, children=children, parent=parent) + def as_string(self, indent: int = 0): result = indent * INDENTATION + f'stateif {self.condition.as_string}:\n' return result + super(IfScope, self).as_string(indent) @@ -203,6 +586,12 @@ class BreakNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'break' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class ContinueNode(ScheduleTreeNode): @@ -213,6 +602,12 @@ class ContinueNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + 'continue' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class ElifScope(ControlFlowScope): @@ -221,10 +616,26 @@ class ElifScope(ControlFlowScope): """ condition: CodeBlock + def __init__(self, + *, + condition: CodeBlock, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) + + self.condition = condition + def as_string(self, indent: int = 0): result = indent * INDENTATION + f'elif {self.condition.as_string}:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + result = MemletSet() + result.update(memlets_in_ast(self.condition.code[0], root.containers)) + result.update(super().input_memlets(root, **kwargs)) + return result + @dataclass class ElseScope(ControlFlowScope): @@ -232,6 +643,9 @@ class ElseScope(ControlFlowScope): Else branch scope. """ + def __init__(self, *, children: List[ScheduleTreeNode], parent: Optional[ScheduleTreeScope] = None) -> None: + super().__init__(children=children, parent=parent) + def as_string(self, indent: int = 0): result = indent * INDENTATION + 'else:\n' return result + super().as_string(indent) @@ -242,18 +656,52 @@ class MapScope(DataflowScope): """ Map scope. """ + node: nodes.MapEntry + + def __init__(self, + *, + node: nodes.MapEntry, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None, + state: Optional[SDFGState] = None) -> None: + super().__init__(node=node, state=state, children=children, parent=parent) def as_string(self, indent: int = 0): rangestr = ', '.join(subsets.Range.dim_to_string(d) for d in self.node.map.range) result = indent * INDENTATION + f'map {", ".join(self.node.map.params)} in [{rangestr}]:\n' return result + super().as_string(indent) + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return super().input_memlets(root, + propagate={ + k: v + for k, v in zip(self.node.map.params, self.node.map.range) + }, + **kwargs) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return super().output_memlets(root, + propagate={ + k: v + for k, v in zip(self.node.map.params, self.node.map.range) + }, + **kwargs) + @dataclass class ConsumeScope(DataflowScope): """ Consume scope. """ + node: nodes.ConsumeEntry + + def __init__(self, + *, + node: nodes.ConsumeEntry, + children: List[ScheduleTreeNode], + parent: Optional[ScheduleTreeScope] = None, + state: Optional[SDFGState] = None) -> None: + super().__init__(node=node, state=state, children=children, parent=parent) def as_string(self, indent: int = 0): node: nodes.ConsumeEntry = self.node @@ -275,12 +723,18 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'tasklet({in_memlets})' return indent * INDENTATION + f'{out_memlets} = tasklet({in_memlets})' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet(self.in_memlets.values()) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet(self.out_memlets.values()) + @dataclass class LibraryCall(ScheduleTreeNode): node: nodes.LibraryNode - in_memlets: Union[Dict[str, Memlet], Set[Memlet]] - out_memlets: Union[Dict[str, Memlet], Set[Memlet]] + in_memlets: Union[Dict[str, Memlet], MemletSet] + out_memlets: Union[Dict[str, Memlet], MemletSet] def as_string(self, indent: int = 0): if isinstance(self.in_memlets, set): @@ -297,6 +751,16 @@ def as_string(self, indent: int = 0): if v.owner not in {nodes.Node, nodes.CodeNode, nodes.LibraryNode}) return indent * INDENTATION + f'{out_memlets} = library {libname}[{own_properties}]({in_memlets})' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + if isinstance(self.in_memlets, set): + return MemletSet(self.in_memlets) + return MemletSet(self.in_memlets.values()) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + if isinstance(self.out_memlets, set): + return MemletSet(self.out_memlets) + return MemletSet(self.out_memlets.values()) + @dataclass class CopyNode(ScheduleTreeNode): @@ -315,6 +779,16 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target}{offset} = copy {self.memlet.data}[{self.memlet.subset}]{wcr}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + root = root if root is not None else self.get_root() + if self.memlet.other_subset is not None: + return MemletSet({Memlet(data=self.target, subset=self.memlet.other_subset, wcr=self.memlet.wcr)}) + + return MemletSet({Memlet.from_array(self.target, root.containers[self.target], self.memlet.wcr)}) + @dataclass class DynScopeCopyNode(ScheduleTreeNode): @@ -327,6 +801,12 @@ class DynScopeCopyNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = dscopy {self.memlet.data}[{self.memlet.subset}]' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + @dataclass class ViewNode(ScheduleTreeNode): @@ -339,6 +819,12 @@ class ViewNode(ScheduleTreeNode): def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = view {self.memlet} as {self.view_desc.shape}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({Memlet.from_array(self.target, self.view_desc)}) + @dataclass class NView(ViewNode): @@ -350,6 +836,25 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = nview {self.memlet} as {self.view_desc.shape}' +@dataclass +class NViewEnd(ScheduleTreeNode): + """ + Artificial node to denote the scope end of the associated Nested SDFG view node. + """ + + target: str + """Target name of the associated NView container.""" + + def as_string(self, indent: int = 0): + return indent * INDENTATION + f"end nview {self.target}" + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + @dataclass class RefSetNode(ScheduleTreeNode): """ @@ -365,25 +870,49 @@ def as_string(self, indent: int = 0): return indent * INDENTATION + f'{self.target} = refset from {type(self.src_desc).__name__.lower()}' return indent * INDENTATION + f'{self.target} = refset to {self.memlet}' + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({self.memlet}) + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet({Memlet.from_array(self.target, self.ref_desc)}) + + +@dataclass +class StateBoundaryNode(ScheduleTreeNode): + """ + A node that represents a state boundary (e.g., when a write-after-write is encountered). This node + is used only during conversion from a schedule tree to an SDFG. + """ + due_to_control_flow: bool = False + + def as_string(self, indent: int = 0): + return indent * INDENTATION + 'state boundary' + + def input_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + + def output_memlets(self, root: Optional['ScheduleTreeRoot'] = None, **kwargs) -> MemletSet: + return MemletSet() + # Classes based on Python's AST NodeVisitor/NodeTransformer for schedule tree nodes class ScheduleNodeVisitor: - def visit(self, node: ScheduleTreeNode): + def visit(self, node: ScheduleTreeNode, **kwargs: Any): """Visit a node.""" if isinstance(node, list): - return [self.visit(snode) for snode in node] + return [self.visit(snode, **kwargs) for snode in node] if isinstance(node, ScheduleTreeScope) and hasattr(self, 'visit_scope'): - return self.visit_scope(node) + return self.visit_scope(node, **kwargs) method = 'visit_' + node.__class__.__name__ visitor = getattr(self, method, self.generic_visit) - return visitor(node) + return visitor(node, **kwargs) - def generic_visit(self, node: ScheduleTreeNode): + def generic_visit(self, node: ScheduleTreeNode, **kwargs: Any): if isinstance(node, ScheduleTreeScope): for child in node.children: - self.visit(child) + self.visit(child, **kwargs) class ScheduleNodeTransformer(ScheduleNodeVisitor): @@ -427,3 +956,98 @@ def validate_has_no_other_node_types(stree: ScheduleTreeScope) -> None: raise RuntimeError(f'Unsupported node type: {type(child).__name__}') if isinstance(child, ScheduleTreeScope): validate_has_no_other_node_types(child) + + +def validate_children_and_parents_align(stree: ScheduleTreeScope, *, root: bool = False) -> None: + """ + Validates the child/parent information of schedule tree scopes are consistent. + + Walks through all children of a scope and raises if the children's parent isn't + the scope. If `root` is true, we additionally check that the top-most scope is + of type `ScheduleTreeRoot`. + + :param stree: Schedule tree scope to be analyzed + :param root: If true, we raise if the top-most scope isn't of type `ScheduleTreeRoot`. + """ + if root and not isinstance(stree, ScheduleTreeRoot): + raise RuntimeError("Expected schedule tree root.") + + for child in stree.children: + if id(child.parent) != id(stree): + raise RuntimeError(f"Inconsistent parent/child relationship. child: {child}, parent: {stree}") + if isinstance(child, ScheduleTreeScope): + validate_children_and_parents_align(child) + + +def loop_variant( + loop: LoopRegion +) -> Union[Literal['for'], Literal['while'], Literal['do-while'], Literal['do-for-uncond-increment'], + Literal['do-for']]: + if loop.update_statement and loop.init_statement and loop.loop_variable: + if loop.inverted: + if loop.update_before_condition: + return 'do-for-uncond-increment' + return 'do-for' + return 'for' + + if loop.inverted: + return 'do-while' + return 'while' + + +def _loop_range( + loop: LoopRegion) -> Optional[Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType]]: + """ + Derive loop range for well-formed `for`-loops. + + :param: loop The loop to be analyzed. + :return: If well formed, `(start, end, step)` where `end` is inclusive, otherwise `None`. + """ + + if loop_variant(loop) != "for" or loop.loop_variable is None: + # Loop range is only defined in for-loops + # and we need to know the loop variable. + return None + + # Avoid cyclic import + from dace.transformation.passes.analysis import loop_analysis + + # If loop information cannot be determined, we cannot derive loop range + start = loop_analysis.get_init_assignment(loop) + step = loop_analysis.get_loop_stride(loop) + end = _match_loop_condition(loop) + if start is None or step is None or end is None: + return None + + return (start, end, step) + + +def _match_loop_condition(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Try to find the end of a for-loop by symbolically matching the loop condition. + + :return: loop end (inclusive) or `None` if matching failed. + """ + + condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) + loop_symbol = symbolic.pystr_to_symbolic(loop.loop_variable) + a = sympy.Wild('a') + + match = condition.match(loop_symbol < a) + if match is not None: + return match[a] - 1 + + match = condition.match(loop_symbol <= a) + if match is not None: + return match[a] + + match = condition.match(loop_symbol >= a) + if match is not None: + return match[a] + + match = condition.match(loop_symbol > a) + if match is not None: + return match[a] + 1 + + # Matching failed - we can't derive end of loop + return None diff --git a/dace/sdfg/memlet_utils.py b/dace/sdfg/memlet_utils.py index 2c91650314..ddba9d66ef 100644 --- a/dace/sdfg/memlet_utils.py +++ b/dace/sdfg/memlet_utils.py @@ -1,13 +1,17 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from __future__ import annotations import ast -import itertools +from collections import defaultdict +import copy from dace import data, Memlet, subsets, symbolic, dtypes +from dace.frontend.python import memlet_parser from dace.sdfg import SDFGState, SDFG, nodes, utils as sdutil from dace.sdfg.scope import is_devicelevel_gpu from dace.sdfg.graph import MultiConnectorEdge from dace.frontend.python import memlet_parser -from typing import Callable, Dict, Optional, Set, Union, Tuple +import itertools +from typing import Callable, Dict, Iterable, Optional, Set, TypeVar, Tuple, Union, Generator, Any class MemletReplacer(ast.NodeTransformer): @@ -83,6 +87,188 @@ def visit_Subscript(self, node: ast.Subscript): return self.generic_visit(node) +class MemletSet(Set[Memlet]): + """ + Implements a set of memlets that considers subsets that intersect or are covered by its other memlets. + Set updates and unions also perform unions on the contained memlet subsets. + """ + + def __init__(self, iterable: Optional[Iterable[Memlet]] = None, *, intersection_is_contained: bool = True) -> None: + """ + Initializes a memlet set. + + :param iterable: An optional iterable of memlets to initialize the set with. + :param intersection_is_contained: Whether the check ``m in memlet_set`` should return True if the memlet + only intersects with the contents of the set. If False, only completely + covered subsets would return True. + """ + self.internal_set: Dict[str, Set[Memlet]] = {} + self.intersection_is_contained = intersection_is_contained + if iterable is not None: + self.update(iterable) + + def __iter__(self) -> Generator[Memlet, Any, None]: + for subset in self.internal_set.values(): + yield from subset + + def __len__(self) -> int: + return len(self.internal_set) + + def update(self, *iterable: Iterable[Memlet]) -> None: + """ + Updates set of memlets via union of existing ranges. + """ + if len(iterable) == 0: + return + if len(iterable) > 1: + for i in iterable: + self.update(i) + return + + to_update, = iterable + for elem in to_update: + self.add(elem) + + def add(self, elem: Memlet) -> None: + """ + Adds a memlet to the set, potentially performing a union of existing ranges. + """ + if elem.data not in self.internal_set: + self.internal_set[elem.data] = {elem} + return + + # Memlet is in set, either perform a union (if possible) or add to internal set + # TODO(later): Consider other_subset as well + for existing_memlet in self.internal_set[elem.data]: + try: + if subsets.intersects(existing_memlet.subset, elem.subset) == True: # Definitely intersects + if existing_memlet.subset.covers(elem.subset): + break # Nothing to do + + # Create a new union memlet + self.internal_set[elem.data].remove(existing_memlet) + new_memlet = copy.deepcopy(existing_memlet) + new_memlet.subset = subsets.union(existing_memlet.subset, elem.subset) + self.internal_set[elem.data].add(new_memlet) + break + except TypeError: # Indeterminate + pass + else: # all intersections were False or indeterminate (may or does not intersect with existing memlets) + self.internal_set[elem.data].add(elem) + + def __contains__(self, elem: Memlet) -> bool: + """ + Returns True iff the memlet or a range superset thereof exists in this set. + """ + if elem.data not in self.internal_set: + return False + for existing_memlet in self.internal_set[elem.data]: + if existing_memlet.subset.covers(elem.subset): + return True + if self.intersection_is_contained: + try: + if subsets.intersects(existing_memlet.subset, elem.subset) == False: + continue + else: # May intersect or indeterminate + return True + except TypeError: + return True + + return False + + def union(self, *s: Iterable[Memlet]) -> MemletSet: + """ + Performs a set-union (with memlet union) over the given sets of memlets. + + :return: New memlet set containing the union of this set and the inputs. + """ + newset = MemletSet(self) + newset.update(s) + return newset + + +T = TypeVar('T') + + +class MemletDict(Dict[Memlet, T]): + """ + Implements a dictionary with memlet keys that considers subsets that intersect + or are covered by its other memlets. + """ + + def __init__(self, **kwargs) -> None: + self.internal_dict: Dict[str, Dict[Memlet, T]] = defaultdict(dict) + self.covers_cache: Dict[Tuple, bool] = defaultdict() + + if kwargs: + self.update(kwargs) + + def __len__(self) -> int: + return len(self.internal_dict) + + def _getkey(self, elem: Memlet) -> Optional[Memlet]: + """ + Returns the corresponding key (exact, covered, intersecting, or indeterminately intersecting memlet) + if it exists in the dictionary, or None if it does not. + """ + if elem.data not in self.internal_dict: + return None + + for existing_memlet in self.internal_dict[elem.data]: + key = (existing_memlet.subset, elem.subset) + is_covered = self.covers_cache.get(key, None) + if is_covered is None: + is_covered = existing_memlet.subset.covers(elem.subset) + self.covers_cache[key] = is_covered + if is_covered: + return existing_memlet + + try: + if subsets.intersects(existing_memlet.subset, elem.subset) == False: # Definitely does not intersect + continue + except TypeError: + pass + + # May or will intersect + return existing_memlet + + return None + + def _setkey(self, key: Memlet, value: T) -> None: + self.internal_dict[key.data][key] = value + + def clear(self) -> None: + self.internal_dict.clear() + + def update(self, mapping: Dict[Memlet, T]) -> None: + for key, value in mapping.items(): + actual_key = self._getkey(key) + if actual_key is None: + self._setkey(key, value) + else: + self._setkey(actual_key, value) + + def __contains__(self, elem: Memlet) -> bool: + """ + Returns True iff the memlet or a range superset thereof exists in this dictionary. + """ + return self._getkey(elem) is not None + + def __getitem__(self, key: Memlet) -> T: + actual_key = self._getkey(key) + if actual_key is None: + raise KeyError(key) + + return self.internal_dict[key.data][actual_key] + + def __setitem__(self, key: Memlet, value: T) -> None: + actual_key = self._getkey(key) + if actual_key is None: + self._setkey(key, value) + else: + self._setkey(actual_key, value) + + def memlet_to_map( edge: MultiConnectorEdge, state: SDFGState, diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index a88ef0a367..1d5783de81 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. """ Functionality relating to Memlet propagation (deducing external memlets from internal memory accesses and scope ranges). @@ -1427,7 +1427,12 @@ def propagate_memlet(dfg_state, # Propagate subset if isinstance(entry_node, nodes.MapEntry): mapnode = entry_node.map - return propagate_subset(aggdata, arr, mapnode.params, mapnode.range, defined_vars, use_dst=use_dst) + return propagate_subset(aggdata, + arr, + mapnode.params, + mapnode.range, + defined_variables=defined_vars, + use_dst=use_dst) elif isinstance(entry_node, nodes.ConsumeEntry): # Nothing to analyze/propagate in consume @@ -1446,7 +1451,9 @@ def propagate_subset(memlets: List[Memlet], arr: data.Data, params: List[str], rng: subsets.Subset, + *, defined_variables: Set[symbolic.SymbolicType] = None, + undefined_variables: Set[symbolic.SymbolicType] = None, use_dst: bool = False) -> Memlet: """ Tries to propagate a list of memlets through a range (computes the image of the memlet function applied on an integer set of, e.g., a @@ -1459,8 +1466,12 @@ def propagate_subset(memlets: List[Memlet], range to propagate with. :param defined_variables: A set of symbols defined that will remain the same throughout propagation. If None, assumes - that all symbols outside of `params` have been - defined. + that all symbols outside of ``params``, except + for ``undefined_variables``, have been defined. + :param undefined_variables: A set of symbols that are explicitly considered + as not defined throughout propagation, such as + locals. Their existence will trigger propagating + the entire memlet. :param use_dst: Whether to propagate the memlets' dst subset or use the src instead, depending on propagation direction. :return: Memlet with propagated subset and volume. @@ -1474,6 +1485,13 @@ def propagate_subset(memlets: List[Memlet], defined_variables |= memlet.free_symbols defined_variables -= set(params) defined_variables = set(symbolic.pystr_to_symbolic(p) for p in defined_variables) + else: + defined_variables = set(defined_variables) + + if undefined_variables is not None: + defined_variables = defined_variables - set(symbolic.pystr_to_symbolic(p) for p in undefined_variables) + else: + undefined_variables = set() # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1504,18 +1522,33 @@ def propagate_subset(memlets: List[Memlet], tmp_subset = pattern.propagate(arr, [subset], rng) break else: - # No patterns found. Emit a warning and propagate the entire - # array whenever symbols are used + # No patterns found. Propagate the entire array whenever symbols are used. if Config.get_bool('debugprint'): print(f'Cannot find appropriate memlet pattern to propagate {subset} through {rng}') entire_array = subsets.Range.from_array(arr) paramset = set(map(str, params)) # Fill in the entire array only if one of the parameters appears in the - # free symbols list of the subset dimension - tmp_subset = subsets.Range([ - ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s - for s, ea in zip(subset, entire_array) - ]) + # free symbols list of the subset dimension or is undefined outside. + tmp_subset_rng = [] + for s, ea in zip(subset, entire_array): + if isinstance(subset, subsets.Indices): + fsyms = _freesyms(s) + fsyms_str = set(map(str, fsyms)) + contains_params = len(fsyms_str & paramset) != 0 + contains_undefs = len(fsyms & undefined_variables) != 0 + else: + contains_params = False + contains_undefs = False + for sdim in s: + fsyms = _freesyms(sdim) + fsyms_str = set(map(str, fsyms)) + contains_params |= len(fsyms_str & paramset) != 0 + contains_undefs |= len(fsyms & undefined_variables) != 0 + if contains_params or contains_undefs: + tmp_subset_rng.append(ea) + else: + tmp_subset_rng.append(s) + tmp_subset = subsets.Range(tmp_subset_rng) # Union edges as necessary if new_subset is None: diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 04ceafb261..4bf1e74a0d 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -81,7 +81,7 @@ def replace_dict(subgraph: 'StateSubgraphView', desc = node.desc(state) # In case the AccessNode name was replaced in the sdfg.arrays but not in the SDFG itself # then we have to look for the replaced value in the sdfg.arrays - elif repl[node.data] in state.sdfg.arrays: + elif node.data in repl and repl[node.data] in state.sdfg.arrays: desc = state.sdfg.arrays[repl[node.data]] else: continue diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index d88effd594..6b54d176f4 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -42,7 +42,7 @@ from dace.codegen.instrumentation.report import InstrumentationReport from dace.codegen.instrumentation.data.data_report import InstrumentedDataReport from dace.codegen.compiled_sdfg import CompiledSDFG - from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeScope + from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeRoot class NestedDict(dict): @@ -417,7 +417,7 @@ def from_json(json_obj, context=None): def label(self): assignments = ','.join(['%s=%s' % (k, v) for k, v in self.assignments.items()]) - # Edge with assigment only (no condition) + # Edge with assignment only (no condition) if self.condition.as_string == '1': # Edge without conditions or assignments if len(self.assignments) == 0: @@ -428,7 +428,7 @@ def label(self): if len(self.assignments) == 0: return self.condition.as_string - # Edges with assigments and conditions + # Edges with assignments and conditions return self.condition.as_string + '; ' + assignments @@ -1128,7 +1128,7 @@ def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, ########################################## - def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeScope': + def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeRoot': """ Creates a schedule tree from this SDFG and all nested SDFGs. The schedule tree is a tree of nodes that represent the execution order of the SDFG. @@ -1136,7 +1136,8 @@ def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeScope': etc.) or a ``ScheduleTreeScope`` block (map, for-loop, etc.) that contains other nodes. It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, - erasing an empty if branch, or merging two consecutive for-loops. + erasing an empty if branch, or merging two consecutive for-loops. The SDFG can then be reconstructed via the + ``as_sdfg`` method or the ``from_schedule_tree`` function in ``dace.sdfg.analysis.schedule_tree.tree_to_sdfg``. :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might not be usable after the conversion if ``in_place`` is True! diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index adc77bf439..e08fd18ce5 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2599,6 +2599,7 @@ class AbstractControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.Inte ControlFlowBlock, abc.ABC): """ Abstract superclass to represent all kinds of control flow regions in an SDFG. + This is consequently one of the three main classes of control flow graph nodes, which include ``ControlFlowBlock``s, ``SDFGState``s, and nested ``AbstractControlFlowRegion``s. An ``AbstractControlFlowRegion`` can further be either a region that directly contains a control flow graph (``ControlFlowRegion``s and subclasses thereof), or something @@ -3121,6 +3122,7 @@ def start_block(self, block_id): class ControlFlowRegion(AbstractControlFlowRegion): """ A ``ControlFlowRegion`` represents a control flow graph node that itself contains a control flow graph. + This can be an arbitrary control flow graph, but may also be a specific type of control flow region with additional semantics, such as a loop or a function call. """ @@ -3185,34 +3187,28 @@ def __init__(self, update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, sdfg: Optional['SDFG'] = None, - update_before_condition=True, + update_before_condition: bool = True, unroll: bool = False, unroll_factor: int = 0): super(LoopRegion, self).__init__(label, sdfg) - if initialize_expr is not None: - if isinstance(initialize_expr, CodeBlock): - self.init_statement = initialize_expr - else: - self.init_statement = CodeBlock(initialize_expr) + if initialize_expr is None or isinstance(initialize_expr, CodeBlock): + self.init_statement = initialize_expr else: - self.init_statement = None + self.init_statement = CodeBlock(initialize_expr) - if condition_expr: + if condition_expr is None: + self.loop_condition = CodeBlock('True') + else: if isinstance(condition_expr, CodeBlock): self.loop_condition = condition_expr else: self.loop_condition = CodeBlock(condition_expr) - else: - self.loop_condition = CodeBlock('True') - if update_expr is not None: - if isinstance(update_expr, CodeBlock): - self.update_statement = update_expr - else: - self.update_statement = CodeBlock(update_expr) + if update_expr is None or isinstance(update_expr, CodeBlock): + self.update_statement = update_expr else: - self.update_statement = None + self.update_statement = CodeBlock(update_expr) self.loop_variable = loop_var or '' self.inverted = inverted @@ -3506,14 +3502,23 @@ def get_meta_codeblocks(self): codes.append(self.update_statement) return codes - def get_meta_read_memlets(self) -> List[mm.Memlet]: + def get_meta_read_memlets(self, arrays: Optional[Dict[str, dt.Data]] = None) -> List[mm.Memlet]: + """ + Get a list of all (read) memlets in meta codeblocks. + + :param arrays: An optional dictionary mapping array names to their data descriptors. + If not not given defaults to ``self.sdfg.arrays``. + """ # Avoid cyclic imports. from dace.sdfg.sdfg import memlets_in_ast - read_memlets = memlets_in_ast(self.loop_condition.code[0], self.sdfg.arrays) + + arrays = arrays if arrays is not None else self.sdfg.arrays + + read_memlets = memlets_in_ast(self.loop_condition.code[0], arrays) if self.init_statement: - read_memlets.extend(memlets_in_ast(self.init_statement.code[0], self.sdfg.arrays)) + read_memlets.extend(memlets_in_ast(self.init_statement.code[0], arrays)) if self.update_statement: - read_memlets.extend(memlets_in_ast(self.update_statement.code[0], self.sdfg.arrays)) + read_memlets.extend(memlets_in_ast(self.update_statement.code[0], arrays)) return read_memlets def replace_meta_accesses(self, replacements): diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index bc7598bbc8..c43910a585 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -282,6 +282,13 @@ def _is_node_dead(self, node: nodes.Node, sdfg: SDFG, state: SDFGState, dead_nod and any(ie.data.data == node.data for ie in state.in_edges(l.src))): return False + # If data is connected to a Tasklet + if isinstance(l.src, nodes.Tasklet): + # We can't inline connected data that is a pointer + ctype = infer_types.infer_out_connector_type(sdfg, state, l.src, l.src_conn) + if l.src.language == dtypes.Language.Python and isinstance(ctype, dtypes.pointer): + return False + # If it is a stream and is read somewhere in the state, it may be popped after pushing if isinstance(desc, data.Stream) and node.data in access_set[0]: return False diff --git a/tests/schedule_tree/naming_test.py b/tests/schedule_tree/naming_test.py index 8c39e3033f..632b64dfdf 100644 --- a/tests/schedule_tree/naming_test.py +++ b/tests/schedule_tree/naming_test.py @@ -5,7 +5,6 @@ from dace.transformation.passes.constant_propagation import ConstantPropagation import pytest -from typing import List def _irreducible_loop_to_loop(): @@ -171,7 +170,7 @@ def test_edgecase_symbol_mapping(): def _check_for_name_clashes(stree: tn.ScheduleTreeNode): - def _traverse(node: tn.ScheduleTreeScope, scopes: List[str]): + def _traverse(node: tn.ScheduleTreeScope, scopes: list[str]): for child in node.children: if isinstance(child, tn.LoopScope): itervar = child.loop.loop.loop_variable diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index 59512f88ab..a920b07067 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -63,8 +63,8 @@ def tester(A: dace.float64[N, N]): simplified = dace.Config.get_bool('optimizer', 'automatic_simplification') if simplified: - assert [type(n) - for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.LoopScope, tn.TaskletNode] + node_types = [type(n) for n in stree.preorder_traversal()][1:] + assert node_types == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] diff --git a/tests/schedule_tree/propagation_test.py b/tests/schedule_tree/propagation_test.py new file mode 100644 index 0000000000..c5233531b9 --- /dev/null +++ b/tests/schedule_tree/propagation_test.py @@ -0,0 +1,109 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests schedule tree input/output memlet computation +""" +import dace +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn +from dace.properties import CodeBlock + +import numpy as np +import pytest + + +def test_stree_propagation_forloop(): + N = dace.symbol('N') + + @dace.program + def tester(a: dace.float64[20]): + for i in range(1, N): + a[i] = 2 + a[1] = 1 + + stree = tester.to_sdfg().as_schedule_tree() + stree = t2s.insert_state_boundaries_to_tree(stree) + + node_types = [n for n in stree.preorder_traversal()] + assert isinstance(node_types[2], tn.ForScope) + memlet = dace.Memlet('a[1:N]') + memlet._is_data_src = False + assert list(node_types[2].output_memlets()) == [memlet] + + +@pytest.mark.skip("Suboptimal memlet propagation: https://github.com/spcl/dace/issues/2293") +def test_stree_propagation_symassign(): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': dace.data.Array(dace.float64, [20]), + }, + symbols={ + 'N': N, + }, + children=[ + tn.MapScope(node=dace.nodes.MapEntry(dace.nodes.Map('map', ['i'], dace.subsets.Range([(1, N - 1, 1)]))), + children=[ + tn.AssignNode('j', CodeBlock('N + i'), dace.InterstateEdge(assignments=dict(j='N + i'))), + tn.TaskletNode(nodes.Tasklet('inner', {}, {'out'}, 'out = inp + 2'), + {'inp': dace.Memlet('A[j]')}, {'out': dace.Memlet('A[j]')}), + ]), + ], + ) + stree.children[0].parent = stree + for c in stree.children[0].children: + c.parent = stree.children[0] + + assert list(stree.children[0].input_memlets()) == [dace.Memlet('A[0:20]', volume=N - 1)] + + +@pytest.mark.skip("Suboptimal memlet propagation: https://github.com/spcl/dace/issues/2293") +def test_stree_propagation_dynset(): + H = dace.symbol('H') + nnz = dace.symbol('nnz') + W = dace.symbol('W') + + @dace.program + def spmv(A_row: dace.uint32[H + 1], A_col: dace.uint32[nnz], A_val: dace.float32[nnz], x: dace.float32[W]): + b = np.zeros([H], dtype=np.float32) + + for i in dace.map[0:H]: + for j in dace.map[A_row[i]:A_row[i + 1]]: + b[i] += A_val[j] * x[A_col[j]] + + return b + + sdfg = spmv.to_sdfg() + stree = sdfg.as_schedule_tree() + assert len(stree.children) == 2 + assert all(isinstance(c, tn.MapScope) for c in stree.children) + mapscope = stree.children[1] + _, _, dynrangemap = mapscope.children + assert isinstance(dynrangemap, tn.MapScope) + + # Check dynamic range map memlets + internal_memlets = list(dynrangemap.input_memlets()) + internal_memlet_data = [m.data for m in internal_memlets] + assert 'x' in internal_memlet_data + assert 'A_val' in internal_memlet_data + assert 'A_row' not in internal_memlet_data + for m in internal_memlets: + if m.data == 'A_val': + assert m.subset != dace.subsets.Range([(0, nnz - 1, 1)]) # Not propagated + + # Check top-level scope memlets + external_memlets = list(mapscope.input_memlets()) + assert dace.Memlet('A_row[0:H]') in external_memlets + assert dace.Memlet('A_row[1:H+1]') in external_memlets + assert dace.Memlet('x[0:W]', volume=0, dynamic=True) in external_memlets + assert dace.Memlet('A_val[0:nnz]', volume=0, dynamic=True) in external_memlets + for m in external_memlets: + if m.data == 'A_val': + assert m.subset == dace.subsets.Range([(0, nnz - 1, 1)]) # Propagated + + +if __name__ == '__main__': + test_stree_propagation_forloop() + test_stree_propagation_symassign() + test_stree_propagation_dynset() diff --git a/tests/schedule_tree/roundtrip_test.py b/tests/schedule_tree/roundtrip_test.py new file mode 100644 index 0000000000..3d0fa42b8c --- /dev/null +++ b/tests/schedule_tree/roundtrip_test.py @@ -0,0 +1,59 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests conversion of schedule trees to SDFGs. +""" +import dace +import numpy as np + + +def test_implicit_inline_and_constants(): + """ + Tests implicit inlining upon roundtrip conversion, as well as constants with conflicting names. + """ + + @dace + def nester(A: dace.float64[20]): + A[:] = 12 + + @dace.program + def tester(A: dace.float64[20, 20]): + for i in dace.map[0:20]: + nester(A[:, i]) + + sdfg = tester.to_sdfg(simplify=False) + + # Inject constant into nested SDFG + assert len(list(sdfg.all_sdfgs_recursive())) > 1 + sdfg.add_constant('cst', 13) # Add an unused constant + sdfg.cfg_list[-1].add_constant('cst', 1, dace.data.Scalar(dace.float64)) + tasklet = next(n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.Tasklet)) + tasklet.code.as_string = tasklet.code.as_string.replace('12', 'cst') + + # Perform a roundtrip conversion + stree = sdfg.as_schedule_tree() + new_sdfg = stree.as_sdfg() + + assert len(list(new_sdfg.all_sdfgs_recursive())) == 1 + assert new_sdfg.constants['cst_0'].dtype == np.float64 + + # Test SDFG + a = np.random.rand(20, 20) + new_sdfg(A=a) # Tests arg_names + assert np.allclose(a, 1) + + +def test_name_propagation(): + name = "my_complicated_sdfg_test_name" + sdfg = dace.SDFG(name) + sdfg.add_state("empty", is_start_block=True) + + stree = sdfg.as_schedule_tree() + assert stree.name == name + + sdfg = stree.as_sdfg() + assert sdfg.name == name + + +if __name__ == '__main__': + test_implicit_inline_and_constants() + test_name_propagation() diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index c15eb99f88..fc2f3ea314 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -3,11 +3,11 @@ import dace from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree -import numpy as np - from dace.transformation.pass_pipeline import FixedPointPipeline from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +import numpy as np + def test_for_in_map_in_for(): @@ -158,7 +158,7 @@ def test_irreducible_sub_sdfg(): stree = as_schedule_tree(sdfg) node_types = [type(n) for n in stree.preorder_traversal()] assert node_types.count(tn.GBlock) == 1 # Only one gblock - assert node_types.count(tn.LoopScope) == 1 # Check that the loop was detected + assert node_types.count(tn.ForScope) == 1 # Check that the for-loop was detected def test_irreducible_in_loops(): diff --git a/tests/schedule_tree/to_sdfg_test.py b/tests/schedule_tree/to_sdfg_test.py new file mode 100644 index 0000000000..5fd5e04fd1 --- /dev/null +++ b/tests/schedule_tree/to_sdfg_test.py @@ -0,0 +1,915 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tests components in conversion of schedule trees to SDFGs. +""" +import dace +from dace import data, subsets as sbs +from dace.codegen import control_flow as cf +from dace.properties import CodeBlock +from dace.sdfg import nodes +from dace.sdfg.analysis.schedule_tree import tree_to_sdfg as t2s, treenodes as tn +from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState + +import pytest + + +def test_state_boundaries_none(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert tn.StateBoundaryNode not in [type(n) for n in stree.children] + + +def test_state_boundaries_waw(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +@pytest.mark.parametrize('overlap', (False, True)) +def test_state_boundaries_waw_ranges(overlap): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + symbols={'N': N}, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'pass'), {}, {'out': dace.Memlet('A[0:N/2]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'pass'), {}, + {'out': dace.Memlet('A[1:N]' if overlap else 'A[N/2+1:N]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + if overlap: + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + else: + assert [tn.TaskletNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_war(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_read_write_chain(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('B[0]')}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla3', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.TaskletNode, tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_data_race(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + tn.TaskletNode(nodes.Tasklet('bla11', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('B[0]')}, + {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla3', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('B[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode, + tn.TaskletNode] == [type(n) for n in stree.children] + + +def test_state_boundaries_cfg(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla1', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + tn.ForScope(loop=cf.LoopRegion(label="for-loop", + condition_expr=CodeBlock("i < 20"), + loop_var="i", + initialize_expr=CodeBlock("i=0"), + update_expr=CodeBlock("i = i+1")), + children=[ + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = i'), {}, + {'out': dace.Memlet('A[1]')}), + ]), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [tn.TaskletNode, tn.StateBoundaryNode, tn.ForScope] == [type(n) for n in stree.children] + + +def test_state_boundaries_state_transition(): + # Manually create a schedule tree + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + symbols={ + 'N': dace.symbol('N'), + }, + children=[ + tn.AssignNode('irrelevant', CodeBlock('N + 1'), dace.InterstateEdge(assignments=dict(irrelevant='N + 1'))), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + tn.AssignNode('relevant', CodeBlock('A[1] + 2'), + dace.InterstateEdge(assignments=dict(relevant='A[1] + 2'))), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [ + tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode, tn.StateBoundaryNode, tn.AssignNode, tn.StateBoundaryNode + ] == [type(n) for n in stree.children] + + +@pytest.mark.parametrize('boundary', (False, True)) +def test_state_boundaries_propagation(boundary): + # Manually create a schedule tree + N = dace.symbol('N') + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + symbols={ + 'N': N, + }, + children=[ + tn.MapScope(node=dace.nodes.MapEntry(dace.nodes.Map('map', ['i'], dace.subsets.Range([(1, N - 1, 1)]))), + children=[ + tn.TaskletNode(nodes.Tasklet('inner', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[i]')}), + ]), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 2'), {}, + {'out': dace.Memlet('A[1]' if boundary else 'A[0]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + + node_types = [type(n) for n in stree.preorder_traversal()] + if boundary: + assert [tn.MapScope, tn.TaskletNode, tn.StateBoundaryNode, tn.TaskletNode] == node_types[1:] + else: + assert [tn.MapScope, tn.TaskletNode, tn.TaskletNode] == node_types[1:] + + +@pytest.mark.parametrize("control_flow", (True, False)) +def test_create_state_boundary_state_transition(control_flow): + sdfg = dace.SDFG("tester") + state = sdfg.add_state("start", is_start_block=True) + bnode = tn.StateBoundaryNode(control_flow) + + t2s.create_state_boundary(bnode, state, t2s.StateBoundaryBehavior.STATE_TRANSITION) + new_label = "cf_state_boundary" if control_flow else "state_boundary" + assert ["start", new_label] == [state.label for state in sdfg.states()] + + +@pytest.mark.xfail(reason="Not yet implemented") +@pytest.mark.parametrize("control_flow", (True, False)) +def test_create_state_boundary_empty_memlet(control_flow): + sdfg = dace.SDFG("tester") + state = sdfg.add_state("start", is_start_block=True) + bnode = tn.StateBoundaryNode(control_flow) + + t2s.create_state_boundary(bnode, state, t2s.StateBoundaryBehavior.EMPTY_MEMLET) + + +def test_create_tasklet_raw(): + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {'inp'}, {'out'}, 'out = inp + 1'), {'inp': dace.Memlet('A[1]')}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + sdfg = stree.as_sdfg() + assert len(sdfg.states()) == 1 + state = sdfg.states()[0] + first_tasklet, write_read_node, second_tasklet, write_node = state.nodes() + + assert first_tasklet.label == "bla" + assert not first_tasklet.in_connectors + assert first_tasklet.out_connectors.keys() == {"out"} + + assert second_tasklet.label == "bla2" + assert second_tasklet.in_connectors.keys() == {"inp"} + assert second_tasklet.out_connectors.keys() == {"out"} + + assert [(first_tasklet, write_read_node), (write_read_node, second_tasklet), + (second_tasklet, write_node)] == [(edge.src, edge.dst) for edge in state.edges()] + + +def test_create_tasklet_waw(): + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + sdfg = stree.as_sdfg() + assert len(sdfg.states()) == 2 + s1, s2 = sdfg.states() + + s1_tasklet, s1_anode = s1.nodes() + assert [(s1_tasklet, s1_anode)] == [(edge.src, edge.dst) for edge in s1.edges()] + + s2_tasklet, s2_anode = s2.nodes() + assert [(s2_tasklet, s2_anode)] == [(edge.src, edge.dst) for edge in s2.edges()] + + +def test_create_tasklet_war(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={"A": data.Array(dace.float64, [20])}, + children=[ + tn.TaskletNode( + nodes.Tasklet("read_write", {"read"}, {"write"}, "write = read + 1"), + {"read": dace.Memlet("A[1]")}, + {"write": dace.Memlet("A[1]")}, + ) + ], + ) + + sdfg = stree.as_sdfg() + + sdfg_states = list(sdfg.states()) + assert len(sdfg_states) == 1 + + state_nodes = list(sdfg_states[0].nodes()) + assert [node.name for node in state_nodes + if isinstance(node, nodes.Tasklet)] == ["read_write"], "Expect one Tasklet node." + assert [node.data for node in state_nodes + if isinstance(node, nodes.AccessNode)] == ["A", "A"], "Expect two AccessNodes for A." + + +def test_create_loop_for(): + for_scope = tn.ForScope( + loop=LoopRegion(label="my_for_loop", + loop_var="i", + initialize_expr=CodeBlock("i = 0 "), + condition_expr=CodeBlock("i < 3"), + update_expr=CodeBlock("i = i+1")), + children=[ + tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[for_scope]) + sdfg = stree.as_sdfg() + + loops = list(filter(lambda x: isinstance(x, LoopRegion), sdfg.cfg_list)) + assert len(loops) == 1, "SDFG contains one LoopRegion" + + loop: LoopRegion = loops[0] + assert loop.loop_variable == "i" + assert loop.init_statement == CodeBlock("i = 0") + assert loop.loop_condition == CodeBlock("i < 3") + assert loop.update_statement == CodeBlock("i = i+1") + + loop_states = list(filter(lambda x: isinstance(x, SDFGState), loop.nodes())) + assert len(loop_states) == 2, "Loop contains two states" + + tasklet_1: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[0].nodes()))[0] + assert tasklet_1.label == "assign_1" + tasklet_2: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[1].nodes()))[0] + assert tasklet_2.label == "assign_2" + + +def test_create_loop_while(): + while_scope = tn.WhileScope( + children=[ + tn.TaskletNode(nodes.Tasklet('assign_1', {}, {'out'}, 'out = 1'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('assign_2', {}, {'out'}, 'out = 2'), {}, {'out': dace.Memlet('A[1]')}), + ], + loop=LoopRegion( + label="my_while_loop", + condition_expr=CodeBlock("A[1] > 5"), + ), + ) + + stree = tn.ScheduleTreeRoot(name='tester', containers={'A': data.Array(dace.float64, [20])}, children=[while_scope]) + + sdfg = stree.as_sdfg() + + loops = list(filter(lambda x: isinstance(x, LoopRegion), sdfg.cfg_list)) + assert len(loops) == 1, "SDFG contains one LoopRegion" + + loop: LoopRegion = loops[0] + assert loop.loop_variable == "" + assert loop.init_statement == None + assert loop.loop_condition == CodeBlock("A[1] > 5") + assert loop.update_statement == None + + loop_states = list(filter(lambda x: isinstance(x, SDFGState), loop.nodes())) + assert len(loop_states) == 2, "Loop contains two states" + + tasklet_1: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[0].nodes()))[0] + assert tasklet_1.label == "assign_1" + tasklet_2: nodes.Tasklet = list(filter(lambda x: isinstance(x, nodes.Tasklet), loop_states[1].nodes()))[0] + assert tasklet_2.label == "assign_2" + + +def test_create_if_else(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, {"out": dace.Memlet("A[1]")}) + ]), + ]) + + sdfg = stree.as_sdfg() + + blocks = list(filter(lambda x: isinstance(x, ConditionalBlock), sdfg.cfg_list)) + assert len(blocks) == 1, "SDFG contains one ConditionalBlock" + + block: ConditionalBlock = blocks[0] + assert len(block.branches) == 2, "Block contains two branches" + + if_branch = list(filter(lambda x: x[0] is not None, block.branches))[0] + tasklets = list(filter(lambda x: isinstance(x, nodes.Tasklet), if_branch[1].nodes()[0].nodes())) + assert if_branch[0] == CodeBlock("A[0] > 0"), "If branch has condition" + assert tasklets[0].label == "bla", "If branch contains Tasklet('bla')" + + else_branch = list(filter(lambda x: x[0] is None, block.branches))[0] + tasklets = list(filter(lambda x: isinstance(x, nodes.Tasklet), else_branch[1].nodes()[0].nodes())) + assert else_branch[0] is None, "Else branch has no condition" + assert tasklets[0].label == "blub", "Else branch contains Tasklet('blub')" + + +@pytest.mark.xfail(reason="Not yet implemented") +def test_create_if_elif_else() -> None: + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), + tn.ElifScope( + condition=CodeBlock("A[0] == 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("blub", {}, {"out"}, "out=2"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("test", {}, {"out"}, "out=3"), {}, {"out": dace.Memlet("A[1]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + + blocks = list(filter(lambda x: isinstance(x, ConditionalBlock), sdfg.cfg_list)) + assert len(blocks) == 1, "SDFG contains one ConditionalBlock" + + block: ConditionalBlock = blocks[0] + assert len(block.branches) == 3, "Block contains three branches" + + +def test_create_if_without_else(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("bla", {}, {"out"}, "out=1"), {}, {"out": dace.Memlet("A[1]")}), + ], + ), + ], + ) + + sdfg = stree.as_sdfg() + + blocks = list(filter(lambda x: isinstance(x, ConditionalBlock), sdfg.cfg_list)) + assert len(blocks) == 1, "SDFG contains one ConditionalBlock" + + block: ConditionalBlock = blocks[0] + assert len(block.branches) == 1, "Block contains one branch" + + branch = list(filter(lambda x: x[0] is not None, block.branches))[0] + tasklets = list(filter(lambda x: isinstance(x, nodes.Tasklet), branch[1].nodes()[0].nodes())) + assert branch[0] == CodeBlock("A[0] > 0"), "Branch has condition" + assert tasklets[0].label == "bla", "Branch contains Tasklet('bla')" + + +def test_create_map_scope_write(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("assign_i", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("A[i]")}) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_map_scope_read_after_write(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20], transient=True), + }, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("write", {}, {"out"}, "out = i"), {}, {"out": dace.Memlet("B[i]")}), + tn.TaskletNode(nodes.Tasklet("read", {"in_field"}, {"out_field"}, "out_field = in_field"), + {"in_field": dace.Memlet("B[i]")}, {"out_field": dace.Memlet("A[i]")}) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_map_scope_write_after_read(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={"A": data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("read_write", {"read"}, {"write"}, "write = read+1"), + {"read": dace.Memlet("A[i]")}, {"write": dace.Memlet("A[i]")}) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_map_scope_copy(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), + }, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("copy", {"inp"}, {"out"}, "out = inp"), {"inp": dace.Memlet("A[i]")}, + {"out": dace.Memlet("B[i]")}) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_map_scope_double_memlet(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [20]), + }, + children=[ + tn.MapScope(node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:10"))), + children=[ + tn.TaskletNode(nodes.Tasklet("sum", {"first", "second"}, {"out"}, "out = first + second"), { + "first": dace.Memlet("A[i]"), + "second": dace.Memlet("A[i+10]") + }, {"out": dace.Memlet("B[i]")}) + ]) + ]) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_nested_map_scope(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = i*5+j"), {}, + {"out": dace.Memlet("A[i*5+j]")}) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_double_map_with_for_loop(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.ForScope( + loop=LoopRegion( + label="loop_k", + loop_var="k", + initialize_expr=CodeBlock("k = 0 "), + condition_expr=CodeBlock("k < 3"), + update_expr=CodeBlock("k = k+1"), + ), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1.0"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + assert sdfg.is_valid() + + +def test_triple_map_flat_if(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [60])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", "k", sbs.Range.from_string("0:3"))), + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 2"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], ), + ], + ) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_triple_map_nested_if(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [60])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_i", "i", sbs.Range.from_string("0:4"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_j", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("map_k", "k", sbs.Range.from_string("0:3"))), + children=[ + tn.IfScope( + condition=CodeBlock("A[0] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 1"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.IfScope( + condition=CodeBlock("A[1] > 0"), + children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 2"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], + ), + tn.ElseScope(children=[ + tn.TaskletNode(nodes.Tasklet("assign", {}, {"out"}, "out = 3"), {}, + {"out": dace.Memlet("A[i*15+j*3+k]")}) + ], ) + ], ), + ], + ) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_create_nested_map_scope_multi_read(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + 'A': data.Array(dace.float64, [20]), + 'B': data.Array(dace.float64, [10]) + }, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:2"))), + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("blub", "j", sbs.Range.from_string("0:5"))), + children=[ + tn.TaskletNode(nodes.Tasklet("asdf", {"a_1", "a_2"}, {"out"}, "out = a_1 + a_2"), { + "a_1": dace.Memlet("A[i*5+j]"), + "a_2": dace.Memlet("A[10+i*5+j]"), + }, {"out": dace.Memlet("B[i*5+j]")}) + ], + ) + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_map_with_state_boundary_inside(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={'A': data.Array(dace.float64, [20])}, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("bla", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet('A[1]')}), + tn.TaskletNode(nodes.Tasklet('bla2', {}, {'out'}, 'out = 2*i'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + ], + ) + + sdfg = stree.as_sdfg() + sdfg.validate() + + +def test_map_calculate_temporary_in_two_loops(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={ + "A": data.Array(dace.float64, [20]), + "tmp": data.Array(dace.float64, [20], transient=True) + }, + children=[ + tn.MapScope( + node=nodes.MapEntry(nodes.Map("first_half", "i", sbs.Range.from_string("0:10"))), + children=[ + tn.TaskletNode(nodes.Tasklet("beginning", {}, {'out'}, 'out = i'), {}, + {'out': dace.Memlet("tmp[i]")}) + ], + ), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("second_half", "i", sbs.Range.from_string("10:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("end", {}, {'out'}, 'out = i'), {}, {'out': dace.Memlet("tmp[i]")}) + ], + ), + tn.MapScope( + node=nodes.MapEntry(nodes.Map("read_tmp", "i", sbs.Range.from_string("0:20"))), + children=[ + tn.TaskletNode(nodes.Tasklet("read_temp", {"read"}, {"out"}, "out = read + 1"), + {"read": dace.Memlet("tmp[i]")}, {"out": dace.Memlet("A[i]")}) + ], + ) + ], + ) + + sdfg = stree.as_sdfg(simplify=True) + sdfg.validate() + + assert [node.name for node, _ in sdfg.all_nodes_recursive() + if isinstance(node, nodes.Tasklet)] == ["beginning", "end", "read_temp"] + + +def test_edge_assignment_read_after_write(): + stree = tn.ScheduleTreeRoot( + name="tester", + containers={}, + children=[ + tn.AssignNode("my_condition", CodeBlock("True"), dace.InterstateEdge()), + tn.AssignNode("condition", CodeBlock("my_condition"), dace.InterstateEdge()), + tn.StateBoundaryNode(), + ], + ) + + sdfg = stree.as_sdfg(simplify=False) + + assert [node.name for node in sdfg.nodes()] == ["tree_root", "state_boundary", "state_boundary_0"] + assert [edge.data.assignments for edge in sdfg.edges()] == [{"my_condition": "True"}, {"condition": "my_condition"}] + + +def test_assign_nodes_force_state_transition(): + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = mySymbol'), {}, {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] + + +def test_assign_nodes_multiple_force_one_transition(): + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), + tn.AssignNode("myOtherSymbol", CodeBlock("2"), dace.InterstateEdge()), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = mySymbol + myOtherSymbol'), {}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [type(child) + for child in stree.children] == [tn.AssignNode, tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] + + +def test_assign_nodes_avoid_duplicate_boundaries(): + stree = tn.ScheduleTreeRoot( + name='tester', + containers={ + 'A': data.Array(dace.float64, [20]), + }, + children=[ + tn.AssignNode("mySymbol", CodeBlock("1"), dace.InterstateEdge()), + tn.StateBoundaryNode(), + tn.TaskletNode(nodes.Tasklet('bla', {}, {'out'}, 'out = mySymbol + myOtherSymbol'), {}, + {'out': dace.Memlet('A[1]')}), + ], + ) + + stree = t2s.insert_state_boundaries_to_tree(stree) + assert [type(child) for child in stree.children] == [tn.AssignNode, tn.StateBoundaryNode, tn.TaskletNode] + + +if __name__ == '__main__': + test_state_boundaries_none() + test_state_boundaries_waw() + test_state_boundaries_waw_ranges(overlap=False) + test_state_boundaries_waw_ranges(overlap=True) + test_state_boundaries_war() + test_state_boundaries_read_write_chain() + test_state_boundaries_data_race() + test_state_boundaries_cfg() + test_state_boundaries_state_transition() + test_state_boundaries_propagation(boundary=False) + test_state_boundaries_propagation(boundary=True) + test_create_state_boundary_state_transition(control_flow=True) + test_create_state_boundary_state_transition(control_flow=False) + # test_create_state_boundary_empty_memlet(control_flow=True) + # test_create_state_boundary_empty_memlet(control_flow=False) + test_create_tasklet_raw() + test_create_tasklet_waw() + test_create_loop_for() + test_create_loop_while() + test_create_if_else() + test_create_if_without_else() + test_create_map_scope_write() + test_create_map_scope_copy() + test_create_map_scope_double_memlet() + test_create_nested_map_scope() + test_create_nested_map_scope_multi_read() + test_map_with_state_boundary_inside() + test_edge_assignment_read_after_write() diff --git a/tests/schedule_tree/treenodes_test.py b/tests/schedule_tree/treenodes_test.py new file mode 100644 index 0000000000..31a0abbd21 --- /dev/null +++ b/tests/schedule_tree/treenodes_test.py @@ -0,0 +1,126 @@ +# Copyright 2019-2026 ETH Zurich and the DaCe authors. All rights reserved. + +from dace.sdfg.analysis.schedule_tree import treenodes as tn +from dace import nodes + +import pytest + + +@pytest.fixture +def tasklet() -> nodes.Tasklet: + return tn.TaskletNode(nodes.Tasklet("noop", {}, {}, code="pass"), {}, {}) + + +@pytest.mark.parametrize('ScopeClass', ( + tn.ScheduleTreeScope, + tn.ControlFlowScope, + tn.GBlock, + tn.ElseScope, +)) +def test_schedule_tree_scope_children(ScopeClass: type[tn.ScheduleTreeScope], tasklet: nodes.Tasklet) -> None: + scope = ScopeClass(children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = ScopeClass(children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = ScopeClass(children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +@pytest.mark.parametrize('LoopScope', ( + tn.LoopScope, + tn.ForScope, + tn.WhileScope, + tn.DoWhileScope, +)) +def test_loop_scope_children(LoopScope: type[tn.LoopScope], tasklet: nodes.Tasklet) -> None: + scope = LoopScope(loop=None, children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = LoopScope(loop=None, children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = LoopScope(loop=None, children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +@pytest.mark.parametrize('IfScope', ( + tn.IfScope, + tn.StateIfScope, + tn.ElifScope, +)) +def test_if_scope_children(IfScope: type[tn.IfScope], tasklet: nodes.Tasklet) -> None: + scope = IfScope(condition=None, children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = IfScope(condition=None, children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = IfScope(condition=None, children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +@pytest.mark.parametrize('DataflowScope', ( + tn.DataflowScope, + tn.MapScope, + tn.ConsumeScope, +)) +def test_dataflow_scope_children(DataflowScope: type[tn.DataflowScope], tasklet: nodes.Tasklet) -> None: + scope = DataflowScope(node=None, children=[tasklet]) + + for child in scope.children: + assert child.parent == scope + + scope = DataflowScope(node=None, children=[]) + scope.add_child(tasklet) + + for child in scope.children: + assert child.parent == scope + + scope = DataflowScope(node=None, children=[]) + scope.add_children([tasklet]) + + for child in scope.children: + assert child.parent == scope + + +if __name__ == '__main__': + test_schedule_tree_scope_children(tn.ScheduleTreeScope, tasklet) + test_schedule_tree_scope_children(tn.ControlFlowScope, tasklet) + test_schedule_tree_scope_children(tn.GBlock, tasklet) + test_schedule_tree_scope_children(tn.ElseScope, tasklet) + test_loop_scope_children(tn.LoopScope, tasklet) + test_loop_scope_children(tn.ForScope, tasklet) + test_loop_scope_children(tn.WhileScope, tasklet) + test_loop_scope_children(tn.DoWhileScope, tasklet) + test_if_scope_children(tn.IfScope, tasklet) + test_if_scope_children(tn.StateIfScope, tasklet) + test_if_scope_children(tn.ElifScope, tasklet) + test_dataflow_scope_children(tn.DataflowScope, tasklet) + test_dataflow_scope_children(tn.MapScope, tasklet) + test_dataflow_scope_children(tn.ConsumeScope, tasklet) diff --git a/tests/sdfg/loop_region_test.py b/tests/sdfg/loop_region_test.py index f09bf0004f..011787c98b 100644 --- a/tests/sdfg/loop_region_test.py +++ b/tests/sdfg/loop_region_test.py @@ -311,7 +311,7 @@ def test_loop_to_stree_triple_nested_for(): stree = s2t.as_schedule_tree(sdfg) po_nodes = list(stree.preorder_traversal())[1:] - assert [type(n) for n in po_nodes] == [tn.LoopScope, tn.LoopScope, tn.LoopScope, tn.TaskletNode, tn.LibraryCall] + assert [type(n) for n in po_nodes] == [tn.ForScope, tn.ForScope, tn.ForScope, tn.TaskletNode, tn.LibraryCall] if __name__ == '__main__': diff --git a/tests/sdfg/memlet_utils_test.py b/tests/sdfg/memlet_utils_test.py index 3c85d72f21..c4a78991c7 100644 --- a/tests/sdfg/memlet_utils_test.py +++ b/tests/sdfg/memlet_utils_test.py @@ -1,10 +1,10 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +import copy import dace import numpy as np import pytest -from dace import symbolic -from dace.sdfg import memlet_utils as mu +from dace.sdfg import graph, memlet_utils as mu import re from typing import Tuple, Optional @@ -12,6 +12,7 @@ def _replace_zero_with_one(memlet: dace.Memlet) -> dace.Memlet: if not isinstance(memlet.subset, dace.subsets.Range): return memlet + for i, (rb, re, rs) in enumerate(memlet.subset.ndrange()): if rb == 0: memlet.subset.ranges[i] = (1, 1, rs) @@ -19,7 +20,7 @@ def _replace_zero_with_one(memlet: dace.Memlet) -> dace.Memlet: @pytest.mark.parametrize('filter_type', ['none', 'same_array', 'different_array']) -def test_replace_memlet(filter_type): +def test_replace_memlet(filter_type: str) -> None: # Prepare SDFG sdfg = dace.SDFG('replace_memlet') sdfg.add_array('A', [2, 2], dace.float64) @@ -66,7 +67,7 @@ def test_replace_memlet(filter_type): assert B[0] == 1 -def _perform_non_lin_delin_test(sdfg: dace.SDFG, edge) -> bool: +def _perform_non_lin_delin_test(sdfg: dace.SDFG, edge: graph.MultiConnectorEdge) -> None: assert sdfg.number_of_nodes() == 1 state: dace.SDFGState = sdfg.states()[0] assert state.number_of_nodes() == 2 @@ -104,8 +105,6 @@ def _perform_non_lin_delin_test(sdfg: dace.SDFG, edge) -> bool: sdfg(a=a, b=b_opt) assert np.allclose(b_unopt, b_opt) - return True - def _make_non_lin_delin_sdfg( shape_a: Tuple[int, ...], @@ -211,6 +210,62 @@ def test_non_lin_delin_8(): _perform_non_lin_delin_test(sdfg, e) +def test_MemletSet() -> None: + empty_set = mu.MemletSet() + assert len(empty_set) == 0 + + memlet_set = mu.MemletSet([dace.Memlet("A[0:5]")]) + covered_set = mu.MemletSet([dace.Memlet("A[0:5]")], intersection_is_contained=False) + + assert dace.Memlet("A[0:2]") in memlet_set + assert dace.Memlet("A[0:2]") in covered_set + assert dace.Memlet("A[2:7]") in memlet_set + assert dace.Memlet("A[2:7]") not in covered_set + + assert dace.Memlet("B[0:2]") not in memlet_set + + before = copy.deepcopy(covered_set.internal_set) + covered_set.add(dace.Memlet("A[0:2]")) + assert covered_set.internal_set == before + + covered_set.add(dace.Memlet("A[4:9]")) + assert dace.Memlet("A[2:7]") in covered_set + assert covered_set.internal_set != before + + union = empty_set.union(dace.Memlet("A[0:3]"), dace.Memlet("A[2:10]")) + assert dace.Memlet("A[5:7]") not in empty_set + assert dace.Memlet("A[5:7]") in union + assert len(union.internal_set["A"]) == 1 + internal_memlet = list(union.internal_set["A"])[0] + assert internal_memlet.subset == dace.subsets.Range.from_string("0:10") + + +def test_MemletDict() -> None: + A_01 = dace.Memlet("A[0:1]") + A_02 = dace.Memlet("A[0:2]") + A_34 = dace.Memlet("A[3:4]") + memlet_dict: mu.Memlet[list[int]] = mu.MemletDict() + assert len(memlet_dict) == 0 + assert A_02 not in memlet_dict + + memlet_dict[A_02] = [42] + assert A_02 in memlet_dict + assert A_01 in memlet_dict + assert A_34 not in memlet_dict + assert dace.Memlet("B[0:2]") not in memlet_dict + + memlet_dict[A_01].append(43) + assert memlet_dict[A_02] == [42, 43] + + memlet_dict.update({A_34: [0], A_01: [44]}) + assert A_34 in memlet_dict + assert memlet_dict[A_02] == [44] # @Tal this is expected, right? + assert memlet_dict[A_34] == [0] + + memlet_dict.clear() + assert len(memlet_dict) == 0 + + if __name__ == '__main__': test_replace_memlet('none') test_replace_memlet('same_array') @@ -224,3 +279,6 @@ def test_non_lin_delin_8(): test_non_lin_delin_6() test_non_lin_delin_7() test_non_lin_delin_8() + + test_MemletSet() + test_MemletDict()