From fc58b3ee16e08a0e4a8d9f24cb89013a9bb445f5 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 12:42:23 +0100 Subject: [PATCH 01/17] prepare --- dace/config_schema.yml | 2 +- dace/dtypes.py | 3 + dace/sdfg/construction_utils.py | 898 +++++++ dace/sdfg/utils.py | 260 +- .../interstate/branch_elimination.py | 1985 +++++++++++++++ dace/transformation/passes/__init__.py | 2 +- .../passes/eliminate_branches.py | 71 + .../passes/explicit_vectorization_cpu.py | 402 ++++ .../passes/explicit_vectorization_gpu.py | 129 + .../interstate/branch_elimination_test.py | 2136 +++++++++++++++++ ...rate_assignment_as_tasklet_instate_test.py | 23 + 11 files changed, 5758 insertions(+), 153 deletions(-) create mode 100644 dace/sdfg/construction_utils.py create mode 100644 dace/transformation/interstate/branch_elimination.py create mode 100644 dace/transformation/passes/eliminate_branches.py create mode 100644 dace/transformation/passes/explicit_vectorization_cpu.py create mode 100644 dace/transformation/passes/explicit_vectorization_gpu.py create mode 100644 tests/transformations/interstate/branch_elimination_test.py create mode 100644 tests/utils/generate_assignment_as_tasklet_instate_test.py diff --git a/dace/config_schema.yml b/dace/config_schema.yml index 812e24329e..a753a55f3b 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -261,7 +261,7 @@ required: type: str title: Arguments description: Compiler argument flags - default: '-std=c++14 -fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' + default: '-fopenmp -std=c++14 -fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' default_Windows: '/O2 /fp:fast /arch:AVX2 /D_USRDLL /D_WINDLL /D__restrict__=__restrict' libs: diff --git a/dace/dtypes.py b/dace/dtypes.py index e018384de3..2712d66685 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -1258,6 +1258,9 @@ def isconstant(var): string = stringtype() MPI_Request = opaque('MPI_Request') +FLOAT_TYPES = {float64, float32, float16} +INT_TYPES = {int8, int16, int32, int64, uintp, uint8, uint16, uint32, uint64} + @undefined_safe_enum @extensible_enum diff --git a/dace/sdfg/construction_utils.py b/dace/sdfg/construction_utils.py new file mode 100644 index 0000000000..58c201441b --- /dev/null +++ b/dace/sdfg/construction_utils.py @@ -0,0 +1,898 @@ +import re +from typing import Dict, Set, Union +import dace +import copy + +from dace.sdfg import ControlFlowRegion +from dace.sdfg.propagation import propagate_memlets_state +import copy +from dace.properties import CodeBlock +from dace.sdfg.state import ConditionalBlock, LoopRegion + +import sympy +from sympy import symbols, Function + +from sympy.printing.pycode import PythonCodePrinter +import dace.sdfg.utils as sdutil +from dace.transformation.passes import FuseStates + + +class BracketFunctionPrinter(PythonCodePrinter): + + def _print_Function(self, expr): + name = self._print(expr.func) + args = ", ".join([self._print(arg) for arg in expr.args]) + return f"{name}[{args}]" + + +def copy_state_contents(old_state: dace.SDFGState, new_state: dace.SDFGState) -> Dict[dace.nodes.Node, dace.nodes.Node]: + """ + Deep-copies all nodes and edges from one SDFG state into another. + + Args: + old_state: The source SDFG state to copy from. + new_state: The destination SDFG state to copy into. + + Returns: + A mapping from original nodes in `old_state` to their deep-copied + counterparts in `new_state`. + + Notes: + - Node objects are deep-copied. + - Edge data are also deep-copied. + - Connections between the newly created nodes are preserved. + """ + node_map = dict() + + # Copy all nodes + for n in old_state.nodes(): + c_n = copy.deepcopy(n) + node_map[n] = c_n + new_state.add_node(c_n) + + # Copy all edges, reconnecting them to their new node counterparts + for e in old_state.edges(): + c_src = node_map[e.src] + c_dst = node_map[e.dst] + new_state.add_edge(c_src, e.src_conn, c_dst, e.dst_conn, copy.deepcopy(e.data)) + + return node_map + + +def copy_graph_contents(old_graph: ControlFlowRegion, + new_graph: ControlFlowRegion) -> Dict[dace.nodes.Node, dace.nodes.Node]: + """ + Deep-copies all nodes and edges from one SDFG state into another. + + Args: + old_state: The source SDFG state to copy from. + new_state: The destination SDFG state to copy into. + + Returns: + A mapping from original nodes in `old_state` to their deep-copied + counterparts in `new_state`. + + Notes: + - Node objects are deep-copied. + - Edge data are also deep-copied. + - Connections between the newly created nodes are preserved. + """ + assert isinstance(old_graph, ControlFlowRegion) + assert isinstance(new_graph, ControlFlowRegion) + + node_map = dict() + + # Copy all nodes + for n in old_graph.nodes(): + c_n = copy.deepcopy(n) + node_map[n] = c_n + new_graph.add_node(c_n, is_start_block=old_graph.start_block == n) + + # Copy all edges, reconnecting them to their new node counterparts + for e in old_graph.edges(): + c_src = node_map[e.src] + c_dst = node_map[e.dst] + new_graph.add_edge(c_src, c_dst, copy.deepcopy(e.data)) + + sdutil.set_nested_sdfg_parent_references(new_graph.sdfg) + + return node_map + + +def move_branch_cfg_up_discard_conditions(if_block: ConditionalBlock, body_to_take: ControlFlowRegion): + # Sanity check the ensure apssed arguments are correct + bodies = {b for _, b in if_block.branches} + assert body_to_take in bodies + assert isinstance(if_block, ConditionalBlock) + + graph = if_block.parent_graph + + node_map = dict() + # Save end and start blocks for reconnections + new_start_block = None + new_end_block = None + + for node in body_to_take.nodes(): + # Copy over nodes + copynode = copy.deepcopy(node) + node_map[node] = copynode + # Check if we need to have a new start state + start_block_case = (body_to_take.start_block == node) and (graph.start_block == if_block) + if body_to_take.start_block == node: + assert new_start_block is None + new_start_block = copynode + if body_to_take.out_degree(node) == 0: + assert new_end_block is None + new_end_block = copynode + graph.add_node(copynode, is_start_block=start_block_case) + + for edge in body_to_take.edges(): + src = node_map[edge.src] + dst = node_map[edge.dst] + graph.add_edge(src, dst, copy.deepcopy(edge.data)) + + for ie in graph.in_edges(if_block): + graph.add_edge(ie.src, new_start_block, copy.deepcopy(ie.data)) + for oe in graph.out_edges(if_block): + graph.add_edge(new_end_block, oe.dst, copy.deepcopy(oe.data)) + + graph.remove_node(if_block) + + +# Put map-body into NSDFG +# Convert Map to Loop +# Put map into NSDFG + + +def insert_non_transient_data_through_parent_scopes(non_transient_data: Set[str], + nsdfg_node: 'dace.nodes.NestedSDFG', + parent_graph: 'dace.SDFGState', + parent_sdfg: 'dace.SDFG', + add_to_output_too: bool = False, + add_with_exact_subset: bool = False, + exact_subset: Union[None, dace.subsets.Range] = None, + nsdfg_connector_name: Union[str, None] = None): + """ + Inserts non-transient data containers into all relevant parent scopes (through all map scopes). + + This function connect data from top-level data + into nested SDFGs (and vice versa) by connecting AccessNodes, MapEntries, + and NestedSDFG connectors appropriately. + + Args: + non_transient_data: Set of data container names to propagate. + nsdfg_node: The nested SDFG node where the data should be connected. + parent_graph: The parent SDFG state that contains the NestedSDFG node. + parent_sdfg: The parent SDFG corresponding to `parent_graph.sdfg`. + add_to_output_too: If True, also connect the data as an output from the nested SDFG. + add_with_exact_subset: If True, use an explicitly provided subset for the memlet. + exact_subset: The explicit subset (if any) to use when `add_with_exact_subset` is True. + + Behavior: + - Adds data descriptors for any missing non-transient arrays to both + the parent SDFG and the nested SDFG. + - Connects data through all enclosing parent scopes (e.g., nested maps). + - Optionally adds symmetric output connections. + - Propagates memlets if exact subsets are used. + - Adds any newly required symbols (from shapes or strides) to the nested SDFG. + """ + + descs = [None] * len(non_transient_data) + assert len(descs) == len(non_transient_data) + + for data_access, desc in zip(non_transient_data, descs): + datadesc = desc or parent_sdfg.arrays[data_access] + assert isinstance(parent_graph, dace.SDFGState), "Parent graph must be a SDFGState" + inner_sdfg: dace.SDFG = nsdfg_node.sdfg + + # Skip if the connector already exists and is wired + if (data_access in nsdfg_node.in_connectors + and len(list(parent_graph.in_edges_by_connector(nsdfg_node, data_access))) > 0): + continue + + # Remove conflicting symbols in nested SDFG + if data_access in inner_sdfg.symbols: + inner_sdfg.remove_symbol(data_access) + + # Add the data descriptor to the nested SDFG if missing + inner_data_access = data_access if nsdfg_connector_name is None else nsdfg_connector_name + if inner_data_access not in inner_sdfg.arrays: + copydesc = copy.deepcopy(datadesc) + copydesc.transient = False + inner_sdfg.add_datadesc(name=inner_data_access, datadesc=copydesc) + + # Ensure the parent also has the data descriptor + if data_access not in parent_sdfg.arrays: + copydesc = copy.deepcopy(datadesc) + copydesc.transient = False + parent_sdfg.add_datadesc(name=data_access, datadesc=copydesc) + + # Collect enclosing map scopes to route data through + parent_scopes = [] + cur_parent_scope = nsdfg_node + scope_dict = parent_graph.scope_dict() + while scope_dict[cur_parent_scope] is not None: + parent_scopes.append(scope_dict[cur_parent_scope]) + cur_parent_scope = scope_dict[cur_parent_scope] + + # Helper: choose between full or exact-subset memlet + def _get_memlet(it_id: int, data_access: str, datadesc: dace.data.Data): + if add_with_exact_subset: + return dace.memlet.Memlet(data=data_access, subset=copy.deepcopy(exact_subset)) + else: + return dace.memlet.Memlet.from_array(data_access, datadesc) + + # --- Add input connection path --- + + state = { + 'cur_in_conn_name': f"IN_{data_access}_p", + 'cur_out_conn_name': f"OUT_{data_access}_p", + 'cur_name_set': False, + } + + def _get_in_conn_name(dst, state=state): + if state['cur_name_set'] is False: + i = 0 + while (state['cur_in_conn_name'] in dst.in_connectors + or state['cur_out_conn_name'] in dst.out_connectors): + state['cur_in_conn_name'] = f"IN_{data_access}_p_{i}" + state['cur_out_conn_name'] = f"OUT_{data_access}_p_{i}" + i += 1 + state['cur_name_set'] = True + + inner_data_access = data_access if nsdfg_connector_name is None else nsdfg_connector_name + + if isinstance(dst, dace.nodes.AccessNode): + return None + elif isinstance(dst, dace.nodes.NestedSDFG): + return inner_data_access + else: + return state['cur_in_conn_name'] + + def _get_out_conn_name(src, state=state): + if state['cur_name_set'] is False: + i = 0 + while (state['cur_in_conn_name'] in src.in_connectors + or state['cur_out_conn_name'] in src.out_connectors): + state['cur_in_conn_name'] = f"IN_{data_access}_p_{i}" + state['cur_out_conn_name'] = f"OUT_{data_access}_p_{i}" + i += 1 + state['cur_name_set'] = True + + inner_data_access = data_access if nsdfg_connector_name is None else nsdfg_connector_name + if isinstance(src, dace.nodes.AccessNode): + return None + elif isinstance(src, dace.nodes.NestedSDFG): + return inner_data_access + else: + return state['cur_out_conn_name'] + + an = parent_graph.add_access(data_access) + src = an + for it_id, parent_scope in enumerate(reversed(parent_scopes)): + dst = parent_scope + # Initialize state with a parent map + _get_in_conn_name(dst) + + parent_graph.add_edge( + src, + _get_out_conn_name(src), + dst, + _get_in_conn_name(dst), + _get_memlet(it_id, data_access, datadesc), + ) + # Ensure connectors exist + if not isinstance(src, dace.nodes.AccessNode): + src.add_out_connector(_get_out_conn_name(src), force=True) + if isinstance(dst, dace.nodes.NestedSDFG): + dst.add_in_connector(_get_in_conn_name(dst), force=True) + else: + dst.add_in_connector(_get_in_conn_name(dst)) + src = parent_scope + + # Connect final edge to the NestedSDFG + dst = nsdfg_node + parent_graph.add_edge( + src, + _get_out_conn_name(src), + dst, + _get_in_conn_name(dst), + _get_memlet(it_id, data_access, datadesc), + ) + if not isinstance(src, dace.nodes.AccessNode): + src.add_out_connector(_get_out_conn_name(src), force=True) + if isinstance(dst, dace.nodes.NestedSDFG): + dst.add_in_connector(_get_in_conn_name(dst), force=True) + else: + dst.add_in_connector(_get_in_conn_name(dst), force=True) + + # --- Optionally add output connection path --- + if add_to_output_too: + an = parent_graph.add_access(data_access) + dst = an + for it_id, parent_scope in enumerate(reversed(parent_scopes)): + src = parent_graph.exit_node(parent_scope) + parent_graph.add_edge( + src, + _get_out_conn_name(src), + dst, + _get_in_conn_name(dst), + _get_memlet(it_id, data_access, datadesc), + ) + if not isinstance(dst, dace.nodes.AccessNode): + dst.add_in_connector(_get_in_conn_name(dst), force=True) + if isinstance(src, dace.nodes.NestedSDFG): + src.add_out_connector(_get_out_conn_name(src), force=True) + else: + src.add_out_connector(_get_out_conn_name(src), ) + dst = src + src = nsdfg_node + parent_graph.add_edge( + src, + _get_out_conn_name(src), + dst, + _get_in_conn_name(dst), + _get_memlet(it_id, data_access, datadesc), + ) + if not isinstance(dst, dace.nodes.AccessNode): + dst.add_in_connector(f"IN_{data_access}_p", force=True) + src.add_out_connector(_get_out_conn_name(dst)) + + parent_graph.sdfg.save("x.sdfg") + + # Re-propagate memlets when subsets are explicit + if add_with_exact_subset: + propagate_memlets_state(parent_graph.sdfg, parent_graph) + + # Add any free symbols from array shapes/strides to the nested SDFG + new_symbols = set() + for data_access, desc in zip(non_transient_data, descs): + if desc is None: + desc = parent_graph.sdfg.arrays[data_access] + data_free_syms = set() + for dim, stride in zip(desc.shape, desc.strides): + dim_expr = dace.symbolic.SymExpr(dim) + stride_expr = dace.symbolic.SymExpr(stride) + if not isinstance(stride_expr, int): + data_free_syms |= stride_expr.free_symbols + if not isinstance(dim_expr, int): + data_free_syms |= dim_expr.free_symbols + new_symbols |= data_free_syms + + defined_syms = parent_graph.symbols_defined_at(nsdfg_node) + for sym in new_symbols: + if str(sym) not in nsdfg_node.sdfg.symbols: + nsdfg_node.sdfg.add_symbol(str(sym), defined_syms[str(sym)]) + if str(sym) not in nsdfg_node.symbol_mapping: + nsdfg_node.symbol_mapping[str(sym)] = str(sym) + + +def token_replace_dict(code: str, repldict: Dict[str, str]) -> str: + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', code) + + # Replace tokens that exactly match src + tokens = [repldict[token.strip()] if token.strip() in repldict else token for token in tokens] + + # Recombine everything + return ''.join(tokens).strip() + + +def token_match(string_to_check: str, pattern_str: str) -> str: + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', string_to_check) + + # Replace tokens that exactly match src + tokens = {token.strip() for token in tokens} + + return pattern_str in tokens + + +def token_split(string_to_check: str, pattern_str: str) -> Set[str]: + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', string_to_check) + + # Replace tokens that exactly match src + tokens = {token.strip() for token in tokens} + + return tokens + + +def token_split_variable_names(string_to_check: str) -> Set[str]: + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', string_to_check) + + # Replace tokens that exactly match src + tokens = {token.strip() for token in tokens if token not in ["[", "]", "(", ")"] and token.isidentifier()} + + return tokens + + +def replace_length_one_arrays_with_scalars(sdfg: dace.SDFG, recursive: bool = True, transient_only: bool = False): + scalarized_arrays = set() + for arr_name, arr in [(k, v) for k, v in sdfg.arrays.items()]: + if isinstance(arr, dace.data.Array) and (arr.shape == (1, ) or arr.shape == [ + 1, + ]): + if (not transient_only) or arr.transient: + sdfg.remove_data(arr_name, False) + sdfg.add_scalar(name=arr_name, + dtype=arr.dtype, + storage=arr.storage, + transient=arr.transient, + lifetime=arr.lifetime, + debuginfo=arr.debuginfo, + find_new_name=False) + scalarized_arrays.add(arr_name) + print(f"Making {arr_name} into scalar") + + # Replace [0] accesses of scalars (formerly array ones) on interstate edges + for edge in sdfg.all_interstate_edges(): + new_dict = dict() + for k, v in edge.data.assignments.items(): + nv = v + for scalar_name in scalarized_arrays: + if f"{scalar_name}[0]" in nv: + nv = nv.replace(f"{scalar_name}[0]", scalar_name) + new_dict[k] = nv + edge.data.assignments = new_dict + + # Replace [0] accesses of scalars (formerly array ones) on IfBlocks + for node in sdfg.all_control_flow_blocks(): + if isinstance(node, ConditionalBlock): + for cond, body in node.branches: + if cond is None: + continue + nlc = cond.as_string if isinstance(cond, CodeBlock) else str(cond) + for scalar_name in scalarized_arrays: + if f"{scalar_name}[0]" in nlc: + nlc = nlc.replace(f"{scalar_name}[0]", scalar_name) + cond = CodeBlock(nlc, cond.language if isinstance(cond, CodeBlock) else dace.dtypes.Language.Python) + + # Replace [0] accesses of scalars (formerly array ones) on LoopRegions + for node in sdfg.all_control_flow_regions(): + if isinstance(node, LoopRegion): + nlc = node.loop_condition.as_string if isinstance(node.loop_condition, CodeBlock) else str( + node.loop_condition) + for scalar_name in scalarized_arrays: + if f"{scalar_name}[0]" in nlc: + nlc = nlc.replace(f"{scalar_name}[0]", scalar_name) + node.loop_condition = CodeBlock( + nlc, node.loop_condition.language + if isinstance(node.loop_condition, CodeBlock) else dace.dtypes.Language.Python) + + if recursive: + for state in sdfg.all_states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.NestedSDFG): + replace_length_one_arrays_with_scalars(node.sdfg, recursive=True, transient_only=True) + + +def connect_array_names(sdfg: dace.SDFG, local_storage: dace.dtypes.StorageType, src_storage: dace.dtypes.StorageType, + local_name_prefix: str): + + array_name_dict = dict() + for state in sdfg.all_states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.AccessNode): + local_arr = state.sdfg.arrays[node.data] + print(local_arr.storage) + if local_arr.storage == local_storage: + assert len(state.in_edges(node)) <= 1 + # Reads + for ie in state.in_edges(node): + if ie.data.data is not None and ie.data.data != node.data: + src_data = state.sdfg.arrays[ie.data.data] + print(src_data) + if src_data.storage == src_storage: + assert node.data not in array_name_dict + array_name_dict[node.data] = ie.data.data + # Writes + for oe in state.out_edges(node): + if oe.data.data is not None and oe.data.data != node.data: + dst_data = state.sdfg.arrays[oe.data.data] + print(dst_data) + if dst_data.storage == src_storage: + assert node.data not in array_name_dict + array_name_dict[node.data] = oe.data.data + + print(array_name_dict) + repldict = {k: f"{local_name_prefix}{v}" for k, v in array_name_dict.items()} + + sdfg.replace_dict(repldict, replace_keys=True) + sdfg.validate() + + +def tasklet_has_symbol(tasklet: dace.nodes.Tasklet, symbol_str: str) -> bool: + if tasklet.code.language == dace.dtypes.Language.Python: + try: + sym_expr = dace.symbolic.SymExpr(tasklet.code.as_astring) + return (symbol_str in {str(s) for s in sym_expr.free_symbols}) + except Exception as e: + return token_match(tasklet.code.as_string, symbol_str) + else: + return token_match(tasklet.code.as_string, symbol_str) + + +def replace_code(code_str: str, code_lang: dace.dtypes.Language, repldict: Dict[str, str]) -> str: + + def _str_replace(lhs: str, rhs: str) -> str: + code_str = token_replace_dict(rhs, repldict) + return f"{lhs.strip()} = {code_str.strip()}" + + if code_lang == dace.dtypes.Language.Python: + try: + lhs, rhs = code_str.split(" = ") + lhs = lhs.strip() + rhs = rhs.strip() + except Exception as e: + try: + new_rhs_sym_expr = dace.symbolic.SymExpr(code_str).subs(repldict) + printer = BracketFunctionPrinter({'strict': False}) + cleaned_expr = printer.doprint(new_rhs_sym_expr).strip() + return f"{cleaned_expr}" + except Exception as e: + return _str_replace(code_str) + try: + new_rhs_sym_expr = dace.symbolic.SymExpr(rhs).subs(repldict) + printer = BracketFunctionPrinter({'strict': False}) + cleaned_expr = printer.doprint(new_rhs_sym_expr).strip() + return f"{lhs.strip()} = {cleaned_expr}" + except Exception as e: + return _str_replace(rhs) + else: + return _str_replace(rhs) + + +def tasklet_replace_code(tasklet: dace.nodes.Tasklet, repldict: Dict[str, str]): + new_code = replace_code(tasklet.code.as_string, tasklet.code.language, repldict) + tasklet.code = CodeBlock(code=new_code, language=tasklet.code.language) + + +def extract_bracket_tokens(s: str) -> list[tuple[str, list[str]]]: + """ + Extracts all contents inside [...] along with the token before the '[' as the name. + + Args: + s (str): Input string. + + Returns: + List of tuples: [(name_token, string inside brackes)] + """ + results = [] + + # Pattern to match [content_inside] + pattern = re.compile(r'(\b\w+)\[([^\]]*?)\]') + + for match in pattern.finditer(s): + name = match.group(1) # token before '[' + content = match.group(2).split() # split content inside brackets into tokens + + results.append((name, " ".join(content))) + + return {k: v for (k, v) in results} + + +def remove_bracket_tokens(s: str) -> str: + """ + Removes all [...] patterns from the string. + + Args: + s (str): Input string. + + Returns: + str: String with all [...] removed. + """ + return re.sub(r'\[.*?\]', '', s) + + +def generate_assignment_as_tasklet_in_state(state: dace.SDFGState, lhs: str, rhs: str): + rhs = rhs.strip() + rhs_sym_expr = dace.symbolic.SymExpr(rhs).evalf() + lhs = lhs.strip() + lhs_sym_expr = dace.symbolic.SymExpr(lhs).evalf() + + in_connectors = dict() + out_connectors = dict() + + # Get functions for indirect accesses + i = 0 + for free_sym in rhs_sym_expr.free_symbols.union({f.func for f in rhs_sym_expr.atoms(Function)}): + if str(free_sym) in state.sdfg.arrays: + in_connectors[str(free_sym)] = f"_in_{free_sym}_{i}" + i += 1 + for free_sym in lhs_sym_expr.free_symbols.union({f.func for f in lhs_sym_expr.atoms(Function)}): + if str(free_sym) in state.sdfg.arrays: + out_connectors[str(free_sym)] = f"_out_{free_sym}_{i}" + i += 1 + + if in_connectors == {} and out_connectors == {}: + raise Exception("Generated tasklets result in no or out connectors") + + # Process interstate edge, extract brackets for access patterns + in_access_exprs = extract_bracket_tokens(token_replace_dict(rhs, in_connectors)) + out_access_exprs = extract_bracket_tokens(token_replace_dict(lhs, out_connectors)) + lhs = remove_bracket_tokens(token_replace_dict(lhs, out_connectors)) + rhs = remove_bracket_tokens(token_replace_dict(rhs, in_connectors)) + + # Ass tasklets + t = state.add_tasklet(name=f"assign_{lhs}", + inputs=set(in_connectors.values()), + outputs=set(out_connectors.values()), + code=f"{lhs} = {rhs}") + + # Add connectors and accesses + in_access_dict = dict() + out_access_dict = dict() + for k, v in in_connectors.items(): + in_access_dict[v] = state.add_access(k) + for k, v in out_connectors.items(): + out_access_dict[v] = state.add_access(k) + + # Add in and out connections + for k, v in in_access_dict.items(): + data_name = v.data + access_str = in_access_exprs.get(k) + if access_str is None: + access_str = "0" + state.add_edge(v, None, t, k, dace.memlet.Memlet(expr=f"{data_name}[{access_str}]")) + for k, v in out_access_dict.items(): + data_name = v.data + access_str = out_access_exprs.get(k) + if access_str is None: + access_str = "0" + state.add_edge(t, k, v, None, dace.memlet.Memlet(expr=f"{data_name}[{access_str}]")) + + +def _find_parent_state(root_sdfg: dace.SDFG, node: dace.nodes.NestedSDFG): + if node is not None: + # Find parent state of that node + for n, g in root_sdfg.all_nodes_recursive(): + if n == node: + parent_state = g + return parent_state + return None + + +def get_num_parent_map_scopes(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_state: dace.SDFGState): + scope_dict = parent_state.scope_dict() + num_parent_maps = 0 + cur_node = node + while scope_dict[cur_node] is not None: + if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): + num_parent_maps += 1 + cur_node = scope_dict[cur_node] + + # Check parent nsdfg + parent_nsdfg_node = parent_state.sdfg.parent_nsdfg_node + parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + + while parent_nsdfg_node is not None: + scope_dict = parent_nsdfg_parent_state.scope_dict() + cur_node = parent_nsdfg_node + while scope_dict[cur_node] is not None: + if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): + num_parent_maps += 1 + cur_node = scope_dict[cur_node] + parent_nsdfg_node = parent_nsdfg_parent_state.sdfg.parent_nsdfg_node + parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + + return num_parent_maps + + +def get_num_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_state: dace.SDFGState): + return len(get_parent_map_and_loop_scopes(root_sdfg, node, parent_state)) + + +def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry | ControlFlowRegion + | dace.nodes.Tasklet | ConditionalBlock, parent_state: dace.SDFGState): + scope_dict = parent_state.scope_dict() if parent_state is not None else None + num_parent_maps_and_loops = 0 + cur_node = node + parent_scopes = list() + + if isinstance(cur_node, (dace.nodes.MapEntry, dace.nodes.Tasklet)): + while scope_dict[cur_node] is not None: + if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): + num_parent_maps_and_loops += 1 + parent_scopes.append(scope_dict[cur_node]) + cur_node = scope_dict[cur_node] + + parent_graph = parent_state.parent_graph if parent_state is not None else node.parent_graph + parent_sdfg = parent_state.sdfg if parent_state is not None else node.parent_graph.sdfg + while parent_graph != parent_sdfg: + if isinstance(parent_graph, LoopRegion): + num_parent_maps_and_loops += 1 + parent_scopes.append(parent_graph) + parent_graph = parent_graph.parent_graph + + # Check parent nsdfg + parent_nsdfg_node = parent_sdfg.parent_nsdfg_node + parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + + while parent_nsdfg_node is not None and parent_nsdfg_parent_state is not None: + scope_dict = parent_nsdfg_parent_state.scope_dict() + cur_node = parent_nsdfg_node + while scope_dict[cur_node] is not None: + if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): + num_parent_maps_and_loops += 1 + parent_scopes.append(scope_dict[cur_node]) + cur_node = scope_dict[cur_node] + + parent_graph = parent_nsdfg_parent_state.parent_graph + parent_sdfg = parent_graph.sdfg + while parent_graph != parent_sdfg: + if isinstance(parent_graph, LoopRegion): + num_parent_maps_and_loops += 1 + parent_scopes.append(parent_graph) + parent_graph = parent_graph.parent_graph + + parent_nsdfg_node = parent_sdfg.parent_nsdfg_node + parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + + return parent_scopes + + +def get_parent_maps(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_state: dace.SDFGState): + maps = [] + scope_dict = parent_state.scope_dict() + cur_node = node + while scope_dict[cur_node] is not None: + if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): + maps.append((cur_node, parent_state)) + cur_node = scope_dict[cur_node] + + parent_graph = parent_state.parent_graph + while parent_graph != parent_state.sdfg: + if isinstance(parent_graph, LoopRegion): + pass + parent_graph = parent_graph.parent_graph + + # Check parent nsdfg + parent_nsdfg_node = parent_state.sdfg.parent_nsdfg_node + parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + + while parent_nsdfg_node is not None: + scope_dict = parent_nsdfg_parent_state.scope_dict() + cur_node = parent_nsdfg_node + while scope_dict[cur_node] is not None: + if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): + maps.append((cur_node, parent_state)) + cur_node = scope_dict[cur_node] + parent_nsdfg_node = parent_nsdfg_parent_state.sdfg.parent_nsdfg_node + parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + + return maps + + +def duplicate_memlets_sharing_single_in_connector(state: dace.SDFGState, map_entry: dace.nodes.MapEntry): + + def _find_new_name(base: str, existing_names: Set[str]) -> str: + i = 0 + candidate = f"{base}_d_{i}" + while candidate in existing_names: + i += 1 + candidate = f"{base}_d_{i}" + return candidate + + for out_conn in list(map_entry.out_connectors.keys()): + out_edges_of_out_conn = set(state.out_edges_by_connector(map_entry, out_conn)) + if len(out_edges_of_out_conn) > 1: + base_in_edge = out_edges_of_out_conn.pop() + + # Get all parent maps (including this) + parent_maps: Set[dace.nodes.MapEntry] = {map_entry} + sdict = state.scope_dict() + parent_map = sdict[map_entry] + while parent_map is not None: + parent_maps.add(parent_map) + parent_map = sdict[parent_map] + + # Need it to find unique names + all_existing_connector_names = set() + for map_entry in parent_maps: + for in_conn in map_entry.in_connectors: + all_existing_connector_names.add(in_conn[len("IN_"):]) + for out_conn in map_entry.out_connectors: + all_existing_connector_names.add(out_conn[len("OUT_"):]) + + # Base path + memlet_paths = [] + path = state.memlet_path(base_in_edge) + source_node = path[0].src + memlet_paths.append(path) + while sdict[source_node] is not None: + if not isinstance(source_node, (dace.nodes.AccessNode, dace.nodes.MapEntry)): + print(source_node) + raise Exception( + f"In the path from map entry to the top level scope, only access nodes and other map entries may appear, got: {source_node}" + ) + in_edges = state.in_edges(source_node) + if isinstance(source_node, dace.nodes.MapEntry) and len(in_edges) != 1: + in_edges = list(state.in_edges_by_connector(source_node, "IN_" + path[-1].src_conn[len("OUT_"):])) + if isinstance(source_node, dace.nodes.AccessNode) and len(in_edges) != 1: + raise Exception( + "In the path from map entry to the top level scope, the intermediate access nodes need to have in and out degree (by connector) 1" + ) + + in_edge = in_edges[0] + path = state.memlet_path(in_edge) + source_node = path[0].src + memlet_paths.append(path) + #print(source_node) + + # Need to duplicate the out edges + for e in list(out_edges_of_out_conn): + state.remove_edge(e) + + for edge_to_duplicate in out_edges_of_out_conn: + base = edge_to_duplicate.src_conn[len("OUT_"):] + new_connector_base = _find_new_name(base, all_existing_connector_names) + all_existing_connector_names.add(new_connector_base) + + node_map = dict() + for i, subpath in enumerate(memlet_paths): + for j, e in enumerate(reversed(subpath)): + # We work by adding an in edge + in_name = f"IN_{new_connector_base}" + out_name = f"OUT_{new_connector_base}" + + if e.src_conn is not None: + out_conn = out_name if e.src_conn.startswith("OUT_") else e.src_conn + else: + out_conn = None + + if e.dst_conn is not None: + if e.src == map_entry: + in_conn = edge_to_duplicate.dst_conn + else: + in_conn = in_name if e.dst_conn.startswith("IN_") else e.dst_conn + else: + in_conn = None + + if isinstance(e.src, dace.nodes.MapEntry): + src_node = e.src + elif isinstance(e.src, dace.nodes.AccessNode): + if e.src in node_map: + src_node = node_map[e.src] + else: + a = state.add_access(e.src.data) + node_map[e.src] = a + src_node = a + else: + src_node = e.src + + if isinstance(e.dst, dace.nodes.MapEntry): + dst_node = e.dst + elif isinstance(e.dst, dace.nodes.AccessNode): + if e.dst in node_map: + dst_node = node_map[e.dst] + else: + a = state.add_access(e.dst.data) + node_map[e.dst] = a + dst_node = a + else: + dst_node = e.dst + + # Above the first map, always add the complete subset and then call memlet propagation + if e.src is map_entry: + data = copy.deepcopy(edge_to_duplicate.data) + else: + data = dace.memlet.Memlet.from_array(e.data.data, state.sdfg.arrays[e.data.data]) + + state.add_edge(src_node, out_conn, dst_node, in_conn, data) + + if out_conn is not None and out_conn not in src_node.out_connectors: + src_node.add_out_connector(out_conn, force=True) + if in_conn is not None and in_conn not in dst_node.in_connectors: + dst_node.add_in_connector(in_conn, force=True) + + # If we duplicate an access node, we should add correct dependency edges + if i == len(memlet_paths) - 1: + if j == len(subpath) - 1: + # Source node + origin_source_node = e.src + for ie in state.in_edges(origin_source_node): + state.add_edge(ie.src, None, src_node, None, dace.memlet.Memlet(None)) + + propagate_memlets_state(state.sdfg, state) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 6347ad4d5d..ba99bf78c2 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -8,6 +8,8 @@ import networkx as nx import time +import sympy + import dace.sdfg.nodes from dace.codegen import compiled_sdfg as csdfg from dace.sdfg.graph import MultiConnectorEdge @@ -594,139 +596,6 @@ def merge_maps( return merged_entry, merged_exit -def canonicalize_memlet_trees_for_scope( - state: SDFGState, - scope_node: Union[nd.EntryNode, nd.ExitNode], -) -> int: - """Canonicalize the Memlet trees of scope nodes. - - The function will modify all Memlets that are adjacent to `scope_node` - such that the Memlet always refers to the data that is on the outside. - This function only operates on a single scope. - - :param state: The SDFG state in which the scope to consolidate resides. - :param scope_node: The scope node whose edges will be consolidated. - :return: Number of modified Memlets. - - :note: This is the "historical" expected format of Memlet trees at scope nodes, - which was present before the introduction of `other_subset`. Running this - transformation might fix some issues. - """ - if isinstance(scope_node, nd.EntryNode): - may_have_dynamic_map_range = True - is_downward_tree = True - outer_edges = state.in_edges(scope_node) - get_outer_edge_connector = lambda e: e.dst_conn - inner_edges_for = lambda conn: state.out_edges_by_connector(scope_node, conn) - inner_prefix = 'OUT_' - outer_prefix = 'IN_' - - def get_outer_data(e: MultiConnectorEdge[dace.Memlet]): - mpath = state.memlet_path(e) - assert isinstance(mpath[0].src, nd.AccessNode) - return mpath[0].src.data - - else: - may_have_dynamic_map_range = False - is_downward_tree = False - outer_edges = state.out_edges(scope_node) - get_outer_edge_connector = lambda e: e.src_conn - inner_edges_for = lambda conn: state.in_edges_by_connector(scope_node, conn) - inner_prefix = 'IN_' - outer_prefix = 'OUT_' - - def get_outer_data(e: MultiConnectorEdge[dace.Memlet]): - mpath = state.memlet_path(e) - assert isinstance(mpath[-1].dst, nd.AccessNode) - return mpath[-1].dst.data - - def swap_prefix(conn: str) -> str: - if conn.startswith(inner_prefix): - return outer_prefix + conn[len(inner_prefix):] - else: - assert conn.startswith( - outer_prefix), f"Expected connector to start with '{outer_prefix}', but it was '{conn}'." - return inner_prefix + conn[len(outer_prefix):] - - modified_memlet = 0 - for outer_edge in outer_edges: - outer_edge_connector = get_outer_edge_connector(outer_edge) - if may_have_dynamic_map_range and (not outer_edge_connector.startswith(outer_prefix)): - continue - assert outer_edge_connector.startswith(outer_prefix) - corresponding_inner_connector = swap_prefix(outer_edge_connector) - - # In case `scope_node` is at the global scope it should be enough to run - # `outer_edge.data.data` but this way it is more in line with consolidate. - outer_data = get_outer_data(outer_edge) - - for inner_edge in inner_edges_for(corresponding_inner_connector): - for mtree in state.memlet_tree(inner_edge).traverse_children(include_self=True): - medge: MultiConnectorEdge[dace.Memlet] = mtree.edge - if medge.data.data == outer_data: - # This edge is already referring to the outer data, so no change is needed. - continue - - # Now we have to extract subset from the Memlet. - if is_downward_tree: - subset = medge.data.get_src_subset(medge, state) - other_subset = medge.data.dst_subset - else: - subset = medge.data.get_dst_subset(medge, state) - other_subset = medge.data.src_subset - - # Now for an update. - medge.data._data = outer_data - medge.data._subset = subset - medge.data._other_subset = other_subset - medge.data.try_initialize(state.sdfg, state, medge) - modified_memlet += 1 - - return modified_memlet - - -def canonicalize_memlet_trees( - sdfg: 'dace.SDFG', - starting_scope: Optional['dace.sdfg.scope.ScopeTree'] = None, -) -> int: - """Canonicalize the Memlet trees of all scopes in the SDFG. - - This function runs `canonicalize_memlet_trees_for_scope()` on all scopes - in the SDFG. Note that this function does not recursively processes - nested SDFGs. - - :param sdfg: The SDFG to consolidate. - :param starting_scope: If not None, starts with a certain scope. Note in that - mode only the state in which the scope is located will be processes. - :return: Number of modified Memlets. - """ - - total_modified_memlets = 0 - for state in sdfg.states(): - # Start bottom-up - if starting_scope is not None and starting_scope.entry not in state.nodes(): - continue - - queue = [starting_scope] if starting_scope else state.scope_leaves() - next_queue = [] - while len(queue) > 0: - for scope in queue: - if scope.entry is not None: - total_modified_memlets += canonicalize_memlet_trees_for_scope(state, scope.entry) - if scope.exit is not None: - total_modified_memlets += canonicalize_memlet_trees_for_scope(state, scope.exit) - if scope.parent is not None: - next_queue.append(scope.parent) - queue = next_queue - next_queue = [] - - if starting_scope is not None: - # No need to traverse other states - break - - return total_modified_memlets - - def consolidate_edges_scope(state: SDFGState, scope_node: Union[nd.EntryNode, nd.ExitNode]) -> int: """ Union scope-entering memlets relating to the same data node in a scope. @@ -2548,11 +2417,12 @@ def _specialize_scalar_impl(root: 'dace.SDFG', sdfg: 'dace.SDFG', scalar_name: s # -> For 2: Rm. dynamic in connector, remove the edge and the node if the degree is None # 3. Access Node # -> If access node is used then e.g. [scalar] -> [tasklet] - # -> then [tasklet(assign const value)] -> [access node] -> [tasklet] + # -> then create a [tasklet] that uses the scalar_val as a constant value inside + import dace.sdfg.construction_utils as cutil def repl_code_block_or_str(input: Union[CodeBlock, str], src: str, dst: str): if isinstance(input, CodeBlock): - return CodeBlock(input.as_string.replace(src, dst)) + return CodeBlock(cutil.replace_code(input.as_string, input.language, {src: dst}), input.language) else: return input.replace(src, dst) @@ -2583,22 +2453,16 @@ def repl_code_block_or_str(input: Union[CodeBlock, str], src: str, dst: str): assert e.data.data == scalar_name if isinstance(e.dst, nd.Tasklet): - assign_tasklet = state.add_tasklet(f"assign_{scalar_name}", - inputs={}, - outputs={"_out"}, - code=f"_out = {scalar_val}") - tmp_name = f"__tmp_{scalar_name}_{c}" - c += 1 - copydesc = copy.deepcopy(sdfg.arrays[scalar_name]) - copydesc.transient = True - copydesc.storage = dace.StorageType.Register - sdfg.add_datadesc(tmp_name, copydesc) - scl_an = state.add_access(tmp_name) + in_tasklet_name = e.dst_conn + new_code = CodeBlock(code=cutil.replace_code(e.dst.code.as_string, e.dst.code.language, + {in_tasklet_name: scalar_val}), + language=e.dst.code.language) + e.dst.code = new_code state.remove_edge(e) - state.add_edge(assign_tasklet, "_out", scl_an, None, dace.memlet.Memlet.from_array(tmp_name, copydesc)) - state.add_edge(scl_an, None, dst, e.dst_conn, dace.memlet.Memlet.from_array(tmp_name, copydesc)) if e.src_conn is not None: src.remove_out_connector(e.src_conn) + if e.dst_conn is not None: + dst.remove_in_connector(e.dst_conn) else: state.remove_edge(e) if e.src_conn is not None: @@ -2657,7 +2521,101 @@ def repl_code_block_or_str(input: Union[CodeBlock, str], src: str, dst: str): _specialize_scalar_impl(root, nsdfg, scalar_name, scalar_val) -def specialize_scalar(sdfg: 'dace.SDFG', scalar_name: str, scalar_val: Union[float, int, str]): - assert isinstance(scalar_name, str) - assert isinstance(scalar_val, (float, int, str)) +def specialize_scalar(sdfg: 'dace.SDFG', scalar_name: str, scalar_val: Union[float, int, str, sympy.Number]): + assert isinstance(scalar_name, str), f"Expected scalar name to be str got {type(scalar_val)}" + + def _sympy_to_python_number(val): + """Convert any SymPy numeric type to a native Python int or float.""" + if isinstance(val, sympy.Integer): + return int(val) + elif isinstance(val, (sympy.Float, sympy.Rational)): + return float(val) + elif isinstance(val, sympy.Number): + # Fallback for any other sympy numeric type + return float(val.evalf()) + return val # unchanged if not a number + + assert isinstance( + scalar_val, + (float, int, str, + sympy.Number)), f"Expected scalar value to be float, int, str, or sympy.Number, got {type(scalar_val)}" + if not isinstance(scalar_val, (float, int, str)): + if isinstance(scalar_val, sympy.Number): + scalar_val = _sympy_to_python_number(scalar_val) + _specialize_scalar_impl(sdfg, sdfg, scalar_name, scalar_val) + + +def demote_symbol_to_scalar(sdfg: 'dace.SDFG', symbol_str: str, default_type: 'dace.dtypes.typeclass' = None): + import dace.sdfg.construction_utils as cutil + if default_type is None: + default_type = dace.int32 + + # If assignment is to symbol_str, append it to last scalar before + if symbol_str in sdfg.symbols: + sym_dtype = sdfg.symbols[symbol_str] + else: + print( + f"Symbol {symbol_str} not in the symbols of {sdfg.label} ({sdfg.symbols}), setting to default type {default_type}" + ) + sym_dtype = default_type + + # If top-level and in free symbols + # Or not top-level and in symbol mapping need to make it non transient + # TODO: + is_top_level = sdfg.parent_nsdfg_node is None + is_transient = not ((is_top_level and symbol_str in sdfg.free_symbols) or + ((not is_top_level) and symbol_str in sdfg.parent_nsdfg_node.symbol_mapping)) + + if is_transient is False: + raise Exception("Scalar to symbol demotion only works if the resulting scalar would be transient") + + if symbol_str in sdfg.symbols: + sdfg.remove_symbol(symbol_str) + sdfg.add_scalar(name=symbol_str, dtype=sym_dtype, storage=dace.dtypes.StorageType.Register, transient=is_transient) + + # For any tasklet that uses the symbol - make an access node to the scalar and connect through the in connector + # 1. Replace all symbols of name appearing + # 2.1 If symbol <- expr in any interstate edge + # 2.2 Add a new state before edge.dst, and add assignment to the scalar + + # 1 + # Replace all code in tasklets and access nodes + for g in sdfg.all_states(): + for n in g.nodes(): + if isinstance(n, dace.nodes.Tasklet): + assert isinstance(g, dace.SDFGState) + sdict = g.scope_dict() + if cutil.tasklet_has_symbol(n, symbol_str): + # 2. If used in tasklet try to replace symbol name with an in connector and add an access to the scalar + + # Sanity check no tasklet should assign to a symbol + cutil.tasklet_replace_code(n.code.as_string, {symbol_str: f"_in_{symbol_str}"}) + n.add_in_connector(f"_in_{symbol_str}") + access = g.add_access(symbol_str) + g.add_edge(access, None, n, f"_in_{symbol_str}", dace.memlet.Memlet(expr=f"{symbol_str}[0]")) + # If parent scope is not None add a dependency edge to it + if sdict[n] is not None: + g.add_edge(sdict[n], None, access, None, dace.memlet.Memlet()) + + # 2 + for e in sdfg.all_interstate_edges(): + matching_assignments = {(k, v) for k, v in e.data.assignments.items() if k.strip() == symbol_str} + if len(matching_assignments) > 0: + # Add them to the next state + state = e.dst.parent_graph.add_state_before(e.dst, + label=f"_{e.dst}_sym_assign", + is_start_block=e.dst.parent_graph.start_block == e.dst) + # Go through all matching assignments + # Add symbols etc. as necessary + for k, v in matching_assignments: + del e.data.assignments[k] + symbol_str = k.strip() + if symbol_str in state.sdfg.symbols: + state.sdfg.remove_symbol(symbol_str) + if symbol_str not in g.sdfg.arrays: + state.sdfg.add_scalar(name=symbol_str, + dtype=sym_dtype, + storage=dace.dtypes.StorageType.Register, + transient=is_transient) + cutil.generate_assignment_as_tasklet_in_state(state, k, v) diff --git a/dace/transformation/interstate/branch_elimination.py b/dace/transformation/interstate/branch_elimination.py new file mode 100644 index 0000000000..dd21b65305 --- /dev/null +++ b/dace/transformation/interstate/branch_elimination.py @@ -0,0 +1,1985 @@ +import ast +import copy +import numpy +from sympy import Eq, Equality, Function, Integer, preorder_traversal, pycode +import re +import sympy +import dace +from dace import properties, transformation +from dace import InterstateEdge +from dace.dtypes import typeclass +from dace.properties import CodeBlock +from dace.sdfg.sdfg import SDFG +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion, SDFGState +import dace.sdfg.utils as sdutil +import dace.sdfg.construction_utils as cutil +from typing import Tuple, Set, Union +from dace.symbolic import pystr_to_symbolic +from dace.transformation.passes import FuseStates + + +def remove_symbol_assignments(graph: ControlFlowRegion, sym_name: str): + for e in graph.all_interstate_edges(): + new_assignments = dict() + for k, v in e.data.assignments.items(): + if k != sym_name: + new_assignments[k] = v + e.data.assignments = new_assignments + + +class DivEps(ast.NodeTransformer): + + def __init__(self, eps_node, tasklet_code_str, mode="add"): + """ + mode: 'add' -> use (x + eps) + 'max' -> use max(x, eps) + """ + self.eps_node = eps_node + self.mode = mode + self.tasklet_code_str = tasklet_code_str + + def visit_BinOp(self, node): + self.generic_visit(node) + if isinstance(node.op, ast.Div): + if self.mode == "add": + print(f"Changing {self.tasklet_code_str} to have +{ast.unparse(self.eps_node)} " + f"to avoid NaN/inf floating-point exception in division!") + node.right = ast.BinOp(left=node.right, op=ast.Add(), right=self.eps_node) + elif self.mode == "max": + print(f"Changing {self.tasklet_code_str} to have max(..., {ast.unparse(self.eps_node)}) " + f"to avoid NaN/inf floating-point exception in division!") + node.right = ast.Call( + func=ast.Name(id="max", ctx=ast.Load()), + args=[node.right, self.eps_node], + keywords=[], + ) + return node + + def visit_Call(self, node): + self.generic_visit(node) + + # determine the name of the called function + func_name = None + if isinstance(node.func, ast.Attribute): + func_name = node.func.attr + elif isinstance(node.func, ast.Name): + func_name = node.func.id + + if func_name in ("log", "log10", "log2"): + if node.args: + if self.mode == "add": + print( + f"Changing {self.tasklet_code_str} to have ({ast.unparse(node.args[0])}) + {ast.unparse(self.eps_node)} " + f"inside {func_name}() to avoid log(0) NaN!") + node.args[0] = ast.BinOp( + left=node.args[0], + op=ast.Add(), + right=self.eps_node, + ) + elif self.mode == "max": + print( + f"Changing {self.tasklet_code_str} to have max(({ast.unparse(node.args[0])}), {ast.unparse(self.eps_node)}) " + f"inside {func_name}() to avoid log(0) NaN!") + node.args[0] = ast.Call( + func=ast.Name(id="max", ctx=ast.Load()), + args=[node.args[0], self.eps_node], + keywords=[], + ) + + # ensure it's a bare function name (no np.log) + if isinstance(node.func, ast.Attribute): + node.func = ast.Name(id=node.func.attr, ctx=ast.Load()) + + return node + + +@properties.make_properties +@dace.transformation.explicit_cf_compatible +class BranchElimination(transformation.MultiStateTransformation): + """ + It translated couple of SIMT branches to a form compatible with SIMD instructions + ``` + if (cond){ + out1[address1] = computation1(...); + } else { + out1[address1] = computation2(...); + } + ``` + + If all the write sets in the left and right branches are the same, + (for all write accesses), + we can transformation the if branch to: + ``` + fcond = cond? 1.0 : 0.0 + out1[address1] = computation1(...) * fcond + (1.0 - fcond) * computation2(...); + ``` + + For single branch case: + ``` + out1[address1] = computation1(...) * fcond + (1.0 - fcond) * out1[address1]; + ``` + + Also supportes: + ``` + if (cond){ + out1[address1] = computation1(...); + } else { + out1[address2] = computation2(...); + } + ``` + If `address1` and `address2` are completely disjoint the pass will handle this + by chaining the if branches and treating them as conditionals with a single branch + like. + ``` + if (cond){ + out1[address1] = computation1(...); + } + if (! cond) { + out1[address2] = computation2(...); + } + ``` + + The pass also supports multiple writes if the conditional has just one branch: + ``` + if (cond){ + out1[address1] = computation1(...); + out1[address2] = computation2(...); + } + + If the pattern is (address do not need tob e disjoint): + if (cond){ + out1[address1] = computation1(...); + out1[address2] = computation2(...); + } else { + out1[address3] = computation1(...); + out1[address4] = computation2(...); + } + + Then the condition is first converted to: + if (cond){ + out1[address1] = computation1(...); + out1[address2] = computation2(...); + } + if (not cond) + out1[address3] = computation1(...); + out1[address4] = computation2(...); + } + + ``` + + + This eliminates branching by duplicating the computation of each branch + but makes it possible to vectorize the computation. + """ + conditional = transformation.PatternNode(ConditionalBlock) + parent_nsdfg_state = properties.Property(dtype=SDFGState, allow_none=True, default=None) + eps_operator_type_for_log_and_div = properties.Property(dtype=str, + allow_none=False, + default="add", + choices=["max", "add"]) + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.conditional)] + + def _check_reuse(self, sdfg: dace.SDFG, orig_state: dace.SDFGState, diff_set: Set[str]): + for graph in sdfg.all_control_flow_regions(): + if (not isinstance(graph, dace.SDFGState)) and orig_state in graph.all_states(): + continue + if graph == orig_state: + continue + read_set, write_set = graph.read_and_write_sets() + if any({k in read_set or k in write_set for k in diff_set}): + return True + + return False + + def _symbol_appears_as_read(self, cfg: ControlFlowRegion, symbol_name: str) -> bool: + # Symbol can be read on an interstate edge, appear in a conditional block's conditions, loop regions condition / update + # Appear in shape of an array, in the expression of maps or in taskelts, passed to nested SDFGs + + # Interstate edge reads + for e in cfg.all_interstate_edges(): + for v in e.data.assignments.values(): + if symbol_name in dace.symbolic.symbols_in_code(v): + return True + + # Conditional Block + for cb in cfg.all_control_flow_blocks(): + if isinstance(cb, ConditionalBlock): + for cond, _ in cb.branches: + if cond is None: + continue + if symbol_name in dace.symbolic.symbols_in_code(cond.as_string): + return True + + # Memlets + for state in cfg.all_states(): + for edge in state.edges(): + if edge.data.data is not None: + for (b, e, s) in edge.data.subset: + if hasattr(b, "free_symbols") and symbol_name in {str(sym) for sym in b.free_symbols}: + return True + if hasattr(e, "free_symbols") and symbol_name in {str(sym) for sym in e.free_symbols}: + return True + if hasattr(s, "free_symbols") and symbol_name in {str(sym) for sym in s.free_symbols}: + return True + # Loop + for lr in cfg.all_control_flow_regions(): + if isinstance(lr, LoopRegion): + if symbol_name in dace.symbolic.symbols_in_code(lr.init_statement.as_string): + return True + if symbol_name in dace.symbolic.symbols_in_code(lr.update_statement.as_string): + return True + if symbol_name in dace.symbolic.symbols_in_code(lr.loop_condition.as_string): + return True + + # Arrays + for arr in cfg.sdfg.arrays.values(): + for dim, stride in zip(arr.shape, arr.strides): + if hasattr(dim, "free_symbols") and symbol_name in {str(sym) for sym in dim.free_symbols}: + return True + if hasattr(stride, "free_symbols") and symbol_name in {str(sym) for sym in stride.free_symbols}: + return True + + # Maps + for state in cfg.all_states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.MapEntry): + for (b, e, s) in node.map.range: + if hasattr(b, "free_symbols") and symbol_name in {str(sym) for sym in b.free_symbols}: + return True + if hasattr(e, "free_symbols") and symbol_name in {str(sym) for sym in e.free_symbols}: + return True + if hasattr(s, "free_symbols") and symbol_name in {str(sym) for sym in s.free_symbols}: + return True + + # Takslets + for state in cfg.all_states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.Tasklet): + if symbol_name in dace.symbolic.symbols_in_code(node.code.as_string): + return True + + return False + + def _extract_bracket_content(self, s: str): + pattern = r"(\w+)\[([^\]]*)\]" + matches = re.findall(pattern, s) + + extracted = dict() + for name, content in matches: + if name in extracted: + raise Exception("Repeated assignment in interstate edge, not supported") + extracted[name] = "[" + content + "]" + + # Remove [...] from the string + cleaned = re.sub(pattern, r"\1", s) + + return cleaned, extracted + + def _move_interstate_assignment_to_state(self, state: dace.SDFGState, rhs: str, lhs: str): + # Parse rhs of itnerstate edge assignment to a symbolic expression + # array accesses are shown as functions e.g. arr1[i, j] is treated as a function arr1(i, j) + # do not consider functions that are native python operators as functions (e.g. AND) + rhs_as_symexpr = dace.symbolic.SymExpr(rhs) + free_vars = {str(sym) + for sym in rhs_as_symexpr.free_symbols}.union({ + str(node.func) + for node in preorder_traversal(rhs_as_symexpr) if isinstance(node, Function) + }) - {"AND", "OR", "NOT", "and", "or", "not", "And", "Or", "Not"} + + # For array accesses such as arr1(i, j) get the dictionary tha tmaps name to accesses {arr1: [i,j]} + # And remove the array accesses + cleaned, extracted_subsets = self._extract_bracket_content(rhs) + + # Collect inputs we need + arr_inputs = {var for var in free_vars if var in state.sdfg.arrays} + # Generate the scalar for the float constant + float_lhs_name, float_lhs = state.sdfg.add_scalar( + name="float_" + lhs, + dtype=dace.dtypes.float64, + storage=dace.dtypes.StorageType.Register, + transient=True, + find_new_name=True, + ) + + # For all non-symbol (=array/scalar) inputs replace the name with the connector version + # Connector version is `_in_{arr_input}_{offset}` (to not repeat the connectors) + symbol_inputs = {var for var in free_vars if var not in state.sdfg.arrays} + for i, arr_input in enumerate(arr_inputs): + cleaned = cutil.token_replace_dict(cleaned, {arr_input: f"_in_{arr_input}_{i}"}) + + assert arr_inputs.union(symbol_inputs) == free_vars + + tasklet = state.add_tasklet(name=f"ieassign_{lhs}_to_{float_lhs_name}_scalar", + inputs={f"_in_{arr_input}_{i}" + for i, arr_input in enumerate(arr_inputs)}, + outputs={f"_out_{float_lhs_name}"}, + code=f"_out_{float_lhs_name} = ({cleaned})") + + for i, arr_name in enumerate(arr_inputs): + an_of_arr_name = {n for n in state.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == arr_name} + last_access = None + if len(an_of_arr_name) == 0: + last_access = state.add_access(arr_name) + elif len(an_of_arr_name) == 1: + last_access = an_of_arr_name.pop() + else: + # Get the last sink access node, if exists, use to avoide data races. + # All access nodes of the same data should be in the same weakly connected component + # Otherwise it would be code-gen dependent data race + source_nodes = {n for n in state.nodes() if state.in_degree(n) == 0} + ordered_an_of_arr_names = dict() + + last_accessess = dict() + for src_node in source_nodes: + ordered_an_of_arr_names[src_node] = list(state.bfs_nodes(src_node)) + access_to_arr_name = [ + n for n in ordered_an_of_arr_names[src_node] + if isinstance(n, dace.nodes.AccessNode) and n.data == arr_name + ] + if len(access_to_arr_name) > 0: + last_accessess[src_node] = access_to_arr_name[-1] + + assert len(last_accessess) == 1 + last_access = next(iter(last_accessess.values())) + + # Use interstate edge's access expression if exists, otherwise use [0,...,0] + new_subset = extracted_subsets.get(arr_name, None) + if new_subset is None: + new_subset = "[" + ",".join(["0" for _ in state.sdfg.arrays[arr_name].shape]) + "]" + #subset_str = interstate_index_to_subset_str(new_subset) + state.add_edge(last_access, None, tasklet, "_in_" + arr_name + f"_{i}", + dace.memlet.Memlet(f"{arr_name}{new_subset}")) + + # Convert boolean to float type + arr = state.sdfg.arrays[arr_name] + if arr.dtype == dace.bool: + arr.dtype = dace.float64 + + state.add_edge(tasklet, f"_out_{float_lhs_name}", state.add_access(float_lhs_name), None, + dace.memlet.Memlet(float_lhs_name)) + + return float_lhs_name + + def _is_disjoint_subset(self, state0: SDFGState, state1: SDFGState) -> bool: + state0_writes = set() + state1_writes = set() + state0_write_subsets = dict() + state1_write_subsets = dict() + read_sets0, write_sets0 = state0.read_and_write_sets() + read_sets1, write_sets1 = state1.read_and_write_sets() + joint_writes = write_sets0.intersection(write_sets1) + + # Remove ignored writes + ignores_writes = self.collect_ignored_writes(state0).union(self.collect_ignored_writes(state1)) + joint_writes = joint_writes.difference(ignores_writes) + + for write in joint_writes: + state0_accesses = {n for n in state0.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == write} + state1_accesses = {n for n in state1.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == write} + + for state_writes, state_accesses, state, state_write_subsets in [ + (state0_writes, state0_accesses, state0, state0_write_subsets), + (state1_writes, state1_accesses, state1, state1_write_subsets) + ]: + state_write_edges = set() + for an in state_accesses: + state_write_edges |= {e for e in state.in_edges(an) if e.data.data is not None} + # If there are multiple write edges again we would need to know the order + state_writes |= {e.data.data for e in state_write_edges} + for e in state_write_edges: + if e.data.data is None: + continue + assert (e.data.subset.num_elements_exact() == 1) + if e.data.data not in state_write_subsets: + state_write_subsets[e.data.data] = set() + state_write_subsets[e.data.data].add(e.data.subset) + + # Build symmetric difference of subsets + try: + all_keys = set(state0_write_subsets) | set(state1_write_subsets) + intersects = {k: False for k in all_keys} + for name, subsets0 in state0_write_subsets.items(): + if name in state1_write_subsets: + subsets1 = state1_write_subsets[name] + else: + subsets1 = set() + for other_subset in subsets1: + for subset0 in subsets0: + if subset0.intersects(other_subset): + intersects[name] = True + for name, subsets1 in state1_write_subsets.items(): + if name in state0_write_subsets: + subsets0 = state0_write_subsets[name] + else: + subsets0 = set() + for other_subset in subsets0: + for subset1 in subsets1: + if subset1.intersects(other_subset): + intersects[name] = True + + if not all(v is False for k, v in intersects.items()): + return False + except Exception as e: + print(f"Intersects call resulted in an exception: {e}") + return False + + return True + + def collect_accesses(self, state: dace.SDFGState, name: str): + """Return all AccessNodes in the state that access a given data name.""" + return {n for n in state.nodes() if isinstance(n, dace.nodes.AccessNode) and n.data == name} + + def collect_write_accesses(self, state: dace.SDFGState, name: str): + """ + Return AccessNodes that write to a given data name. + + A node is considered a write access if it: + - Has incoming edges (i.e., receives data) + - At least one incoming edge carries actual data + - Either has no outgoing edges, or refers to a non-transient or array type + """ + accesses = self.collect_accesses(state, name) + result = set() + for a in accesses: + has_incoming_data = any(e.data.data is not None for e in state.in_edges(a)) + array = state.sdfg.arrays[a.data] + if (state.in_degree(a) > 0 and has_incoming_data + and (state.out_degree(a) == 0 or isinstance(array, dace.data.Array) or array.transient is False)): + result.add(a) + return result + + def collect_ignored_write_accesses(self, state: dace.SDFGState): + """ + Return AccessNodes that are *write accesses* but do not meet the criteria + for considered writes - see collect write access for that is considered a + write. + """ + all_write_accesses = { + a + for a in state.nodes() + if isinstance(a, dace.nodes.AccessNode) and state.in_degree(a) > 0 and any(e.data.data is not None + for e in state.in_edges(a)) + } + + considered_write_accesses = { + a + for a in all_write_accesses + if (state.out_degree(a) == 0 or isinstance(state.sdfg.arrays[a.data], dace.data.Array) + or not state.sdfg.arrays[a.data].transient) + } + + return all_write_accesses - considered_write_accesses + + def collect_ignored_writes(self, state: dace.SDFGState): + ignored_write_access_nodes = self.collect_ignored_write_accesses(state) + return {an.data for an in ignored_write_access_nodes} + + def ignored_accesses_are_reused(self, states: Set[dace.SDFGState]): + """ + Check if ignored write accesses are used elsewhere in the SDFG. + Implement it by remove the state and then running and checking the + read write sets of the SDFG + + Returns: + True if any ignored data is later read or written in another state. + """ + ignored_accesses = set() + for state in states: + ignored_accesses = ignored_accesses.union(self.collect_ignored_write_accesses(state)) + ignored_data = {a.data for a in ignored_accesses} + + # Work on a copy to safely remove nodes + copy_sdfg = copy.deepcopy(state.sdfg) + sdutil.set_nested_sdfg_parent_references(copy_sdfg) + + # Remove all nodes from the target state in the copy + labels = {(s.label, s.parent_graph.label, s.sdfg.label) for s in states} + for st in copy_sdfg.all_states(): + label_tuple = (st.label, st.parent_graph.label, st.sdfg.label) + if label_tuple in labels: + for n in list(st.nodes()): + st.remove_node(n) + + read_set, write_set = copy_sdfg.read_and_write_sets() + ignored_in_read = read_set & ignored_data + ignored_in_write = write_set & ignored_data + + return bool(ignored_in_read or ignored_in_write), ignored_in_read.union(ignored_in_write) + + def symbol_reused_outside_conditional(self, sym_name: str): + copy_sdfg = copy.deepcopy(self.conditional.sdfg) + sdutil.set_nested_sdfg_parent_references(copy_sdfg) + + # Remove all nodes from the target state in the copy + conditional_label_tuple = (self.conditional.label, self.conditional.parent_graph.label + if self.conditional.parent_graph is not None else "", self.conditional.sdfg.label) + for st in copy_sdfg.all_control_flow_regions(): + label_tuple = (st.label, st.parent_graph.label if st.parent_graph is not None else "", st.sdfg.label) + if label_tuple == conditional_label_tuple: + ies = st.parent_graph.in_edges(st) + oes = st.parent_graph.out_edges(st) + empty_state = st.parent_graph.add_state(label=f"empty_replacement_{st.label}", + is_start_block=st.start_block) + st.parent_graph.remove_node(st) + for ie in ies: + st.parent_graph.add_edge(ie.src, empty_state, copy.deepcopy(ie.data)) + for oe in oes: + st.parent_graph.add_edge(empty_state, oe.dst, copy.deepcopy(oe.data)) + break + + return self._symbol_appears_as_read(copy_sdfg, sym_name) + + def only_top_level_tasklets(self, graph: ControlFlowRegion): + checked_at_least_one_tasklet = False + # Can be applied should ensure this + + # Having something other than a state is a problem + if set(graph.all_states()) != set(graph.nodes()): + return False + + for state in graph.all_states(): + # The function to get parent map and loop scopes is expensive so lets try map-libnodes first entry first + for node in state.nodes(): + if isinstance(node, (dace.nodes.MapEntry, dace.nodes.LibraryNode)): + return False + for node in state.nodes(): + if isinstance(node, dace.nodes.Tasklet): + parent_maps = cutil.get_parent_map_and_loop_scopes(root_sdfg=graph.sdfg, + node=node, + parent_state=state) + checked_at_least_one_tasklet = True + if len(parent_maps) > 0: + return False + + # If no tasklet has been checked + return checked_at_least_one_tasklet + + def has_no_top_level_tasklets(self, graph: ControlFlowRegion): + for state in graph.all_states(): + for node in state.nodes(): + if isinstance(node, dace.nodes.Tasklet): + parent_maps = cutil.get_parent_map_and_loop_scopes(root_sdfg=graph.sdfg, + node=node, + parent_state=state) + if len(parent_maps) == 0: + return False + + # If no tasklet has been checked + return True + + def add_conditional_write_combination(self, new_state: dace.SDFGState, + state0_in_new_state_write_access: dace.nodes.AccessNode, + state1_in_new_state_write_access: dace.nodes.AccessNode, + cond_var_as_float_name: str, write_name: str, index: int): + """ + Adds a conditional write combination mechanism to merge outputs from two branches. + + This function creates temporary scalars and a combine tasklet to perform: + result = float_cond * tmp1 + (1 - float_cond) * tmp2 + + Args: + new_state: The SDFG state where the combination will be added + state0_write_access: Write access node from state0 (original state) + state1_write_access: Write access node from state1 (original state) + state0_in_new_state_write_access: Write access node from state0 in new_state + state1_in_new_state_write_access: Write access node from state1 in new_state + cond_var_as_float_name: Name of the float condition variable + write_name: Base name for the write operation (used in naming) + index: Index for unique naming of temporaries + + Returns: + tuple: (combine_tasklet, tmp1_access, tmp2_access, float_cond_access) + """ + + # Prepare write subset for non-tasklet inputs + ies0 = new_state.in_edges(state0_in_new_state_write_access) + ies1 = new_state.in_edges(state1_in_new_state_write_access) + #assert len(ies0) == len(ies1) + #assert len(ies0) == 1 + + ie0, ie1 = ies0[0], ies1[0] + assert (ie0.data.subset == ie1.data.subset or ie0.data.subset == ie1.data.other_subset + or ie0.data.other_subset == ie1.data.subset + ), f"{ie0.data.subset} =? {ie1.data.subset} ; {ie0.data.other_subset} =? {ie1.data.other_subset}" + write_subset: dace.subsets.Range = ie0.data.subset + assert write_subset.num_elements_exact() == 1 + + # Generate unique names for temporary scalars + tmp1_name = f"if_body_tmp_{index}" + tmp2_name = f"else_body_tmp_{index}" + + # Get array information for dtype + arr = new_state.sdfg.arrays[state0_in_new_state_write_access.data] + + # 1. Add temporary scalar between state0's write to array (tmp1) + tmp1_name, tmp1_scalar = new_state.sdfg.add_scalar( + name=tmp1_name, + dtype=arr.dtype, + storage=dace.dtypes.StorageType.Default, + transient=True, + find_new_name=True, + ) + + # 2. Add temporary scalar between state1's write to array (tmp2) + tmp2_name, tmp2_scalar = new_state.sdfg.add_scalar( + name=tmp2_name, + dtype=arr.dtype, + storage=dace.dtypes.StorageType.Default, + transient=True, + find_new_name=True, + ) + + # 3. Add the combine tasklet which performs: float_cond1 * tmp1 + (1 - float_cond1) * tmp2 + combine_tasklet = new_state.add_tasklet( + name=f"combine_branch_values_for_{write_name}_{index}", + inputs={"_in_left", "_in_right", "_in_factor"}, + outputs={"_out"}, + code="_out = (_in_factor * _in_left) + ((1.0 - _in_factor) * _in_right)") + + # 4. Redirect the writes to access nodes with temporary scalars for each state + # Handle state0's write + ies = new_state.in_edges(state0_in_new_state_write_access) + #assert len(ies) == 1, f"Expected 1 input edge to state0 write access, got {len(ies)}" + for ie in ies: + tmp1_access = new_state.add_access(tmp1_name) + new_state.add_edge(ie.src, ie.src_conn, tmp1_access, None, dace.memlet.Memlet(f"{tmp1_name}")) + new_state.remove_edge(ie) + + # Handle state1's write + ies = new_state.in_edges(state1_in_new_state_write_access) + #assert len(ies) == 1, f"Expected 1 input edge to state1 write access, got {len(ies)}: {ies}" + for ie in ies: + tmp2_access = new_state.add_access(tmp2_name) + new_state.add_edge(ie.src, ie.src_conn, tmp2_access, None, dace.memlet.Memlet(f"{tmp2_name}")) + + # 5. Remove write of state1 + new_state.remove_edge(ie) + + # 6. Redirect tmp scalars to the combine tasklet and then to the old write + float_cond_access = new_state.add_access(cond_var_as_float_name) + + # Connect inputs to combine tasklet + for tmp_access, connector in [(tmp1_access, "_in_left"), (tmp2_access, "_in_right"), + (float_cond_access, "_in_factor")]: + new_state.add_edge(tmp_access, None, combine_tasklet, connector, dace.memlet.Memlet(tmp_access.data)) + + # Connect combine tasklet output to the final write access + if (ie.data.data != state0_in_new_state_write_access.data): + raise Exception("?") + + # In case other-subset + if ie.data.data == state0_in_new_state_write_access.data: + subset_to_use = ie.data.subset + else: + assert ie.data.other_subset is not None + subset_to_use = ie.data.other_subset + new_state.add_edge( + combine_tasklet, "_out", state0_in_new_state_write_access, None, + dace.memlet.Memlet(data=state0_in_new_state_write_access.data, subset=copy.deepcopy(subset_to_use))) + + return combine_tasklet, tmp1_access, tmp2_access, float_cond_access + + def has_wcr_edges(self, state: dace.SDFGState): + for e in state.edges(): + if e.data.wcr is not None or e.data.wcr_nonatomic: + return True + return False + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # Works for if-else branches or only if branches + # Sanity checks for the sdfg and graph parameters + assert sdfg == self.conditional.sdfg + assert graph == self.conditional.parent_graph + + if sdfg.parent_nsdfg_node is not None: + if self.parent_nsdfg_state is None: + print("[can_be_applied] Nested SDFGs need to provide the parent state of the parent nsdfg node") + return False + + if len(self.conditional.branches) > 2: + print("[can_be_applied] More than two branches – only supports if/else or single if.") + return False + + if len(self.conditional.branches) == 2: + tup0 = self.conditional.branches[0] + tup1 = self.conditional.branches[1] + (cond0, body0) = tup0[0], tup0[1] + (cond1, body1) = tup1[0], tup1[1] + if cond0 is not None and cond1 is not None: + print("[can_be_applied] Both branches have conditions – not a simple if/else.") + return False + assert not (cond0 is None and cond1 is None) + + # Works if the branch bodies have a single state each + for i, body in enumerate([body0, body1]): + if len(body.nodes()) != 1 or not isinstance(body.nodes()[0], SDFGState): + print(f"[can_be_applied] Branch {i} does not have exactly one SDFGState node.") + return False + + # Check write sets are equivalent + state0: SDFGState = body0.nodes()[0] + state1: SDFGState = body1.nodes()[0] + + # Need to consist of top level tasklets + all_top_level0 = self.only_top_level_tasklets(body0) + if not all_top_level0: + print(f"[can_be_applied] All tasklets need to be top level. Not the case for body {body0}.") + return False + all_top_level1 = self.only_top_level_tasklets(body1) + if not all_top_level1: + print(f"[can_be_applied] All tasklets need to be top level. Not the case for body {body1}.") + return False + + if self.has_wcr_edges(state0) or self.has_wcr_edges(state1): + print(f"[can_be_applied] Has WCR edges.") + return False + + read_sets0, write_sets0 = state0.read_and_write_sets() + read_sets1, write_sets1 = state1.read_and_write_sets() # fixed typo + + joint_writes = write_sets0.intersection(write_sets1) + diff_state0 = write_sets0.difference(write_sets1) + diff_state1 = write_sets1.difference(write_sets0) + + # For joint writes ensure the write subsets are always the same + + for write in joint_writes: + state0_accesses = self.collect_accesses(state0, write) + state1_accesses = self.collect_accesses(state1, write) + + state0_write_accesses = self.collect_write_accesses(state0, write) + state1_write_accesses = self.collect_write_accesses(state1, write) + + for state, accesses in [(state0, state0_accesses), (state1, state1_accesses)]: + for access in accesses: + arr = state.sdfg.arrays[access.data] + if arr.dtype not in dace.dtypes.FLOAT_TYPES and arr.dtype not in dace.dtypes.INT_TYPES: + print( + f"[can_be_applied] Storage of '{write.data}' is not a floating/int point type has type: {arr.dtype}" + ) + return False + + state0_writes = set() + state1_writes = set() + state0_write_subsets = dict() + state1_write_subsets = dict() + subsets_disjoint = True + for state_writes, state_accesses, state, state_write_subsets in [ + (state0_writes, state0_accesses, state0, state0_write_subsets), + (state1_writes, state1_accesses, state1, state1_write_subsets) + ]: + state_write_edges = set() + for an in state_accesses: + state_write_edges |= {e for e in state.in_edges(an) if e.data.data is not None} + # If there are multiple write edges again we would need to know the order + state_writes |= {e.data.data for e in state_write_edges} + for e in state_write_edges: + if e.data.data is None: + continue + if e.data.data not in state_write_subsets: + state_write_subsets[e.data.data] = set() + state_write_subsets[e.data.data].add(e.data.subset) + + subsets_disjoint = self._is_disjoint_subset(state0, state1) + if (len(state0_write_subsets.items()) == 1 and len(state1_write_subsets.items()) == 1): + if not subsets_disjoint: + # If data are the same ok, otherwise not + k1, v1 = next(iter(state0_write_subsets.items())) + k2, v2 = next(iter(state1_write_subsets.items())) + if k1 != k2 or v1 != v2: + return False + + # If there are more than one writes we can't fuse them together without knowing how to order + # Unless the subset is disjoint + if (any(len(v) > 1 for k, v in state0_write_subsets.items()) + or any(len(v) > 1 for k, v in state1_write_subsets.items())): + if not subsets_disjoint: + print( + f"[can_be_applied] Multiple write edges for '{write}' in one branch and subsets not disjoint." + ) + return False + + # If the subset of each branch is different then we can't fuse either + if state0_writes != state1_writes: + if not subsets_disjoint: + print( + f"[can_be_applied] Write subsets differ (and not disjoint) for '{write}' between branches.") + return False + + if len(state0_write_accesses) > 1 or len(state1_write_accesses) > 1: + if not subsets_disjoint: + print( + f"[can_be_applied] Multiple write accesses AccessNodes found for '{write}' in one of the branches." + ) + return False + + # If diff states only have transient scalars or arrays it is probably ok (permissive) + if diff_state0 or diff_state1: + if self._check_reuse(sdfg, state0, diff_state0): + print(f"[can_be_applied] Branch 0 writes to non-reusable data: {diff_state0}") + return False + if self._check_reuse(sdfg, state1, diff_state1): + print(f"[can_be_applied] Branch 1 writes to non-reusable data: {diff_state1}") + return False + + if not permissive: + has_reuse_on_ignored, ignored_but_reused_data = self.ignored_accesses_are_reused({state0, state1}) + if has_reuse_on_ignored: + print( + f"[can_be_applied] Ignored data ({write}) is used elsewhere (read): {ignored_but_reused_data}") + return False + + elif len(self.conditional.branches) == 1: + tup0: Tuple[properties.CodeBlock, ControlFlowRegion] = self.conditional.branches[0] + cond0, body0 = tup0[0], tup0[1] + + # Works if the branch body has a single state + if len(body0.nodes()) != 1 or not isinstance(body0.nodes()[0], SDFGState): + print("[can_be_applied] Single branch does not have exactly one SDFGState node.") + return False + + # Check write sets are equivalent + state0: SDFGState = body0.nodes()[0] + + read_sets0, write_sets0 = state0.read_and_write_sets() + + # Need to consist of top level tasklets + all_top_level0 = self.only_top_level_tasklets(body0) + if not all_top_level0: + print(f"[can_be_applied] All tasklets need to be top level. Not the case for body {body0}.") + return False + + if self.has_wcr_edges(state0): + print(f"[can_be_applied] Has WCR edges.") + return False + + # For joint writes ensure the write subsets are always the same + for write in write_sets0: + state0_accesses = self.collect_accesses(state0, write) + state0_write_accesses = self.collect_write_accesses(state0, write) + + for state, accesses in [ + (state0, state0_accesses), + ]: + for access in accesses: + arr = state.sdfg.arrays[access.data] + if arr.dtype not in dace.dtypes.FLOAT_TYPES and arr.dtype not in dace.dtypes.INT_TYPES: + print( + f"[can_be_applied] Storage of '{write.data}' is not a floating/int point type has type: {arr.dtype}" + ) + return False + + state0_writes = set() + for state_writes, state_accesses, state in [(state0_writes, state0_accesses, state0)]: + state_write_edges = set() + for an in state_accesses: + state_write_edges |= {e for e in state.in_edges(an) if e.data.data is not None} + + state_writes |= {e.data.data for e in state_write_edges} + for e in state_write_edges: + if e.data.data is None: + continue + if e.data.subset.num_elements_exact() != 1: + print( + f"[can_be_applied] All write edges need to have exactly one-element write '{write}' (edge {e} problematic)." + ) + + #has_reused_ignored_data, resued = self.ignored_accesses_are_reused({state0}) + # It is ok to have reused data for this case + # return False + + if permissive is False: + if self.condition_has_map_param(): + print( + "[can_be_applied] Map parameter is used conditional. This will likely result in out-of-bounds accesses. Enable permissive if you want to risk it" + ) + return False + + print(f"[can_be_applied] to {self.conditional} is True") + return True + + def _scalar_is_assigned_symbolic_value(self, state: SDFGState, + node: dace.nodes.AccessNode) -> Union[None, Tuple[str, dace.nodes.Tasklet]]: + # Check if the scalars are really needed + # If scalar has single source, which is a tasklet with 0 in degree then it is using some symbol value + # then we can use the symbolic expression instead + if state.in_degree(node) != 1: + return None + ies = state.in_edges(node) + if len(ies) != 1: + return None + ie = ies[0] + if not isinstance(ie.src, dace.nodes.Tasklet): + return None + t = ie.src + if state.in_degree(t) > 0: + return None + if len(t.in_connectors) > 0: + return None + return t.code.as_string.split("=")[-1].strip(), t + + def _try_simplify_combine_tasklet(self, state: SDFGState, node: dace.nodes.Tasklet): + if node.language != dace.dtypes.Language.Python: + return + + for ie in state.in_edges(node): + if isinstance(ie.src, dace.nodes.AccessNode): + rettup = self._scalar_is_assigned_symbolic_value(state, ie.src) + if rettup is not None: + rhs_str: str = rettup[0] + tasklet: dace.nodes.Tasklet = rettup[1] + state.remove_node(tasklet) + node.remove_in_connector(ie.dst_conn) + lhs, rhs = node.code.as_string.split("=") + lhs = lhs.strip() + rhs = rhs.strip() + rhs_expr = dace.symbolic.SymExpr(rhs) + code = sympy.nsimplify(rhs_expr) + # Use rational until the very end and then call evalf to get rational to flaot to avoid accumulating errors + code = sympy.nsimplify(code.subs(ie.dst_conn, rhs_str)).evalf() + new_code_str = lhs + " = " + pycode(code) + node.code = CodeBlock(new_code_str) + + state.remove_edge(ie) + if state.degree(ie.src) == 0: + state.remove_node(ie.src) + + def _try_fuse(self, graph: ControlFlowRegion, new_state: SDFGState, cond_prep_state: SDFGState): + if len({ + t + for t in cond_prep_state.nodes() + if isinstance(t, dace.nodes.Tasklet) or isinstance(t, dace.nodes.NestedSDFG) + }) == 1: + assign_tasklets = { + t + for t in cond_prep_state.nodes() if isinstance(t, dace.nodes.Tasklet) and t.label.startswith("ieassign") + } + + assert cond_prep_state.out_degree(next(iter(assign_tasklets))) == 1 + assign_tasklet = assign_tasklets.pop() + oe = cond_prep_state.out_edges(assign_tasklet)[0] + + dst_data = oe.dst.data + + source_nodes = [ + n for n in new_state.nodes() + if isinstance(n, dace.nodes.AccessNode) and n.data == dst_data and new_state.in_degree(n) == 0 + ] + + ies = cond_prep_state.in_edges(assign_tasklet) + + cp_tasklet = copy.deepcopy(assign_tasklet) + new_state.add_node(cp_tasklet) + tmp_name, tmp_arr = new_state.sdfg.add_scalar(name="tmp_ieassign", + dtype=new_state.sdfg.arrays[dst_data].dtype, + transient=True, + find_new_name=True) + tmp_an = new_state.add_access(tmp_name) + new_state.add_edge(cp_tasklet, next(iter(cp_tasklet.out_connectors)), tmp_an, None, + dace.memlet.Memlet(tmp_an.data)) + for src_node in source_nodes: + at = new_state.add_tasklet(name="assign", inputs={"_in"}, outputs={"_out"}, code="_out = _in") + new_state.add_edge(tmp_an, None, at, "_in", dace.memlet.Memlet(tmp_an.data)) + new_state.add_edge(at, "_out", src_node, None, dace.memlet.Memlet(src_node.data)) + + for ie in ies: + assert isinstance(ie.src, dace.nodes.AccessNode) + assert cp_tasklet in new_state.nodes() + new_ie_src = new_state.add_access(ie.src.data) + new_state.add_edge(new_ie_src, ie.src_conn, cp_tasklet, ie.dst_conn, copy.deepcopy(ie.data)) + cond_prep_state.remove_edge(ie) + if cond_prep_state.degree(ie.src) == 0: + cond_prep_state.remove_node(ie.src) + cond_prep_state.remove_edge(oe) + cond_prep_state.remove_node(oe.dst) + cond_prep_state.remove_node(assign_tasklet) + + # If no nodes inside the state and if it only connects to the new state and has no interstate assignments we can delete it + if len(cond_prep_state.nodes()) == 0: + if graph.out_degree(cond_prep_state) == 1 and {new_state + } == {e.dst + for e in graph.out_edges(cond_prep_state)}: + oes = graph.out_edges(cond_prep_state) + ies = graph.in_edges(cond_prep_state) + all_assignments = [oe.data.assignments for oe in oes] + if all({d == dict() for d in all_assignments}): + was_start_block = cond_prep_state == graph.start_block + graph.remove_node(cond_prep_state) + for ie in ies: + graph.add_edge(ie.src, new_state, copy.deepcopy(ie.data)) + if was_start_block: + oes2 = graph.out_edges(new_state) + graph.remove_node(new_state) + graph.add_node(new_state, is_start_block=True) + for oe2 in oes2: + graph.add_edge(new_state, oe2.dst, copy.deepcopy(oe2.data)) + + def _extract_condition_var_and_assignment(self, graph: ControlFlowRegion) -> Tuple[str, str]: + non_none_conds = [cond for cond, _ in self.conditional.branches if cond is not None] + assert len(non_none_conds) == 1 + cond = non_none_conds.pop() + cond_code_str = cond.as_string + cond_code_symexpr = pystr_to_symbolic(cond_code_str, simplify=False) + #assert len(cond_code_symexpr.free_symbols) == 1, f"{cond_code_symexpr}, {cond_code_symexpr.free_symbols}" + + # Find values assigned to the symbols + #print("Condition as symexpr:", cond_code_symexpr) + free_syms = {str(s).strip() for s in cond_code_symexpr.free_symbols if str(s) in graph.sdfg.symbols} + #print("free_syms:", free_syms) + sym_val_map = dict() + symbolic_sym_val_map = dict() + nodes_to_check = {self.conditional} + + # Do reverse BFS from the sink node to get all possible interstate assignments + while nodes_to_check: + node_to_check = nodes_to_check.pop() + ies = {ie for ie in graph.in_edges(node_to_check)} + for ie in ies: + for k, v in ie.data.assignments.items(): + if k in free_syms and k not in sym_val_map: + # in case if Eq((x + 1 > b), 1) sympy will have a problem + expr = pystr_to_symbolic(v, simplify=False) + #if isinstance(expr, Eq): + sym_val_map[k] = v + # symbolic_sym_val_map[k] = expr + # else: if not an Equality expression, subs / simplify will cause issue + # As relationals can't be integers in Sympy system + # But DaCe uses 1 == True, 0 == False + nodes_to_check = nodes_to_check.union({ie.src for ie in ies}) + + # The symbol can in free symbols + assigned_syms = {k for k in sym_val_map} + unassigned_syms = free_syms - assigned_syms + #print("assigned_syms:", assigned_syms) + #print("unassigned_syms:", unassigned_syms) + #print("sym_val_map:", sym_val_map) + + # If 1 free symbol, easy it means it is condition variable, + # otherwise get the left most + if len(cond_code_symexpr.free_symbols) == 1: + cond_var = str(next(iter(cond_code_symexpr.free_symbols))) + else: + cond_var = cutil.token_split_variable_names(cond_code_str).pop() + + # If the sym_map has any functions, then we need to drop, e.g. array access + new_sym_val_map = dict() + for k, v in sym_val_map.items(): + vv = dace.symbolic.SymExpr(v) + funcs = [e for e in vv.atoms(Function)] + #print(f"Functions in {v} are {funcs}") + if len(funcs) == 0: + new_sym_val_map[str(k)] = str(v) + sym_val_map = new_sym_val_map + + # Subsitute using replace dict to avoid problems + self.conditional.replace_dict(sym_val_map) + + #print( + # f"Subs {cond_code_symexpr}, with sym map ({sym_val_map}) -> {cond_code_symexpr.subs(sym_val_map)} | {cond_code_symexpr.xreplace(symbolic_sym_val_map)}" + #) + #cond_assignment = pycode(cond_code_symexpr.subs(sym_val_map)) + new_conds = {c.as_string for c, b in self.conditional.branches if c is not None} + new_cond = new_conds.pop() + #raise Exception(cond_var, new_cond) + + return cond_var, new_cond + + def _generate_identity_write(self, state: SDFGState, arr_name: str, subset: dace.subsets.Range): + accessed_data = {n.data for n in state.nodes() if isinstance(n, dace.nodes.AccessNode)} + #if arr_name in accessed_data: + # print( + # "Adding_a_new identity assign - even though present. Allowed if the ConditionalBlock only has 1 branch." + # ) + + an1 = state.add_access(arr_name) + an2 = state.add_access(arr_name) + + node_labels = {n.label for n in state.nodes()} + candidate = "identity_assign_0" + i = 0 + while candidate in node_labels: + i += 1 + candidate = f"identity_assign_{i}" + + assign_t = state.add_tasklet(name=candidate, inputs={"_in"}, outputs={"_out"}, code="_out = _in") + + subset_str = ",".join([f"{b}:{e+1}:{s}" for (b, e, s) in subset]) + + state.add_edge(an1, None, assign_t, "_in", dace.memlet.Memlet(f"{arr_name}[{subset_str}]")) + state.add_edge(assign_t, "_out", an2, None, dace.memlet.Memlet(f"{arr_name}[{subset_str}]")) + + return [an1, assign_t, an2] + + ni = 0 + + def _split_branches(self, parent_graph: ControlFlowRegion, if_block: ConditionalBlock): + # Create two new conditional blocks with single branches each + tup0 = if_block.branches[0] + tup1 = if_block.branches[1] + (cond0, body0) = tup0[0], tup0[1] + (cond1, body1) = tup1[0], tup1[1] + + cond = cond0 if cond0 is not None else cond1 + body = body0 if cond0 is None else body1 + + if_block.remove_branch(body) + assert cond.language == dace.dtypes.Language.Python + + if_out_edges = parent_graph.out_edges(if_block) + + new_if_block = ConditionalBlock(label=f"{if_block.label}_negated", sdfg=parent_graph.sdfg, parent=parent_graph) + + # Get the condition assignment of the if-block to copy the symbol type + # We add its negation to the new branch (e.g. expr == 0 instead of expr == 1 which is the usual one) + cond_var, cond_assignment = self._extract_condition_var_and_assignment(if_block) + + new_if_block.add_branch(condition=CodeBlock(f"({cond_assignment}) == 0"), branch=body) + + parent_graph.add_node(new_if_block) + + for oe in if_out_edges: + parent_graph.remove_edge(oe) + parent_graph.add_edge(new_if_block, oe.dst, copy.deepcopy(oe.data)) + + # Do not use negation assignments={f"{negated_name}": f"not ({cond.as_string})"} + # Crates issue when simplifying with sympy + parent_graph.add_edge(if_block, new_if_block, dace.sdfg.InterstateEdge()) + + parent_graph.reset_cfg_list() + + return if_block, new_if_block + + def move_interstate_assignments_from_empty_start_states_to_front_of_conditional(self, graph: ControlFlowRegion, + conditional: ConditionalBlock): + # Pattern 1 + # If start block is a state and is empty and has assignments only, + # We can add them before + applied = False + + if len(self.conditional.branches) == 1: + cond, body = self.conditional.branches[0] + + first_start_block = body.start_block + start_block = body.start_block + # While empty state with assignments + while (len(start_block.nodes()) == 0 and isinstance(start_block, dace.SDFGState) + and body.out_degree(start_block) == 1): + + assert body.in_degree(start_block) == 0 + # RM and copy assignments + oe = body.out_edges(start_block)[0] + assignments = oe.data.assignments + body.remove_node(start_block) + # Update start block + start_block = oe.dst + + # Add state before if cond + graph.add_state_before(self.conditional, + label=start_block.label + "_p", + assignments=copy.deepcopy(assignments), + is_start_block=graph.start_block == self.conditional) + applied = True + + if start_block != first_start_block: + # Rm and Add conditional to get the start block correct + oes = body.out_edges(start_block) + cpnode = copy.deepcopy(start_block) + body.remove_node(start_block) + body.add_node(cpnode, is_start_block=True) + + for oe in oes: + body.add_edge(cpnode, oe.dst, copy.deepcopy(oe.data)) + + graph.sdfg.reset_cfg_list() + sdutil.set_nested_sdfg_parent_references(graph.sdfg) + graph.sdfg.validate() + + return applied + + def duplicate_condition_across_all_top_level_nodes_if_line_graph_and_empty_interstate_edges( + self, graph: ControlFlowRegion): + # Pattern 2 + # If all top-level nodes are connected through empty interstate edges + # And we have a line graph, put each state to the same if condition + applied = False + print(len(self.conditional.branches)) + if len(self.conditional.branches) == 1: + cond, body = self.conditional.branches[0] + + nodes = [n for n in body.bfs_nodes()] + print(len(nodes), nodes) + if len(nodes) <= 1: + return False + + in_degree_leq_one = all({body.in_degree(n) <= 1 for n in nodes}) + out_degree_leq_one = all({body.out_degree(n) <= 1 for n in nodes}) + edges = body.edges() + # Can support if they are not fully empty + all_edges_empty = all({e.data.assignments == dict() for e in edges}) + + if in_degree_leq_one and out_degree_leq_one: #and all_edges_empty: + # Put all nodes into their own if condition + node_to_add_after = self.conditional + # First node gets to stay + for ci, node in enumerate(nodes[1:]): + # Get edge data to copy + + if not all_edges_empty: + cfg_in_edges = body.in_edges(node) + assert len(cfg_in_edges) <= 1, f"{cfg_in_edges}" + cfg_in_edge = cfg_in_edges[0] if len(cfg_in_edges) == 1 else None + cfg_out_edges = body.out_edges(node) + assert len(cfg_out_edges) <= 1, f"{cfg_out_edges}" + cfg_out_edge = cfg_out_edges[0] if len(cfg_out_edges) == 1 else None + + body.remove_node(node) + + parent_graph = self.conditional.parent_graph + + is_empty_state = isinstance(node, dace.SDFGState) and len(node.nodes()) == 0 + # If state is empty do not wrap it in a conditional region + if not is_empty_state: + copy_conditional = ConditionalBlock(label=self.conditional.label + f"_v_{ci}", + sdfg=self.conditional.sdfg, + parent=parent_graph) + + cfg = ControlFlowRegion(label=self.conditional.label + f"_v_{ci}_body", + sdfg=self.conditional.sdfg, + parent=copy_conditional) + cfg.add_node(copy.deepcopy(node)) + copy_conditional.add_branch(condition=copy.deepcopy(cond), branch=cfg) + else: + copy_conditional = copy.deepcopy(node) + + parent_graph.add_node(copy_conditional, False, False) + + for oe in parent_graph.out_edges(node_to_add_after): + parent_graph.remove_edge(oe) + parent_graph.add_edge(copy_conditional, oe.dst, copy.deepcopy(oe.data)) + + # Find the edge between the + parent_graph.add_edge(node_to_add_after, copy_conditional, InterstateEdge()) + + if not all_edges_empty: + if cfg_in_edge is not None: + pre_assign = parent_graph.add_state_before( + state=copy_conditional, + label=f"pre_assign_{copy_conditional.label}", + is_start_block=parent_graph.start_block == copy_conditional, + assignments=cfg_in_edge.data.assignments) + if cfg_out_edge is not None: + post_assign = parent_graph.add_state_after(state=copy_conditional, + label=f"post_assign_{copy_conditional.label}", + is_start_block=False, + assignments=cfg_out_edge.data.assignments) + node_to_add_after = post_assign + else: + node_to_add_after = copy_conditional + else: + node_to_add_after = copy_conditional + applied = True + + graph.sdfg.reset_cfg_list() + sdutil.set_nested_sdfg_parent_references(graph.sdfg) + return applied + + def sequentialize_if_else_branch_if_disjoint_subsets(self, graph: ControlFlowRegion): + # Disjoint subsets do not require the combine tasklet. + # Therefore we need to split: + # if (cond1) { + # body1 + # } else { + # body2 + # } + # into two sequential ifs: + # if (cond1) { + # body1 + # } + # neg_cond1 = !cond1 + # if (!cond1) { + # body2 + # } + # And thus we can use twice the single branch imlpementaiton + if self.can_be_applied(graph=graph, expr_index=0, sdfg=graph.sdfg, permissive=False): + if len(self.conditional.branches) == 2: + tup0 = self.conditional.branches[0] + tup1 = self.conditional.branches[1] + (cond0, body0) = tup0[0], tup0[1] + (cond1, body1) = tup1[0], tup1[1] + state0: SDFGState = body0.nodes()[0] + state1: SDFGState = body1.nodes()[0] + if self._is_disjoint_subset(state0, state1): # Then we need to sequentialize branches + first_if, second_if = self._split_branches(parent_graph=graph, if_block=self.conditional) + return first_if, second_if + return None, None + + def demote_branch_only_symbols_appearing_only_a_single_branch_to_scalars_and_try_fuse( + self, graph: ControlFlowRegion, sdfg: dace.SDFG): + applied = False + for branch, body in self.conditional.branches: + # 2 states, first state empty and only thing is interstate assignments + print(body.nodes()) + print(body.start_block.nodes()) + if (len(body.nodes()) == 2 and all({isinstance(n, dace.SDFGState) + for n in body.nodes()}) and len(body.start_block.nodes()) == 0 + and len(body.edges()) == 1): + + edge = body.edges()[0] + # If symbol not used anywhere else + symbols_reused = False + symbols_defined = set() + for k, v in edge.data.assignments.items(): + symbols_reused |= self.symbol_reused_outside_conditional(k) + symbols_defined.add(k) + if symbols_reused: + break + + if len(symbols_defined) > 1: + print("Not well tested: More than one symbol in the clean-up is not well tested skipping") + continue + + applied = True + + if not symbols_reused: + # Then demote all symbols + if len(symbols_defined) == 0: + # First state is empty then + start_block, other_state = list(body.bfs_nodes())[0:2] + assert body.start_block == start_block + body.remove_node(body.start_block) + body.remove_node(other_state) + body.add_node(other_state, is_start_block=True) + applied = True + continue + assert len(symbols_defined) == 1 + for i, symbol_str in enumerate(symbols_defined): + # It might be that the symbol is not defined (defined through an interstate edge) + #if symbol_str not in sdfg.symbols: + # sdfg.add_symbol(symbol_str, dace.float64) + print(symbol_str) + sdutil.demote_symbol_to_scalar(sdfg, symbol_str, dace.float64) + # Get edges of the first nodes + edges = list(body.all_edges(*(list(body.bfs_nodes()))[0:2])) + if len(edges) == 2: + edge0, edge1 = edges + # assert edge0.data.assignments == dict() + # assert edge1.data.assignments == dict() + for k, v in edge1.data.assignments.items(): + assert k not in edge0.data.assignments + edge0.data.assignments[k] = v + else: + pass + + # State fusion will fail but we know it is fine + # Copy all access nodes to the next state, connect the sink node from prev. state + # to the next state + body.reset_cfg_list() + body.sdfg.save("x2.sdfg") + assignment_state, other_state = list(body.bfs_nodes())[1:3] + node_map = cutil.copy_state_contents(assignment_state, other_state) + # Multiple symbols -> multiple sink nodes + + sink_nodes = {n for n in assignment_state.nodes() if assignment_state.out_degree(n) == 0} + #print("Sink nodes:", sink_nodes, " of:", assignment_state.nodes()) + + for sink_node in sink_nodes: + sink_data = sink_node.data + sink_node_in_other_state = node_map[sink_node] + + # Find matching source nodes with same name + source_nodes = { + n + for n in other_state.nodes() if isinstance(n, dace.nodes.AccessNode) + and n.data == sink_data and n not in node_map.values() + } + + # Reconnect edges to the new source node + for source_node in source_nodes: + out_edges = other_state.out_edges(source_node) + for out_edge in out_edges: + other_state.remove_edge(out_edge) + other_state.add_edge(sink_node_in_other_state, out_edge.src_conn, out_edge.dst, + out_edge.dst_conn, copy.deepcopy(out_edge.data)) + other_state.remove_node(source_node) + + # Remove both nodes to change the start block + # Old node is not needed enaymore + if i != len(symbols_defined) - 1: + #body.remove_node(body.start_block) + oes = body.out_edges(body.start_block) + assert len(oes) == 1 + #body.remove_node(other_state) + body.remove_node(assignment_state) + #body.add_node(other_state, is_start_block=False) + body.add_edge(body.start_block, other_state, copy.deepcopy(oes.pop().data)) + else: + body.remove_node(body.start_block) + body.remove_node(other_state) + body.remove_node(assignment_state) + body.add_node(other_state, is_start_block=True) + applied = True + return applied + + def demote_branch_only_symbols_appering_on_both_branches_to_scalars_and_try_fuse( + self, graph: ControlFlowRegion, sdfg: dace.SDFG): + applied = False + if len(self.conditional.branches) != 2: + return False + (cond0, body0), (cond1, body1) = self.conditional.branches + # 2 states, first state empty and only thing is interstate assignments + if (len(body0.nodes()) == 2 and all({isinstance(n, dace.SDFGState) + for n in body0.nodes()}) and len(body0.start_block.nodes()) == 0 + and len(body0.edges()) == 1 and len(body1.nodes()) == 2 + and all({isinstance(n, dace.SDFGState) + for n in body1.nodes()}) and len(body1.start_block.nodes()) == 0 and len(body1.edges()) == 1): + edge0 = body0.edges()[0] + edge1 = body1.edges()[0] + + # If symbol not used anywhere else + symbols_defined0 = set() + symbols_defined1 = set() + for k, v in edge0.data.assignments.items(): + symbols_defined0.add(k) + for k, v in edge1.data.assignments.items(): + symbols_defined1.add(k) + + if symbols_defined1 != symbols_defined0: + return False + + symbols_defined = symbols_defined0 + + # Then demote all symbols + for symbol_str in symbols_defined: + #if symbol_str not in sdfg.symbols: + # sdfg.add_symbol(symbol_str, dace.float64) + sdutil.demote_symbol_to_scalar(sdfg, symbol_str, dace.float64) + + # Get edges coming and out from the first two nodes + edge0_0, edge0_1 = list(body0.all_edges(*(list(body0.bfs_nodes()))[0:2])) + edge1_0, edge1_1 = list(body1.all_edges(*(list(body1.bfs_nodes()))[0:2])) + assert edge0_0.data.assignments == dict() + assert edge1_1.data.assignments == dict() + assert edge0_1.data.assignments == dict() + assert edge1_0.data.assignments == dict() + + # State fusion will fail but we know it is fine + # Copy all access nodes to the next state, connect the sink node from prev. state + # to the next state + body0.reset_cfg_list() + body1.reset_cfg_list() + + for body in [body0, body1]: + #print("CCC", body) + assignment_state, other_state = list(body.bfs_nodes())[1:3] + node_map = cutil.copy_state_contents(assignment_state, other_state) + # Multiple symbols -> multiple sink nodes + + sink_nodes = {n for n in assignment_state.nodes() if assignment_state.out_degree(n) == 0} + #print("Sink nodes:", sink_nodes, " of:", assignment_state.nodes()) + + for sink_node in sink_nodes: + sink_data = sink_node.data + sink_node_in_other_state = node_map[sink_node] + + # Find matching source nodes with same name + source_nodes = { + n + for n in other_state.nodes() + if isinstance(n, dace.nodes.AccessNode) and n.data == sink_data and n not in node_map.values() + } + + # Reconnect edges to the new source node + for source_node in source_nodes: + if source_node == sink_node_in_other_state: + continue + out_edges = other_state.out_edges(source_node) + for out_edge in out_edges: + other_state.remove_edge(out_edge) + other_state.add_edge(sink_node_in_other_state, out_edge.src_conn, out_edge.dst, + out_edge.dst_conn, copy.deepcopy(out_edge.data)) + + other_state.remove_node(source_node) + + # Remove both nodes to change the start block + # Old node is not needed enaymore + body.remove_node(body.start_block) + body.remove_node(other_state) + body.remove_node(assignment_state) + body.add_node(other_state, is_start_block=True) + + FuseStates().apply_pass(self.conditional.sdfg, {}) + applied = True + return applied + + def try_clean(self, graph: ControlFlowRegion, sdfg: SDFG, lift_multi_state: bool = False): + assert graph == self.conditional.parent_graph + assert sdfg == self.conditional.parent_graph.sdfg + # Some patterns that we can clean + + applied = False + applied |= self.move_interstate_assignments_from_empty_start_states_to_front_of_conditional( + graph, self.conditional) + sdfg.validate() + + # If not top-level tasklets the others wont work + if not all({self.only_top_level_tasklets(body) for cond, body in self.conditional.branches}): + return False + + # Implicitly done by can-be-applied + # self.sequentialize_if_else_branch_if_disjoint_subsets(graph) + + applied |= self.demote_branch_only_symbols_appering_on_both_branches_to_scalars_and_try_fuse(graph, sdfg) + sdfg.validate() + + applied |= self.demote_branch_only_symbols_appearing_only_a_single_branch_to_scalars_and_try_fuse(graph, sdfg) + sdfg.validate() + + if lift_multi_state: + applied |= self.duplicate_condition_across_all_top_level_nodes_if_line_graph_and_empty_interstate_edges( + graph) + sdfg.validate() + + return applied + + _processed_tasklets = set() + + def make_division_tasklets_safe_for_unconditional_execution( + self, + state: dace.SDFGState, + precision: typeclass, + ): + + tasklets = {n for n in state.nodes() if isinstance(n, dace.nodes.Tasklet) and n not in self._processed_tasklets} + self._processed_tasklets = self._processed_tasklets.union(tasklets) + + def _add_eps(expr_str: str, eps: str): + eps_node = ast.Name(id=eps, ctx=ast.Load()) + + # It might better to use max instead of DIV + tree = ast.parse(expr_str, mode='exec') + tree = DivEps(eps_node=eps_node, tasklet_code_str=expr_str, + mode=self.eps_operator_type_for_log_and_div).visit(tree) + return ast.unparse(tree).strip() + + if precision == dace.float64: + eps = numpy.finfo( + numpy.float64 + ).tiny * 8 # Having the min as the limit can still result with NaN due to how string stuff works + else: + eps = numpy.finfo(numpy.float32).tiny * 8 + + has_division = False + for tasklet in tasklets: + tasklet_code_str = tasklet.code.as_string + if tasklet.code.language != dace.dtypes.Language.Python: + continue + e = str(eps) + new_code = _add_eps(tasklet_code_str, e) + if new_code != tasklet.code.as_string: + tasklet_code_str = CodeBlock(new_code) + tasklet.code = tasklet_code_str + has_division = True + + return has_division + + def get_free_syms_outside_calls(self, expr, sdfg_symbols): + # Collect all free symbols + all_syms = expr.free_symbols + + # Collect symbols appearing in function calls + func_arg_syms = set() + for node in sympy.preorder_traversal(expr): + if isinstance(node, sympy.Function): + for arg in node.args: + func_arg_syms |= arg.free_symbols + + # Keep only symbols that are not used as function arguments + outside_syms = all_syms - func_arg_syms + + # Filter to symbols known to the SDFG + return {str(s).strip() for s in outside_syms if str(s) in sdfg_symbols} + + def condition_has_map_param(self): + # Can be applied should ensure this + root_sdfg = self.conditional.sdfg if self.parent_nsdfg_state is None else self.parent_nsdfg_state.sdfg + + all_parent_maps_and_loops = cutil.get_parent_map_and_loop_scopes(root_sdfg=root_sdfg, + node=self.conditional, + parent_state=None) + + all_params = set() + for map_or_loop in all_parent_maps_and_loops: + if isinstance(map_or_loop, dace.nodes.MapEntry): + all_params = all_params.union(map_or_loop.map.params) + else: + assert isinstance(map_or_loop, LoopRegion) + all_params.add(map_or_loop.loop_variable) + + graph = self.conditional.parent_graph + + non_none_conds = [cond for cond, _ in self.conditional.branches if cond is not None] + assert len(non_none_conds) == 1 + cond = non_none_conds.pop() + cond_code_str = cond.as_string + cond_code_symexpr = pystr_to_symbolic(cond_code_str, simplify=False) + # assert len(cond_code_symexpr.free_symbols) == 1, f"{cond_code_symexpr}, {cond_code_symexpr.free_symbols}" + + # Find values assigned to the symbols + # print("Condition as symexpr:", cond_code_symexpr) + + # Ignore free symbols appearing in functionc alls. + # e.g. _for_it_32 < klev is a real free symbol and is a problem + # A[_for_it_32] (which becomes A(_for_it_32)) like a function call si not a problem + free_syms = self.get_free_syms_outside_calls(cond_code_symexpr, self.conditional.sdfg.symbols) + + nodes_to_check = {self.conditional} + while nodes_to_check: + node_to_check = nodes_to_check.pop() + ies = {ie for ie in graph.in_edges(node_to_check)} + for ie in ies: + for k, v in ie.data.assignments.items(): + if k in free_syms: + # in case if Eq((x + 1 > b), 1) sympy will have a problem + expr = pystr_to_symbolic(v, simplify=False) + free_syms = free_syms.union( + self.get_free_syms_outside_calls(expr, self.conditional.sdfg.symbols)) + nodes_to_check = nodes_to_check.union({ie.src for ie in ies}) + + return all_params.intersection(free_syms) != set() + + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): + # If CFG has 1 or two branches + # If two branches then the write sets to sink nodes are the same + + # Strategy copy the nodes of the states to the new state + # If we have 1 state we could essentially mimic the same behaviour + # by making so such that the state only has copies for the writes + assert graph == self.conditional.parent_graph + cond_var, cond_assignment = self._extract_condition_var_and_assignment(graph) + orig_cond_var = cond_var + + if len(self.conditional.branches) == 2: + tup0 = self.conditional.branches[0] + tup1 = self.conditional.branches[1] + (cond0, body0) = tup0[0], tup0[1] + (cond1, body1) = tup1[0], tup1[1] + + state0: SDFGState = body0.nodes()[0] + state1: SDFGState = body1.nodes()[0] + + # Disjoint subsets do not require the combine tasklet. + # Therefore we need to split: + # if (cond1) { + # body1 + # } else { + # body2 + # } + # into two sequential ifs: + # if (cond1) { + # body1 + # } + # neg_cond1 = !cond1 + # if (!cond1) { + # body2 + # } + # And thus we can use twice the single branch imlpementaiton + first_if, second_if = self.sequentialize_if_else_branch_if_disjoint_subsets(graph) + + if first_if is not None and second_if is not None: + t1 = BranchElimination() + t1.conditional = first_if + if sdfg.parent_nsdfg_node is not None: + t1.parent_nsdfg_state = self.parent_nsdfg_state + # Right now can_be_applied False because the is reused, but we do not care about + # this type of a reuse - so call permissive=True + assert t1.can_be_applied(graph=graph, expr_index=0, sdfg=sdfg, permissive=True) + t1.apply(graph=graph, sdfg=sdfg) + + t2 = BranchElimination() + t2.conditional = second_if + if sdfg.parent_nsdfg_node is not None: + t2.parent_nsdfg_state = self.parent_nsdfg_state + # Right now can_be_applied False because the is reused, but we do not care about + # this type of a reuse - so call permissive=True + assert t2.can_be_applied(graph=graph, expr_index=0, sdfg=sdfg, permissive=True) + t2.apply(graph=graph, sdfg=sdfg) + # Create two single branch SDFGs + # Then call apply on each one of them + return + + cond_prep_state = graph.add_state_before(self.conditional, + f"cond_prep_for_fused_{self.conditional}", + is_start_block=graph.in_degree(self.conditional) == 0) + cond_var_as_float_name = self._move_interstate_assignment_to_state(cond_prep_state, cond_assignment, cond_var) + + if len(self.conditional.branches) == 2: + tup0 = self.conditional.branches[0] + tup1 = self.conditional.branches[1] + (cond0, body0) = tup0[0], tup0[1] + (cond1, body1) = tup1[0], tup1[1] + + state0: SDFGState = body0.nodes()[0] + state1: SDFGState = body1.nodes()[0] + + if self._is_disjoint_subset(state0, state1): + raise Exception("The case shoudl have been handled by branch split") + + new_state = dace.SDFGState(f"fused_{state0.label}_and_{state1.label}") + state0_to_new_state_node_map = cutil.copy_state_contents(state0, new_state) + state1_to_new_state_node_map = cutil.copy_state_contents(state1, new_state) + + # State1, State0 write sets are data names which should be present in the new state too + read_sets0, write_sets0 = state0.read_and_write_sets() + read_sets0, write_sets1 = state1.read_and_write_sets() + + joint_writes = write_sets0.intersection(write_sets1) + + # Remove ignored writes from the set + ignored_writes0 = self.collect_ignored_writes(state0) + ignored_writes1 = self.collect_ignored_writes(state1) + ignored_writes = ignored_writes0.union(ignored_writes1) + joint_writes = joint_writes.difference(ignored_writes) + + graph.add_node(new_state) + for ie in graph.in_edges(self.conditional): + graph.add_edge(ie.src, new_state, copy.deepcopy(ie.data)) + for oe in graph.out_edges(self.conditional): + graph.add_edge(new_state, oe.dst, copy.deepcopy(oe.data)) + + for write in joint_writes: + state0_write_accesses = self.collect_write_accesses(state0, write) + state1_write_accesses = self.collect_write_accesses(state1, write) + + state0_write_accesses_in_new_state = {state0_to_new_state_node_map[n] for n in state0_write_accesses} + state1_write_accesses_in_new_state = {state1_to_new_state_node_map[n] for n in state1_write_accesses} + + # If it was a single branch we support multiple writes (just fuse last one) + # But this is two branches so we need to ensure single write access nodes + # We only support multiple-writes if the branch has only an if or else + # Sequentializing the branches would fix this issue (but would come at a performance cost) + assert len(state0_write_accesses_in_new_state) == 1, f"len({state0_write_accesses_in_new_state}) != 1" + assert len(state1_write_accesses_in_new_state) == 1, f"len({state1_write_accesses_in_new_state}) != 1" + assert len(state0_write_accesses) == 1 + + state0_write_access = state0_write_accesses.pop() + state0_in_new_state_write_access = state0_write_accesses_in_new_state.pop() + state1_in_new_state_write_access = state1_write_accesses_in_new_state.pop() + + combine_tasklet, tmp1_access, tmp2_access, float_cond_access = self.add_conditional_write_combination( + new_state=new_state, + state0_in_new_state_write_access=state0_in_new_state_write_access, + state1_in_new_state_write_access=state1_in_new_state_write_access, + cond_var_as_float_name=cond_var_as_float_name, + write_name=write, + index=0) + + new_state.remove_node(state1_in_new_state_write_access) + float_type = new_state.sdfg.arrays[float_cond_access.data].dtype + + has_divisions = self.make_division_tasklets_safe_for_unconditional_execution(new_state, float_type) + + if not has_divisions: + self._try_simplify_combine_tasklet(new_state, combine_tasklet) + + else: + assert len(self.conditional.branches) == 1 + tup0 = self.conditional.branches[0] + (cond0, body0) = tup0[0], tup0[1] + + state0: SDFGState = body0.nodes()[0] + state1 = SDFGState("tmp_branch", sdfg=state0.sdfg) + + new_state = dace.SDFGState(f"fused_{state0.label}_and_{state1.label}") + state0_to_new_state_node_map = cutil.copy_state_contents(state0, new_state) + + read_sets0, write_sets0 = state0.read_and_write_sets() + joint_writes = write_sets0.difference(self.collect_ignored_writes(state0)) + + # If there ignored but reused data add them too + # It is ok to have reused data for this case + _, reused_but_ignored = self.ignored_accesses_are_reused({state0}) + joint_writes = joint_writes.union(reused_but_ignored) + + new_joint_writes = copy.deepcopy(joint_writes) + new_reads = dict() + for write in joint_writes: + state0_write_accesses = self.collect_write_accesses(state0, write) + state0_write_accesses_in_new_state = {state0_to_new_state_node_map[n] for n in state0_write_accesses} + + for state0_write_access in state0_write_accesses: + ies = state0.in_edges(state0_write_access) + assert len(ies) >= 1 + ie = ies[0] + if ie.data.data is not None: + # Other subset + if ie.data.data == write: + subset_to_use = ie.data.subset + else: + assert ie.data.other_subset is not None + subset_to_use = ie.data.other_subset + an1, tasklet, an2 = self._generate_identity_write(state1, write, subset_to_use) + new_reads[state0_write_access] = (write, ie.data, (an1, tasklet, an2)) + + # Copy over all identify writes + state1_to_new_state_node_map = cutil.copy_state_contents(state1, new_state) + read_sets0, write_sets1 = state1.read_and_write_sets() + joint_writes = new_joint_writes + + # New joint writes require adding reads to a previously output-only connector, + # we nede to do this if the data is not transient + if new_reads: + if graph.sdfg.parent_nsdfg_node is not None: + parent_nsdfg_node = graph.sdfg.parent_nsdfg_node + parent_nsdfg_state = self.parent_nsdfg_state + for i, (new_read_name, new_read_memlet, nodes) in enumerate(new_reads.values()): + assert new_read_name in graph.sdfg.arrays + new_read_is_transient = graph.sdfg.arrays[new_read_name].transient + if new_read_is_transient: + continue + if new_read_name not in parent_nsdfg_node.in_connectors: + write_edges = set( + parent_nsdfg_state.out_edges_by_connector(parent_nsdfg_node, new_read_name)) + assert len(write_edges) == 1, f"{write_edges} of new_read: {new_read_name}" + write_edge = write_edges.pop() + write_subset: dace.subsets.Range = write_edge.data.subset + # This is not necessarily true because the subset connection can be the full set + # assert write_subset.num_elements_exact() == 1, f"{new_read_name}: {write_subset}: {write_subset.num_elements_exact()}, ()" + use_exact_subset = write_subset.num_elements_exact() == 1 + cutil.insert_non_transient_data_through_parent_scopes( + non_transient_data={write_edge.data.data}, + nsdfg_node=parent_nsdfg_node, + parent_graph=parent_nsdfg_state, + parent_sdfg=parent_nsdfg_state.sdfg, + add_to_output_too=False, + add_with_exact_subset=use_exact_subset, + exact_subset=copy.deepcopy(write_subset), + nsdfg_connector_name=new_read_name) + + graph.add_node(new_state) + for ie in graph.in_edges(self.conditional): + graph.add_edge(ie.src, new_state, copy.deepcopy(ie.data)) + for oe in graph.out_edges(self.conditional): + graph.add_edge(new_state, oe.dst, copy.deepcopy(oe.data)) + for write in joint_writes: + state0_write_accesses = self.collect_write_accesses(state0, write) + + #state0_write_accesses_in_new_state = {state0_to_new_state_node_map[n] for n in state0_write_accesses} + + # For each combining needed, we generate a tasklet on state1 + # Then we copy it over to the new state to combine them together + # Because of this we check the new_reads and its mapping to the new state + for i, (state0_write_access) in enumerate(state0_write_accesses): + new_read_name, new_read_memlet, nodes = new_reads[state0_write_access] + state1_in_new_state_write_access: dace.nodes.AccessNode = state1_to_new_state_node_map[nodes[-1]] + state0_in_new_state_write_access = state0_to_new_state_node_map[state0_write_access] + assert state0_in_new_state_write_access in new_state.nodes() + assert state1_in_new_state_write_access in new_state.nodes() + + # If state 1 access should have setzero defined to be true to avoid writing trash + state1_in_new_state_write_access.setzero = True + + combine_tasklet, tmp1_access, tmp2_access, float_cond_access = self.add_conditional_write_combination( + new_state=new_state, + state0_in_new_state_write_access=state0_in_new_state_write_access, + state1_in_new_state_write_access=state1_in_new_state_write_access, + cond_var_as_float_name=cond_var_as_float_name, + write_name=write, + index=i) + + # If we detect a previous right, for correctness we need to connect to that + self._connect_rhs_identity_assignment_to_previous_read(new_state=new_state, + rhs_access=tmp2_access, + data=state0_in_new_state_write_access.data, + combine_tasklet=combine_tasklet, + skip_set={tmp2_access}) + + new_state.remove_node(state1_in_new_state_write_access) + float_type = new_state.sdfg.arrays[float_cond_access.data].dtype + + has_divisions = self.make_division_tasklets_safe_for_unconditional_execution(new_state, float_type) + + if not has_divisions: + self._try_simplify_combine_tasklet(new_state, combine_tasklet) + + if self.parent_nsdfg_state is not None: + self.parent_nsdfg_state.sdfg.validate() + else: + self.conditional.sdfg.validate() + + # If the symbol is not used anymore + conditional_strs = {cond.as_string for cond, _ in self.conditional.branches if cond is not None} + conditional_symbols = set() + graph.remove_node(self.conditional) + + for cond_str in conditional_strs: + conditional_symbols = conditional_symbols.union( + {str(s) + for s in dace.symbolic.SymExpr(cond_str).free_symbols}) + conditional_symbols.add(orig_cond_var) + + if self.parent_nsdfg_state is not None: + self.parent_nsdfg_state.sdfg.validate() + else: + self.conditional.sdfg.validate() + # Then name says symbols but could be an array too + for sym_name in conditional_symbols: + if not self._symbol_appears_as_read(graph.sdfg, sym_name): + remove_symbol_assignments(graph.sdfg, sym_name) + if sym_name in graph.sdfg.symbols: + graph.sdfg.remove_symbol(sym_name) + if graph.sdfg.parent_nsdfg_node is not None: + if sym_name in graph.sdfg.parent_nsdfg_node.symbol_mapping: + del graph.sdfg.parent_nsdfg_node.symbol_mapping[sym_name] + + if self.parent_nsdfg_state is not None: + self.parent_nsdfg_state.sdfg.validate() + else: + self.conditional.sdfg.validate() + + self._try_fuse(graph, new_state, cond_prep_state) + + graph.sdfg.reset_cfg_list() + sdutil.set_nested_sdfg_parent_references(graph.sdfg) + + if self.parent_nsdfg_state is not None: + self.parent_nsdfg_state.sdfg.validate() + else: + self.conditional.sdfg.validate() + + def _find_previous_write(self, state: dace.SDFGState, sink: dace.nodes.Tasklet, data: str, + skip_set: Set[dace.nodes.Node]): + nodes_to_check = {ie.src for ie in state.in_edges(sink) if ie.src not in skip_set} + while nodes_to_check: + node_to_check = nodes_to_check.pop() + if nodes_to_check in skip_set: + continue + if isinstance(node_to_check, dace.nodes.AccessNode) and node_to_check.data == data: + return node_to_check + nodes_to_check = nodes_to_check.union( + {ie.src + for ie in state.in_edges(node_to_check) if ie.src not in skip_set}) + return None + + def _connect_rhs_identity_assignment_to_previous_read(self, new_state: dace.SDFGState, + rhs_access: dace.nodes.AccessNode, data: str, + combine_tasklet: dace.nodes.Tasklet, + skip_set: Set[dace.nodes.Node]): + identity_rhs_access = self._find_previous_write(new_state, rhs_access, data, set()) + assert identity_rhs_access is not None + previous_write = self._find_previous_write(new_state, combine_tasklet, data, + skip_set.union({identity_rhs_access, rhs_access})) + + if previous_write is not None: + assert identity_rhs_access != previous_write + + # Rm identity rhs access, rewrite the edge + assert len(new_state.out_edges(identity_rhs_access)) == 1 + oe = new_state.out_edges(identity_rhs_access)[0] + new_state.remove_edge(oe) + assert oe.src_conn is None + new_state.add_edge(previous_write, oe.src_conn, oe.dst, oe.dst_conn, copy.deepcopy(oe.data)) + if new_state.degree(oe.src) == 0: + new_state.remove_node(oe.src) diff --git a/dace/transformation/passes/__init__.py b/dace/transformation/passes/__init__.py index 8d0c023a51..7d7b7002ff 100644 --- a/dace/transformation/passes/__init__.py +++ b/dace/transformation/passes/__init__.py @@ -14,5 +14,5 @@ from .simplify import SimplifyPass from .symbol_propagation import SymbolPropagation from .transient_reuse import TransientReuse - +from .eliminate_branches import EliminateBranches from .util import available_passes diff --git a/dace/transformation/passes/eliminate_branches.py b/dace/transformation/passes/eliminate_branches.py new file mode 100644 index 0000000000..2ad086bac2 --- /dev/null +++ b/dace/transformation/passes/eliminate_branches.py @@ -0,0 +1,71 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from dace import properties, SDFG, nodes +from dace.sdfg.state import ConditionalBlock +from dace.transformation import pass_pipeline as ppl +from dace.transformation.transformation import explicit_cf_compatible +from typing import Union, Optional + + +@properties.make_properties +@explicit_cf_compatible +class EliminateBranches(ppl.Pass): + try_clean = properties.Property(dtype=bool, default=False, allow_none=False) + clean_only = properties.Property(dtype=bool, default=False, allow_none=True) + permissive = properties.Property(dtype=bool, default=False, allow_none=False) + eps_operator_type_for_log_and_div = properties.Property(dtype=str, default="add", allow_none=True) + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.CFG + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return False + + def _apply_eliminate_branches(self, root: SDFG, sdfg: SDFG, parent_nsdfg_state: Union[SDFG, None] = None): + """Apply EliminateBranches transformation to all eligible conditionals.""" + from dace.transformation.interstate import branch_elimination + # Pattern matching with conditional branches to not work (9.10.25), avoid it + # Depending on the number of nestedness we need to apply that many times because + # the transformation only runs on top-level ConditionalBlocks + changed = True + while changed: + changed = False + + for node in sdfg.all_control_flow_blocks(): + if isinstance(node, ConditionalBlock): + t = branch_elimination.BranchElimination() + t.conditional = node + t.eps_operator_type_for_log_and_div = self.eps_operator_type_for_log_and_div + + if self.try_clean: + t.try_clean(node.parent_graph, sdfg, True) + node = t.conditional + + if not self.clean_only: + for node in sdfg.all_control_flow_blocks(): + if isinstance(node, ConditionalBlock): + t = branch_elimination.BranchElimination() + t.conditional = node + if node.sdfg.parent_nsdfg_node is not None: + t.parent_nsdfg_state = parent_nsdfg_state + t.eps_operator_type_for_log_and_div = self.eps_operator_type_for_log_and_div + if t.can_be_applied(graph=node.parent_graph, + expr_index=0, + sdfg=node.sdfg, + permissive=self.permissive): + t.apply(graph=node.parent_graph, sdfg=node.sdfg) + changed = True + + for state in sdfg.all_states(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + changed |= self._apply_eliminate_branches(root, node.sdfg, state) + + return changed + + def apply_pass(self, sdfg: SDFG, _) -> Optional[int]: + if self.clean_only is True: + self.try_clean = True + self._apply_eliminate_branches(sdfg, sdfg, None) + + def report(self, pass_retval: int) -> str: + return f'Fused (andd removed) {pass_retval} branches.' diff --git a/dace/transformation/passes/explicit_vectorization_cpu.py b/dace/transformation/passes/explicit_vectorization_cpu.py new file mode 100644 index 0000000000..91c553518c --- /dev/null +++ b/dace/transformation/passes/explicit_vectorization_cpu.py @@ -0,0 +1,402 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +import dace +from typing import Dict, Iterator +from dace.transformation import Pass, pass_pipeline as ppl +from dace.transformation.pass_pipeline import Modifies +from dace.transformation.passes.clean_data_to_scalar_slice_to_tasklet_pattern import CleanDataToScalarSliceToTaskletPattern +from dace.transformation.passes.duplicate_all_memlets_sharing_same_in_connector import DuplicateAllMemletsSharingSingleMapOutConnector +from dace.transformation.passes.split_tasklets import SplitTasklets +from dace.transformation.passes.tasklet_preprocessing_passes import PowerOperatorExpansion, RemoveFPTypeCasts, RemoveIntTypeCasts +from dace.transformation.passes import InlineSDFGs +from dace.transformation.passes.explicit_vectorization import ExplicitVectorization +from dace.transformation.passes.eliminate_branches import EliminateBranchesPass +from dace.transformation.passes.remove_redundant_assignment_tasklets import RemoveRedundantAssignmentTasklets +import dace.sdfg.utils as sdutil + + +class ExplicitVectorizationPipelineCPU(ppl.Pipeline): + _cpu_global_code = """ +#include + + +#if defined(__clang__) + #define _dace_vectorize_hint + #define _dace_vectorize "clang loop vectorize(enable) vectorize_width({vector_width}8)" +#elif defined(__GNUC__) + #define _dace_vectorize_hint + #define _dace_vectorize "omp simd simdlen({vector_width})" +#else + #define _dace_vectorize_hint + #define _dace_vectorize "omp simd simdlen({vector_width})" +#endif + + +template +inline void vector_mult(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + c[i] = a[i] * b[i]; + }} +}} + +template +inline void vector_mult_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = a[i] * constant; + }} +}} + +template +inline void vector_add(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + c[i] = a[i] + b[i]; + }} +}} + +template +inline void vector_add_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = a[i] + constant; + }} +}} + +template +inline void vector_sub(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + c[i] = a[i] - b[i]; + }} +}} + +template +inline void vector_sub_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = a[i] - constant; + }} +}} + +template +inline void vector_sub_w_scalar_c(T * __restrict__ b, const T constant, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = constant - a[i]; + }} +}} + +template +inline void vector_div(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + c[i] = a[i] / b[i]; + }} +}} + +template +inline void vector_div_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = a[i] / constant; + }} +}} + +template +inline void vector_div_w_scalar_c(T * __restrict__ b, const T constant, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = constant / a[i]; + }} +}} + +template +inline void vector_copy(T * __restrict__ dst, const T * __restrict__ src) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + dst[i] = src[i]; + }} +}} + +// ---- Additional elementwise ops ---- + +template +inline void vector_exp(T * __restrict__ out, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = std::exp(a[i]); + }} +}} + +template +inline void vector_log(T * __restrict__ out, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = std::log(a[i]); + }} +}} + +template +inline void vector_min(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = std::min(a[i], b[i]); + }} +}} + +template +inline void vector_min_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = std::min(a[i], constant); + }} +}} + +template +inline void vector_max(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = std::max(a[i], b[i]); + }} +}} + +template +inline void vector_max_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = std::max(a[i], constant); + }} +}} + +template +inline void vector_gt(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] > b[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_gt_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] > constant) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_gt_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (constant > a[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_lt(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] < b[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_lt_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] < constant) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_lt_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (constant < a[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_ge(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] >= b[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_ge_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] >= constant) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_ge_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (constant >= a[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_le(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] <= b[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_le_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] <= constant) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_le_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (constant <= a[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_eq(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] == b[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_eq_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] == constant) ? 1.0 : 0.0; + }} +}} + + +template +inline void vector_ne(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] != b[i]) ? 1.0 : 0.0; + }} +}} + +template +inline void vector_ne_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ + #pragma _dace_vectorize_hint + #pragma _dace_vectorize + for (int i = 0; i < {vector_width}; i++) {{ + out[i] = (a[i] != constant) ? 1.0 : 0.0; + }} +}} +""" + + def __init__(self, vector_width): + passes = [ + EliminateBranchesPass(), + RemoveFPTypeCasts(), + RemoveIntTypeCasts(), + PowerOperatorExpansion(), + SplitTasklets(), + CleanDataToScalarSliceToTaskletPattern(), + InlineSDFGs(), + DuplicateAllMemletsSharingSingleMapOutConnector(), + ExplicitVectorization( + templates={ + "*": "vector_mult({lhs}, {rhs1}, {rhs2});", + "+": "vector_add({lhs}, {rhs1}, {rhs2});", + "-": "vector_sub({lhs}, {rhs1}, {rhs2});", + "/": "vector_div({lhs}, {rhs1}, {rhs2});", + "=": "vector_copy({lhs}, {rhs1});", + "log": "vector_log({lhs}, {rhs1});", + "exp": "vector_exp({lhs}, {rhs1});", + "min": "vector_min({lhs}, {rhs1}, {rhs2});", + "max": "vector_max({lhs}, {rhs1}, {rhs2});", + ">": "vector_gt({lhs}, {rhs1}, {rhs2});", + "<": "vector_lt({lhs}, {rhs1}, {rhs2});", + ">=": "vector_ge({lhs}, {rhs1}, {rhs2});", + "<=": "vector_le({lhs}, {rhs1}, {rhs2});", + "==": "vector_eq({lhs}, {rhs1}, {rhs2});", + "!=": "vector_ne({lhs}, {rhs1}, {rhs2});", + # scalar variants type 1 + "c*": "vector_mult_w_scalar({lhs}, {rhs1}, {constant});", + "c+": "vector_add_w_scalar({lhs}, {rhs1}, {constant});", + "c-": "vector_sub_w_scalar({lhs}, {rhs1}, {constant});", + "c/": "vector_div_w_scalar({lhs}, {rhs1}, {constant});", + "cmin": "vector_min_w_scalar({lhs}, {rhs1}, {constant});", + "cmax": "vector_max_w_scalar({lhs}, {rhs1}, {constant});", + "c>": "vector_gt_w_scalar({lhs}, {rhs1}, {constant});", + "c<": "vector_lt_w_scalar({lhs}, {rhs1}, {constant});", + "c>=": "vector_ge_w_scalar({lhs}, {rhs1}, {constant});", + "c<=": "vector_le_w_scalar({lhs}, {rhs1}, {constant});", + "c==": "vector_eq_w_scalar({lhs}, {rhs1}, {constant});", + "c!=": "vector_ne_w_scalar({lhs}, {rhs1}, {constant});", + # scalar variants type 2 for non-commutative ops + "-c": "vector_sub_w_scalar_c({lhs}, {constant}, {rhs1});", + "/c": "vector_div_w_scalar_c({lhs}, {constant}, {rhs1});", + ">c": "vector_gt_w_scalar_c({lhs}, {constant}, {rhs1});", + "=c": "vector_ge_w_scalar_c({lhs}, {constant}, {rhs1});", + "<=c": "vector_le_w_scalar_c({lhs}, {constant}, {rhs1});", + }, + vector_width=vector_width, + vector_input_storage=dace.dtypes.StorageType.Register, + vector_output_storage=dace.dtypes.StorageType.Register, + global_code=ExplicitVectorizationPipelineCPU._cpu_global_code.format(vector_width=vector_width), + global_code_location="frame", + vector_op_numeric_type=dace.float64) + ] + super().__init__(passes) + + def iterate_over_passes(self, sdfg: dace.SDFG) -> Iterator[Pass]: + """ + Iterates over passes in the pipeline, potentially multiple times based on which elements were modified + in the pass. + Note that this method may be overridden by subclasses to modify pass order. + + :param sdfg: The SDFG on which the pipeline is currently being applied + """ + for p in self.passes: + p: Pass + yield p diff --git a/dace/transformation/passes/explicit_vectorization_gpu.py b/dace/transformation/passes/explicit_vectorization_gpu.py new file mode 100644 index 0000000000..22ab47853e --- /dev/null +++ b/dace/transformation/passes/explicit_vectorization_gpu.py @@ -0,0 +1,129 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Dict, Iterator +import dace +from dace.transformation import Pass, pass_pipeline as ppl +from dace.transformation.pass_pipeline import Modifies +from dace.transformation.passes.clean_data_to_scalar_slice_to_tasklet_pattern import CleanDataToScalarSliceToTaskletPattern +from dace.transformation.passes.duplicate_all_memlets_sharing_same_in_connector import DuplicateAllMemletsSharingSingleMapOutConnector +from dace.transformation.passes.split_tasklets import SplitTasklets +from dace.transformation.passes.tasklet_preprocessing_passes import PowerOperatorExpansion, RemoveFPTypeCasts, RemoveIntTypeCasts +from dace.transformation.passes import InlineSDFGs +from dace.transformation.passes.explicit_vectorization import ExplicitVectorization +from dace.transformation.passes.eliminate_branches import EliminateBranchesPass +from dace.transformation.passes.remove_redundant_assignment_tasklets import RemoveRedundantAssignmentTasklets +import dace.sdfg.utils as sdutil + + +class ExplicitVectorizationPipelineGPU(ppl.Pipeline): + _gpu_global_code = """ +template +__host__ __device__ __forceinline__ void vector_mult(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + c[i] = a[i] * b[i]; + }} +}} + +template +__host__ __device__ __forceinline__ void vector_mult_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ + T cReg[{vector_width}]; + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + cReg[i] = constant; + }} + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = a[i] * cReg[i]; + }} +}} + +template +__host__ __device__ __forceinline__ void vector_add(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + c[i] = a[i] + b[i]; + }} +}} + +template +__host__ __device__ __forceinline__ void vector_add_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ + T cReg[{vector_width}]; + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + cReg[i] = constant; + }} + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = a[i] + cReg[i]; + }} +}} + +template +__host__ __device__ __forceinline__ void vector_div(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + c[i] = a[i] / b[i]; + }} +}} + +template +__host__ __device__ __forceinline__ void vector_div_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ + T cReg[{vector_width}]; + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + cReg[i] = constant; + }} + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + b[i] = a[i] / cReg[i]; + }} +}} + +template +__host__ __device__ __forceinline__ void vector_copy(T * __restrict__ dst, const T * __restrict__ src) {{ + #pragma omp unroll + for (int i = 0; i < {vector_width}; i++) {{ + dst[i] = src[i]; + }} +}} +""" + + def __init__(self, vector_width): + passes = [ + EliminateBranchesPass(), + RemoveFPTypeCasts(), + RemoveIntTypeCasts(), + PowerOperatorExpansion(), + SplitTasklets(), + CleanDataToScalarSliceToTaskletPattern(), + InlineSDFGs(), + DuplicateAllMemletsSharingSingleMapOutConnector(), + ExplicitVectorization( + templates={ + "*": "vector_mult({lhs}, {rhs1}, {rhs2});", + "+": "vector_add({lhs}, {rhs1}, {rhs2});", + "=": "vector_copy({lhs}, {rhs1});", + "c+": "vector_add({lhs}, {rhs1}, {constant});", + "c*": "vector_mult({lhs}, {rhs1}, {constant});", + }, + vector_width=vector_width, + vector_input_storage=dace.dtypes.StorageType.Register, + vector_output_storage=dace.dtypes.StorageType.Register, + global_code=ExplicitVectorizationPipelineGPU._gpu_global_code.format(vector_width=vector_width), + global_code_location="frame", + vector_op_numeric_type=dace.float64, + ) + ] + super().__init__(passes) + + def iterate_over_passes(self, sdfg: dace.SDFG) -> Iterator[Pass]: + """ + Iterates over passes in the pipeline, potentially multiple times based on which elements were modified + in the pass. + Note that this method may be overridden by subclasses to modify pass order. + + :param sdfg: The SDFG on which the pipeline is currently being applied + """ + for p in self.passes: + p: Pass + yield p diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py new file mode 100644 index 0000000000..202edce85e --- /dev/null +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -0,0 +1,2136 @@ +import copy +import numpy as np +import dace +import pytest +from dace.properties import CodeBlock +from dace.sdfg import InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion +from dace.transformation.interstate import branch_elimination +from dace.transformation.passes import ConstantPropagation, EliminateBranches +from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion +from dace.transformation.passes.symbol_propagation import SymbolPropagation + +N = 32 +S = 32 +S1 = 0 +S2 = 32 + + +@dace.program +def branch_dependent_value_write( + a: dace.float64[N, N], + b: dace.float64[N, N], + c: dace.float64[N, N], + d: dace.float64[N, N], +): + for i, j in dace.map[0:N:1, 0:N:1]: + if a[i, j] > 0.5: + c[i, j] = a[i, j] * b[i, j] + d[i, j] = 1 - c[i, j] + else: + c[i, j] = 0.0 + d[i, j] = 0.0 + + +@dace.program +def branch_dependent_value_write_with_transient_reuse( + a: dace.float64[N, N], + b: dace.float64[N, N], + c: dace.float64[N, N], +): + for i, j in dace.map[0:N:1, 0:N:1]: + if a[i, j] > 0.5: + c[i, j] = a[i, j] * b[i, j] + c_scl = 1 - c[i, j] + else: + c[i, j] = 0.1 + c_scl = 1.0 + + if b[i, j] > 0.5: + c[i, j] = c_scl * c[i, j] + else: + c[i, j] = -1.5 * c_scl * c[i, j] + + +@dace.program +def branch_dependent_value_write_two( + a: dace.float64[N, N], + b: dace.float64[N, N], + c: dace.float64[N, N], + d: dace.float64[N, N], +): + for i, j in dace.map[0:N:1, 0:N:1]: + if a[i, j] > 0.3: + b[i, j] = 1.1 + d[i, j] = 0.8 + else: + b[i, j] = -1.1 + d[i, j] = 2.2 + c[i, j] = max(0, b[i, j]) + d[i, j] = max(0, d[i, j]) + + +@dace.program +def weird_condition(a: dace.float64[N, N], b: dace.float64[N, N], ncldtop: dace.int64): + for i, j in dace.map[0:N:1, 0:N:1]: + cond = i + 1 > ncldtop + if cond == 1: + a[i, j] = 1.1 + b[i, j] = 0.8 + + +@dace.program +def multi_state_branch_body(a: dace.float64[N, N], b: dace.float64[N, N], c: dace.float64[N, N], d: dace.float64[N, N], + s: dace.int64): + for i, j in dace.map[0:N:1, 0:N:1]: + if a[i, j] > 0.3: + b[i + s, j + s] = 1.1 + d[i + s, j + s] = 0.8 + else: + b[i + s, j + s] = -1.1 + d[i + s, j + s] = 2.2 + c[i + s, j + s] = max(0, b[i + s, j + s]) + d[i + s, j + s] = max(0, d[i + s, j + s]) + + +@dace.program +def nested_if(a: dace.float64[N, N], b: dace.float64[N, N], c: dace.float64[N, N], d: dace.float64[N, N], + s: dace.int64): + for i, j in dace.map[0:N:1, 0:N:1]: + if a[i, j] > 0.3: + if s == 0: + b[i, j] = 1.1 + d[i, j] = 0.8 + else: + b[i, j] = 1.2 + d[i, j] = 0.9 + else: + b[i, j] = -1.1 + d[i, j] = 2.2 + c[i + s, j + s] = max(0, b[i + s, j + s]) + d[i + s, j + s] = max(0, d[i + s, j + s]) + + +SN = dace.symbol("SN") + + +@dace.program +def condition_on_bounds(a: dace.float64[SN, SN], b: dace.float64[SN, SN], c: dace.float64[SN, SN], + d: dace.float64[SN, SN], s: dace.int64): + for i in dace.map[0:2]: + j = 1 + c1 = (i < s) + if c1: + b[i, j] = 1.1 * c[i + 1, j] * a[i + 1, j] + d[i, j] = 0.8 * c[i + 1, j] * a[i + 1, j] + + +@dace.program +def nested_if_two(a: dace.float64[N, N], b: dace.float64[N, N], c: dace.float64[N, N], d: dace.float64[N, N]): + for i, j in dace.map[0:N:1, 0:N:1]: + c[i, j] = b[i, j] * c[i, j] + d[i, j] = b[i, j] * d[i, j] + if a[i, j] > 0.3: + if b[i, j] > 0.1: + b[i, j] = 1.1 + d[i, j] = 0.8 + else: + b[i, j] = 1.2 + d[i, j] = 0.9 + else: + b[i, j] = -1.1 + d[i, j] = 2.2 + c[i, j] = max(0, b[i, j]) + d[i, j] = max(0, d[i, j]) + + +@dace.program +def branch_dependent_value_write_single_branch( + a: dace.float64[N, N], + b: dace.float64[N, N], + d: dace.float64[N, N], +): + for i, j in dace.map[0:N:1, 0:N:1]: + if a[i, j] < 0.65: + b[i, j] = 0.0 + d[i, j] = b[i, j] + a[i, j] # ensure d is always written for comparison + + +@dace.program +def branch_dependent_value_write_single_branch_nonzero_write( + a: dace.float64[N, N], + b: dace.float64[N, N], + d: dace.float64[N, N], +): + for i, j in dace.map[0:N:1, 0:N:1]: + if a[i, j] < 0.65: + b[i, j] = 1.2 * a[i, j] + d[i, j] = b[i, j] + a[i, j] # ensure d is always written for comparison + + +@dace.program +def complicated_if( + a: dace.float64[N, N], + b: dace.float64[N, N], + d: dace.float64[N, N], +): + for i, j in dace.map[0:N:1, 0:N:1]: + if (a[i, j] + d[i, j] * 3.0) < 0.65: + b[i, j] = 0.0 + d[i, j] = b[i, j] + a[i, j] # ensure d is always written for comparison + + +@dace.program +def tasklets_in_if( + a: dace.float64[S, S], + b: dace.float64[S, S], + d: dace.float64[S, S], + c: dace.float64, +): + for i in dace.map[S1:S2:1]: + for j in dace.map[S1:S2:1]: + if a[i, j] > c: + b[i, j] = b[i, j] + d[i, j] + else: + b[i, j] = b[i, j] - d[i, j] + b[i, j] = (1 - a[i, j]) * c + + +@dace.program +def single_branch_connectors( + a: dace.float64[S, S], + b: dace.float64[S, S], + d: dace.float64[S, S], + c: dace.float64, +): + for i in dace.map[S1:S2:1]: + for j in dace.map[S1:S2:1]: + if a[i, j] > c: + b[i, j] = d[i, j] + + +@dace.program +def if_over_map( + a: dace.float64[S, S], + b: dace.float64[S, S], + d: dace.float64[S, S], + c: dace.float64, +): + if c == 2.0: + for i in dace.map[S1:S2:1]: + for j in dace.map[S1:S2:1]: + cond = (a[0, 0] + b[0, 0] + d[0, 0]) > 2.0 + if cond: + b[i, j] = d[i, j] + + +@dace.program +def if_over_map_with_top_level_tasklets( + a: dace.float64[S, S], + b: dace.float64[S, S], + d: dace.float64[S, S], + c: dace.float64, +): + if c == 2.0: + for i in dace.map[S1:S2:1]: + for j in dace.map[S1:S2:1]: + cond = (b[0, 0] + d[0, 0]) > 2.0 + if cond: + b[i, j] = d[i, j] + a[2, 2] = 3.0 + a[1, 1] = 3.0 + a[4, 4] = 3.0 + + +@dace.program +def disjoint_subsets( + if_cond_58: dace.int32, + A: dace.float64[N], + B: dace.float64[N, 3, 3], + C: dace.float64[N, 3, 3], + E: dace.float64[N], +): + i = 4 + if if_cond_58 == 1: + B[i, 2, 0] = A[i] + B[i, 2, 0] + B[i, 2, 0] = E[i] + B[i, 2, 0] + B[i, 0, 2] = B[i, 2, 0] + C[i, 0, 2] + B[i, 0, 2] = A[i] + B[i, 0, 2] + else: + B[i, 1, 0] = A[i] + B[i, 1, 0] + B[i, 1, 0] = E[i] + B[i, 1, 0] + B[i, 0, 1] = B[i, 2, 0] + C[i, 0, 1] + B[i, 0, 1] = A[i] + B[i, 0, 1] + + +@dace.program +def disjoint_subsets_two( + if_cond_58: dace.int32, + A: dace.float64[N], + B: dace.float64[N, 3, 3], + C: dace.float64[N, 3, 3], + D: dace.float64[N, 3, 3], + E: dace.float64[N], + F: dace.float64[N, 3, 3], +): + for i in dace.map[0:N:1]: + if if_cond_58 == 1: + B[i, 2, 0] = A[i] + B[i, 2, 0] + C[i, 2, 0] = E[i] + B[i, 2, 0] + D[i, 0, 2] = A[i] + C[i, 0, 2] + F[i, 0, 2] = A[i] + D[i, 0, 2] + else: + B[i, 1, 0] = A[i] + B[i, 1, 0] + C[i, 1, 0] = E[i] + B[i, 1, 0] + D[i, 0, 1] = A[i] + C[i, 0, 1] + F[i, 0, 1] = A[i] + D[i, 0, 1] + + +def _get_parent_state(sdfg: dace.SDFG, nsdfg_node: dace.nodes.NestedSDFG): + for n, g in sdfg.all_nodes_recursive(): + if n == nsdfg_node: + return g + return None + + +def apply_branch_elimination(sdfg, nestedness: int = 1): + """Apply BranchElimination transformation to all eligible conditionals.""" + # Pattern matching with conditional branches to not work (9.10.25), avoid it + for i in range(nestedness): + for node, graph in sdfg.all_nodes_recursive(): + parent_nsdfg_node = graph.sdfg.parent_nsdfg_node + parent_state = None + if parent_nsdfg_node is not None: + parent_state = _get_parent_state(sdfg, parent_nsdfg_node) + if isinstance(node, ConditionalBlock): + t = branch_elimination.BranchElimination() + if t.can_be_applied_to(graph.sdfg, conditional=node, options={"parent_nsdfg_state": parent_state}): + t.apply_to(graph.sdfg, conditional=node, options={"parent_nsdfg_state": parent_state}) + + +def run_and_compare( + program, + num_expected_branches, + use_pass, + **arrays, +): + # Run SDFG version (no transformation) + sdfg = program.to_sdfg() + sdfg.validate() + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + sdfg(**out_no_fuse) + # Apply transformation + if use_pass: + fb = EliminateBranches() + fb.try_clean = True + fb.apply_pass(sdfg, {}) + else: + apply_branch_elimination(sdfg, 2) + sdfg.name = sdfg.label + "_transformed" + + # Run SDFG version (with transformation) + out_fused = {k: v.copy() for k, v in arrays.items()} + + sdfg(**out_fused) + + branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len( + branch_code) == num_expected_branches, f"(actual) len({branch_code}) != (desired) {num_expected_branches}" + + # Compare all arrays + for name in arrays.keys(): + np.testing.assert_allclose(out_fused[name], out_no_fuse[name], atol=1e-12) + + +def run_and_compare_sdfg( + sdfg, + permissive, + **arrays, +): + # Run SDFG version (no transformation) + sdfg.validate() + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + sdfg(**out_no_fuse) + + # Run SDFG version (with transformation) + fb = EliminateBranches() + fb.try_clean = True + fb.permissive = permissive + fb.apply_pass(sdfg, {}) + out_fused = {k: v.copy() for k, v in arrays.items()} + sdfg(**out_fused) + + # Compare all arrays + for name in arrays.keys(): + np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_branch_dependent_value_write(use_pass_flag): + a = np.random.rand(N, N) + b = np.random.rand(N, N) + c = np.zeros((N, N)) + d = np.zeros((N, N)) + run_and_compare(branch_dependent_value_write, 0, use_pass_flag, a=a, b=b, c=c, d=d) + + +def test_weird_condition(): + a = np.random.rand(N, N) + b = np.random.rand(N, N) + ncldtop = np.array([N // 2], dtype=np.int64) + run_and_compare(weird_condition, 1, False, a=a, b=b, ncldtop=ncldtop[0]) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_branch_dependent_value_write_two(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.zeros((N, N)) + c = np.zeros((N, N)) + d = np.zeros((N, N)) + run_and_compare(branch_dependent_value_write_two, 0, use_pass_flag, a=a, b=b, c=c, d=d) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_branch_dependent_value_write_single_branch(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + d = np.zeros((N, N)) + run_and_compare(branch_dependent_value_write_single_branch, 0, use_pass_flag, a=a, b=b, d=d) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_complicated_if(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + d = np.zeros((N, N)) + run_and_compare(complicated_if, 0, use_pass_flag, a=a, b=b, d=d) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_multi_state_branch_body(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + c = np.random.choice([0.001, 5.0], size=(N, N)) + d = np.zeros((N, N)) + s = np.zeros((1, )).astype(np.int64) + run_and_compare(multi_state_branch_body, 1 if use_pass_flag else 1, use_pass_flag, a=a, b=b, c=c, d=d, s=s[0]) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_nested_if(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + c = np.random.choice([0.001, 5.0], size=(N, N)) + d = np.random.choice([0.001, 5.0], size=(N, N)) + s = np.zeros((1, )).astype(np.int64) + run_and_compare(nested_if, 0, use_pass_flag, a=a, b=b, c=c, d=d, s=s[0]) + + +def test_condition_on_bounds(): + a = np.random.choice([0.001, 3.0], size=(2, 2)) + b = np.random.choice([0.001, 5.0], size=(2, 2)) + c = np.random.choice([0.001, 5.0], size=(2, 2)) + d = np.random.choice([0.001, 5.0], size=(2, 2)) + + sdfg = condition_on_bounds.to_sdfg() + sdfg.validate() + arrays = {"a": a, "b": b, "c": c, "d": d} + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + sdfg(a=out_no_fuse["a"], b=out_no_fuse["b"], c=out_no_fuse["c"], d=out_no_fuse["d"], s=1, SN=2) + # Apply transformation + EliminateBranches().apply_pass(sdfg, {}) + sdfg.validate() + out_fused = {k: v.copy() for k, v in arrays.items()} + + nsdfgs = {(n, g) for n, g in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} + assert len(nsdfgs) == 1 # Can be applied should return false + + +def test_nested_if_two(): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + c = np.random.choice([0.001, 5.0], size=(N, N)) + d = np.random.choice([0.001, 5.0], size=(N, N)) + run_and_compare(nested_if_two, 0, True, a=a, b=b, c=c, d=d) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_tasklets_in_if(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + c = np.zeros((1, )) + d = np.zeros((N, N)) + run_and_compare(tasklets_in_if, 0, use_pass_flag, a=a, b=b, d=d, c=c[0]) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_branch_dependent_value_write_single_branch_nonzero_write(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + d = np.random.choice([0.001, 5.0], size=(N, N)) + run_and_compare(branch_dependent_value_write_single_branch_nonzero_write, 0, use_pass_flag, a=a, b=b, d=d) + + +def test_branch_dependent_value_write_with_transient_reuse(): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 3.0], size=(N, N)) + c = np.random.choice([0.001, 3.0], size=(N, N)) + run_and_compare(branch_dependent_value_write_with_transient_reuse, 0, True, a=a, b=b, c=c) + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_single_branch_connectors(use_pass_flag): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 5.0], size=(N, N)) + d = np.random.choice([0.001, 5.0], size=(N, N)) + c = np.random.randn(1, ) + + sdfg = single_branch_connectors.to_sdfg() + sdfg.validate() + arrays = {"a": a, "b": b, "c": c, "d": d} + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + sdfg(a=out_no_fuse["a"], b=out_no_fuse["b"], c=out_no_fuse["c"][0], d=out_no_fuse["d"]) + # Apply transformation + if use_pass_flag: + EliminateBranches().apply_pass(sdfg, {}) + else: + apply_branch_elimination(sdfg, 2) + + # Run SDFG version (with transformation) + out_fused = {k: v.copy() for k, v in arrays.items()} + sdfg(a=out_fused["a"], b=out_fused["b"], c=out_fused["c"][0], d=out_fused["d"]) + + branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) 0" + + # Compare all arrays + for name in arrays.keys(): + np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) + + nsdfgs = {(n, g) for n, g in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} + assert len(nsdfgs) == 1 + nsdfg, parent_state = nsdfgs.pop() + assert len(nsdfg.in_connectors) == 4, f"{nsdfg.in_connectors}, length is not 4 but {len(nsdfg.in_connectors)}" + assert len(nsdfg.out_connectors) == 1, f"{nsdfg.out_connectors}, length is not 1 but {len(nsdfg.out_connectors)}" + + +@pytest.mark.parametrize("use_pass_flag", [True, False]) +def test_disjoint_subsets(use_pass_flag): + if_cond_58 = np.array([1], dtype=np.int32) + A = np.random.choice([0.001, 3.0], size=(N, )) + B = np.random.randn(N, 3, 3) + C = np.random.randn(N, 3, 3) + E = np.random.choice([0.001, 3.0], size=(N, 3, 3)) + run_and_compare(disjoint_subsets, 0, use_pass_flag, A=A, B=B, C=C, E=E, if_cond_58=if_cond_58[0]) + + +@dace.program +def _multi_state_nested_if( + A: dace.float64[ + N, + ], + B: dace.float64[N, 3, 3], + C: dace.float64[ + N, + ], + if_cond_1: dace.float64, + offset: dace.int64, +): + if if_cond_1 > 1.0: + _if_cond_2 = C[offset] + if _if_cond_2 > 1.0: + B[6, 1, 0] = A[6] + B[6, 1, 0] + B[6, 0, 1] = A[6] + B[6, 0, 1] + else: + B[6, 2, 0] = A[6] + B[6, 2, 0] + B[6, 0, 2] = A[6] + B[6, 0, 2] + + +def test_try_clean(): + sdfg1 = _multi_state_nested_if.to_sdfg() + cblocks = {n for n, g in sdfg1.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 2 + + for cblock in cblocks: + parent_sdfg = cblock.parent_graph.sdfg + parent_graph = cblock.parent_graph + xform = branch_elimination.BranchElimination() + xform.conditional = cblock + xform.try_clean(graph=parent_graph, sdfg=parent_sdfg) + + # Should have moe states before + cblocks = {n for n, g in sdfg1.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 2 + # A state must have been moved before) + assert isinstance(sdfg1.start_block, dace.SDFGState) + sdfg1.validate() + + fbpass = EliminateBranches() + fbpass.try_clean = False + fbpass.apply_pass(sdfg1, {}) + cblocks = {n for n, g in sdfg1.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + # 1 left because now the if branch has 2 states + assert len(cblocks) == 1, f"{cblocks}" + sdfg1.validate() + + for cblock in cblocks: + parent_sdfg = cblock.parent_graph.sdfg + parent_graph = cblock.parent_graph + xform = branch_elimination.BranchElimination() + xform.conditional = cblock + applied = xform.try_clean(graph=parent_graph, sdfg=parent_sdfg, lift_multi_state=False) + assert applied is False + applied = xform.try_clean(graph=parent_graph, sdfg=parent_sdfg, lift_multi_state=True) + assert applied is True + sdfg1.validate() + + fbpass = EliminateBranches() + fbpass.try_clean = False + fbpass.apply_pass(sdfg1, {}) + cblocks = {n for n, g in sdfg1.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 0, f"{cblocks}" + sdfg1.validate() + + if_cond_1 = np.array([1.2], dtype=np.float64) + offset = np.array([0], dtype=np.int64) + A = np.random.choice([0.001, 3.0], size=(N, )) + B = np.random.randn(N, 3, 3) + C = np.random.choice([0.001, 3.0], size=(N, )) + run_and_compare_sdfg(sdfg1, permissive=False, A=A, B=B, C=C, if_cond_1=if_cond_1[0], offset=offset[0]) + + +def test_try_clean_as_pass(): + # This is a test to check the different configurations of try clean, applicability depends on the SDFG and the pass + sdfg = _multi_state_nested_if.to_sdfg() + fbpass = EliminateBranches() + fbpass.clean_only = True + fbpass.try_clean = False + fbpass.apply_pass(sdfg, {}) + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 2, f"{cblocks}" + fbpass.clean_only = False + fbpass.try_clean = False + fbpass.apply_pass(sdfg, {}) + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1, f"{cblocks}" + fbpass.clean_only = False + fbpass.try_clean = False + fbpass.apply_pass(sdfg, {}) + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1, f"{cblocks}" + fbpass.clean_only = False + fbpass.try_clean = True + fbpass.apply_pass(sdfg, {}) + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 0, f"{cblocks}" + sdfg.validate() + + if_cond_1 = np.array([1.2], dtype=np.float64) + offset = np.array([0], dtype=np.int64) + A = np.random.choice([0.001, 3.0], size=(N, )) + B = np.random.randn(N, 3, 3) + C = np.random.choice([0.001, 3.0], size=(N, )) + run_and_compare_sdfg(sdfg, permissive=False, A=A, B=B, C=C, if_cond_1=if_cond_1[0], offset=offset[0]) + + +def _get_sdfg_with_interstate_array_condition(): + sdfg = dace.SDFG("sd1") + sdfg.add_array("llindex", (4, 4, 4), dtype=dace.int64) + sdfg.add_array("zratio", (4, 4, 4), dtype=dace.float64) + sdfg.add_array("zsolqa", (4, 4, 4), dtype=dace.float64) + sdfg.add_scalar( + "zzratio", + dtype=dace.float64, + transient=True, + storage=dace.dtypes.StorageType.Register, + ) + sdfg.add_symbol("_if_cond_1", dace.int64) + s1 = sdfg.add_state("s1", is_start_block=True) + cb1 = ConditionalBlock("cb1", sdfg, sdfg) + cfg = ControlFlowRegion("cfg1", sdfg, cb1) + cb1.add_branch(condition=CodeBlock("_if_cond_1 == 1"), branch=cfg) + s2 = cfg.add_state("s2", is_start_block=True) + + sdfg.add_edge(s1, cb1, dace.InterstateEdge(assignments={"_if_cond_1": "llindex[2,2,2]"})) + + z1 = s1.add_access("zratio") + zz1 = s1.add_access("zzratio") + zz2 = s2.add_access("zzratio") + zs1 = s2.add_access("zsolqa") + zs2 = s2.add_access("zsolqa") + + t1 = s1.add_tasklet("T_1", {"_in_zratio"}, {"_out_zzratio"}, "_out_zzratio = _in_zratio") + t2 = s2.add_tasklet("T_2", {"_in_zzratio", "_in_zsolqa"}, {"_out_zsolqa"}, "_out_zsolqa = _in_zzratio * _in_zsolqa") + + s1.add_edge(z1, None, t1, "_in_zratio", dace.memlet.Memlet("zratio[3,3,3]")) + s1.add_edge(t1, "_out_zzratio", zz1, None, dace.memlet.Memlet("zzratio[0]")) + s2.add_edge(zz2, None, t2, "_in_zzratio", dace.memlet.Memlet("zzratio[0]")) + s2.add_edge(zs1, None, t2, "_in_zsolqa", dace.memlet.Memlet("zsolqa[2,2,2]")) + s2.add_edge(t2, "_out_zsolqa", zs2, None, dace.memlet.Memlet("zsolqa[2,2,2]")) + + sdfg.validate() + return sdfg + + +def test_sdfg_with_interstate_array_condition(): + sdfg = _get_sdfg_with_interstate_array_condition() + llindex = np.ones(shape=(4, 4, 4), dtype=np.int64) + zsolqa = np.random.choice([0.001, 3.0], size=(4, 4, 4)) + zratio = np.random.choice([0.001, 3.0], size=(4, 4, 4)) + run_and_compare_sdfg( + sdfg, + permissive=False, + llindex=llindex, + zsolqa=zsolqa, + zratio=zratio, + ) + + for n, g in sdfg.all_nodes_recursive(): + if isinstance(n, dace.nodes.Tasklet): + assert "[" not in n.code.as_string, f"Tasklet {n} has code: {n.code.as_string}" + assert "]" not in n.code.as_string, f"Tasklet {n} has code: {n.code.as_string}" + + +@dace.program +def repeated_condition_variables( + a: dace.float64[N, N], + b: dace.float64[N, N], + c: dace.float64[N, N], + conds: dace.float64[4, N], +): + for i, j in dace.map[0:N:1, 0:N:1]: + cond_1 = conds[0, j] + if cond_1 > 0.5: + c[i, j] = a[i, j] * b[i, j] + cond_1 = conds[1, j] + if cond_1 > 0.5: + c[i, j] = a[i, j] * b[i, j] + cond_1 = conds[2, j] + if cond_1 > 0.5: + c[i, j] = a[i, j] * b[i, j] + cond_1 = conds[3, j] + if cond_1 > 0.5: + c[i, j] = a[i, j] * b[i, j] + + +def test_repeated_condition_variables(): + a = np.random.choice([0.001, 3.0], size=(N, N)) + b = np.random.choice([0.001, 3.0], size=(N, N)) + c = np.random.choice([0.001, 3.0], size=(N, N)) + conds = np.random.choice([1.0, 3.0], size=(4, N)) + run_and_compare(repeated_condition_variables, 0, True, a=a, b=b, c=c, conds=conds) + + +def _find_state(root_sdfg: dace.SDFG, node): + for n, g in root_sdfg.all_nodes_recursive(): + if n == node: + return g + return None + + +def test_if_over_map(): + sdfg = if_over_map.to_sdfg() + cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + inner_cblocks = { + n + for n, g in sdfg.all_nodes_recursive() + if isinstance(n, ConditionalBlock) and g is not None and g.sdfg.parent_nsdfg_node is not None + } + assert len(inner_cblocks) == 1 + + xform = branch_elimination.BranchElimination() + xform.conditional = cblocks.pop() + xform.parent_nsdfg_state = None + assert xform.can_be_applied(graph=xform.conditional.parent_graph, + expr_index=0, + sdfg=xform.conditional.parent_graph.sdfg) is False + + xform = branch_elimination.BranchElimination() + xform.conditional = inner_cblocks.pop() + xform.parent_nsdfg_state = _find_state(sdfg, xform.conditional.sdfg.parent_nsdfg_node) + assert xform.can_be_applied(graph=xform.conditional.parent_graph, + expr_index=0, + sdfg=xform.conditional.parent_graph.sdfg) is True + + +def test_if_over_map_with_top_level_tasklets(): + sdfg = if_over_map.to_sdfg() + cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + inner_cblocks = { + n + for n, g in sdfg.all_nodes_recursive() + if isinstance(n, ConditionalBlock) and g is not None and g.sdfg.parent_nsdfg_node is not None + } + assert len(inner_cblocks) == 1 + + xform = branch_elimination.BranchElimination() + xform.conditional = cblocks.pop() + xform.parent_nsdfg_state = None + assert xform.can_be_applied(graph=xform.conditional.parent_graph, + expr_index=0, + sdfg=xform.conditional.parent_graph.sdfg) is False + + xform = branch_elimination.BranchElimination() + xform.conditional = inner_cblocks.pop() + xform.parent_nsdfg_state = None + assert xform.can_be_applied(graph=xform.conditional.parent_graph, + expr_index=0, + sdfg=xform.conditional.parent_graph.sdfg) is False + + xform.parent_nsdfg_state = _find_state(sdfg, xform.conditional.sdfg.parent_nsdfg_node) + assert xform.can_be_applied(graph=xform.conditional.parent_graph, + expr_index=0, + sdfg=xform.conditional.parent_graph.sdfg) is True + + +def test_can_be_applied_parameters_on_nested_sdfg(): + sdfg = nested_if.to_sdfg() + cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 0 + inner_cblocks = { + n + for n, g in sdfg.all_nodes_recursive() + if isinstance(n, ConditionalBlock) and g is not None and g.sdfg.parent_nsdfg_node is not None + } + assert len(inner_cblocks) == 2 + + full_inner_cblocks = { + n + for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock) and g is not None and g != g.sdfg + } + assert len(full_inner_cblocks) == 1 + + upper_inner_cblocks = { + n + for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock) and g is not None and g == g.sdfg + } + assert len(upper_inner_cblocks) == 1 + + xform = branch_elimination.BranchElimination() + xform.conditional = full_inner_cblocks.pop() + + assert xform.can_be_applied(graph=xform.conditional.parent_graph, + expr_index=0, + sdfg=xform.conditional.parent_graph.sdfg) is False + + xform.parent_nsdfg_state = _find_state(sdfg, xform.conditional.sdfg.parent_nsdfg_node) + + assert xform.can_be_applied(graph=xform.conditional.parent_graph, + expr_index=0, + sdfg=xform.conditional.parent_graph.sdfg) is True + + +@dace.program +def non_trivial_subset_after_combine_tasklet( + a: dace.float64[N, N], + b: dace.float64[N, N], + c: dace.float64[N, N], + d: dace.float64[N, N], + e: dace.float64[N, N], + f: dace.float64, + g: dace.float64[N, N], +): + _if_cond_1 = a[1, 2] > (1 - f) + if _if_cond_1: + tc1 = b[3, 4] + c[3, 4] + d[3, 4] + tc2 = tc1 * 2.3 + tc3 = max(0.0, tc2) + g[6, 6] = tc3 + else: + tc1 = b[3, 4] + c[3, 4] + e[5, 5] + tc2 = tc1 * 2.3 + tc3 = max(0.767, tc2) + tc4 = e[4, 4] * tc3 + g[6, 6] = tc4 + + +def test_non_trivial_subset_after_combine_tasklet(): + A = np.random.choice([0.001, 5.0], size=(N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.random.choice([0.001, 5.0], size=(N, N)) + E = np.random.choice([0.001, 5.0], size=(N, N)) + F = np.random.randn(1, ) + G = np.random.choice([0.001, 5.0], size=(N, N)) + run_and_compare( + non_trivial_subset_after_combine_tasklet, + 0, + True, + a=A, + b=B, + c=C, + d=D, + e=E, + f=F[0], + g=G, + ) + + +@dace.program +def split_on_disjoint_subsets( + a: dace.float64[N, N, 2], + b: dace.float64[N, N], + c: dace.float64[N, N], + d: dace.float64, +): + _if_cond_1 = b[1, 2] > (1.0 - d) + if _if_cond_1: + tc1 = b[6, 6] * c[6, 6] + a[6, 6, 0] = tc1 + else: + tc1 = b[6, 6] * c[6, 6] + a[6, 6, 1] = tc1 + c[3, 3] = 0.0 + b[4, 4] = 0.0 + + +@dace.program +def split_on_disjoint_subsets_nested( + a: dace.float64[N, N, 2], + b: dace.float64[N, N], + c: dace.float64[N, N], + d: dace.float64, +): + for i in dace.map[0:N]: + _if_cond_1 = b[i, 2] > (1.0 - d) + if _if_cond_1: + tc1 = b[i, 6] * c[i, 6] + a[i, 6, 0] = tc1 + else: + tc1 = b[i, 6] * c[i, 6] + a[i, 6, 1] = tc1 + c[i, 3] = 0.0 + b[i, 4] = 0.0 + + +def test_split_on_disjoint_subsets(): + A = np.random.choice([0.001, 5.0], size=(N, N, 2)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + sdfg = split_on_disjoint_subsets.to_sdfg() + + # Is disjoit subset needs to return true + cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + cblock = cblocks.pop() + + (cond0, body0), (cond1, body1) = cblock.branches[0:2] + assert len(body0.nodes()) == 1 + assert len(body1.nodes()) == 1 + state0 = body0.nodes()[0] + state1 = body1.nodes()[0] + assert isinstance(state0, dace.SDFGState) + assert isinstance(state1, dace.SDFGState) + + xform = branch_elimination.BranchElimination() + xform.conditional = cblock + assert xform._is_disjoint_subset(state0, state1) is True + + # If we split we will make not applicable anymore + xform._split_branches(cblock.parent_graph, cblock) + + sdfg.validate() + + run_and_compare( + split_on_disjoint_subsets, + 0, + True, + a=A, + b=B, + c=C, + d=D[0], + ) + + +def test_split_on_disjoint_subsets_nested(): + A = np.random.choice([0.001, 5.0], size=(N, N, 2)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + sdfg = split_on_disjoint_subsets_nested.to_sdfg() + + # Is disjoit subset needs to return true + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + cblock = cblocks.pop() + + (cond0, body0), (cond1, body1) = cblock.branches[0:2] + assert len(body0.nodes()) == 1 + assert len(body1.nodes()) == 1 + state0 = body0.nodes()[0] + state1 = body1.nodes()[0] + assert isinstance(state0, dace.SDFGState) + assert isinstance(state1, dace.SDFGState) + + xform = branch_elimination.BranchElimination() + xform.conditional = cblock + assert xform._is_disjoint_subset(state0, state1) is True + + # If we split we will make not applicable anymore + xform._split_branches(cblock.parent_graph, cblock) + + sdfg.validate() + + run_and_compare( + split_on_disjoint_subsets_nested, + 0, + True, + a=A, + b=B, + c=C, + d=D[0], + ) + + +@dace.program +def write_to_transient( + a: dace.float64[N, N], + b: dace.float64[N, N], + d: dace.int64, +): + for i in dace.map[0:N]: + zmdn = 0.0 + _if_cond_1 = d < 5.0 + if _if_cond_1: + tc1 = b[i, 6] + a[i, 6] + zmdn = tc1 + b[i, 3] = zmdn + b[i, 3] = zmdn + + +@dace.program +def write_to_transient_two( + a: dace.float64[N, N], + b: dace.float64[N, N], + d: dace.int64, +): + for i in dace.map[0:N]: + zmdn = 0.0 + _if_cond_1 = d < 5.0 + if _if_cond_1: + tc1 = b[i, 6] + a[i, 6] + zmdn = tc1 + else: + zmdn = 1.0 + b[i, 3] = zmdn + b[i, 3] = zmdn + + +def test_write_to_transient(): + A = np.random.choice([0.001, 5.0], size=(N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + run_and_compare( + write_to_transient, + 0, + True, + a=A, + b=B, + d=D[0], + ) + + +def test_write_to_transient_two(): + A = np.random.choice([0.001, 5.0], size=(N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + run_and_compare( + write_to_transient_two, + 0, + True, + a=A, + b=B, + d=D[0], + ) + + +def test_double_empty_state(): + A = np.random.choice([0.001, 5.0], size=(N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + sdfg = write_to_transient_two.to_sdfg() + + nested_sdfgs = {(n, g) for (n, g) in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} + + sdfg.add_state_before(sdfg.start_block, label=f"empty_prepadding_{sdfg.label}", is_start_block=True) + + for nsdfg, parent_state in nested_sdfgs: + nsdfg.sdfg.add_state_before(nsdfg.sdfg.start_block, + label=f"empty_prepardding_{nsdfg.sdfg.label}", + is_start_block=True) + + run_and_compare_sdfg( + sdfg, + permissive=False, + a=A, + b=B, + d=D[0], + ) + + +@dace.program +def complicated_pattern_for_manual_clean_up_one( + a: dace.float64[N, N], + b: dace.float64[N, N], + c: dace.float64[N, N], + d: dace.float64, +): + for i in dace.map[0:N]: + _if_cond_1 = i < d + if _if_cond_1: + zalfaw = a[i, 0] + tc1 = b[i, 6] + zalfaw + tc2 = b[i, 3] * zalfaw + tc1 + c[i, 0] = tc2 + else: + c[i, 0] = 0.0 + + +def test_complicated_pattern_for_manual_clean_up_one(): + A = np.random.choice([0.001, 5.0], size=(N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + sdfg = complicated_pattern_for_manual_clean_up_one.to_sdfg() + + nested_sdfgs = {(n, g) for (n, g) in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} + + # Force scalar promotion like in CloudSC + ssp = ScalarToSymbolPromotion() + ssp.integers_only = False + ssp.transients_only = True + scalar_names = { + arr_name + for arr_name, arr in sdfg.arrays.items() + if isinstance(arr, dace.data.Scalar) or (isinstance(arr, dace.data.Array) and (arr.shape == [ + 1, + ] or arr.shape == (1, ))) + }.difference({"zalfaw"}) + ssp.ignore = scalar_names + ssp.apply_pass(sdfg, {}) + + for nsdfg, parent_state in nested_sdfgs: + ssp = ScalarToSymbolPromotion() + ssp.integers_only = False + ssp.transients_only = True + scalar_names = { + arr_name + for arr_name, arr in nsdfg.sdfg.arrays.items() + if isinstance(arr, dace.data.Scalar) or (isinstance(arr, dace.data.Array) and (arr.shape == [ + 1, + ] or arr.shape == (1, ))) + }.difference({"zalfaw"}) + ssp.ignore = scalar_names + ssp.apply_pass(nsdfg.sdfg, {}) + + for nsdfg, parent_state in nested_sdfgs: + for cb in nsdfg.sdfg.all_control_flow_regions(): + if isinstance(cb, ConditionalBlock): + xform = branch_elimination.BranchElimination() + xform.parent_nsdfg_state = parent_state + xform.conditional = cb + assert xform.can_be_applied(graph=cb.parent_graph, expr_index=0, sdfg=cb.sdfg) is False + assert xform.symbol_reused_outside_conditional(sym_name="zalfaw") is False + assert len(cb.branches) == 2 + # Clean-up should be able to catch this pattern + xform.demote_branch_only_symbols_appearing_only_a_single_branch_to_scalars_and_try_fuse( + graph=cb.parent_graph, sdfg=cb.sdfg) + (cond0, body0), (cond1, body1) = cb.branches[0:2] + assert len(body0.nodes()) == 1 + assert len(body1.nodes()) == 1 + assert all({isinstance(n, dace.SDFGState) for n in body0.nodes()}) + assert all({isinstance(n, dace.SDFGState) for n in body1.nodes()}) + + +def test_try_clean_on_complicated_pattern_for_manual_clean_up_one(): + A = np.random.choice([0.001, 5.0], size=(N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + sdfg = complicated_pattern_for_manual_clean_up_one.to_sdfg() + + nested_sdfgs = {(n, g) for (n, g) in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} + + # Force scalar promotion like in CloudSC + ssp = ScalarToSymbolPromotion() + ssp.integers_only = False + ssp.transients_only = True + scalar_names = { + arr_name + for arr_name, arr in sdfg.arrays.items() + if isinstance(arr, dace.data.Scalar) or (isinstance(arr, dace.data.Array) and (arr.shape == [ + 1, + ] or arr.shape == (1, ))) + }.difference({"zalfaw"}) + ssp.ignore = scalar_names + ssp.apply_pass(sdfg, {}) + + for nsdfg, parent_state in nested_sdfgs: + ssp = ScalarToSymbolPromotion() + ssp.integers_only = False + ssp.transients_only = True + scalar_names = { + arr_name + for arr_name, arr in nsdfg.sdfg.arrays.items() + if isinstance(arr, dace.data.Scalar) or (isinstance(arr, dace.data.Array) and (arr.shape == [ + 1, + ] or arr.shape == (1, ))) + }.difference({"zalfaw"}) + ssp.ignore = scalar_names + ssp.apply_pass(nsdfg.sdfg, {}) + + run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0]) + + branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" + + +@dace.program +def complicated_pattern_for_manual_clean_up_two( + a: dace.float64[N, N], + b: dace.float64[N, N], + c: dace.float64[N, N], + d: dace.float64, + e: dace.float64, +): + for i in dace.map[0:N]: + _if_cond_1 = i < d + if _if_cond_1: + zlcrit = d + zalfaw = a[i, 0] + tc1 = b[i, 6] + zalfaw + tc2 = b[i, 3] * zalfaw + tc1 + c[i, 0] = tc2 + else: + zlcrit = e + c[i, 0] = 0.0 + a[i, 3] = zlcrit * 2.0 + + +def test_try_clean_on_complicated_pattern_for_manual_clean_up_two(): + A = np.random.choice([0.001, 5.0], size=(N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.ones([ + 1, + ], dtype=np.float64) + E = np.ones([ + 1, + ], dtype=np.float64) + sdfg = complicated_pattern_for_manual_clean_up_two.to_sdfg() + + nested_sdfgs = {(n, g) for (n, g) in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} + + # Force scalar promotion like in CloudSC + ssp = ScalarToSymbolPromotion() + ssp.integers_only = False + ssp.transients_only = True + scalar_names = { + arr_name + for arr_name, arr in sdfg.arrays.items() + if isinstance(arr, dace.data.Scalar) or (isinstance(arr, dace.data.Array) and (arr.shape == [ + 1, + ] or arr.shape == (1, ))) + }.difference({"zlcrit"}) + ssp.ignore = scalar_names + ssp.apply_pass(sdfg, {}) + + for nsdfg, parent_state in nested_sdfgs: + ssp = ScalarToSymbolPromotion() + ssp.integers_only = False + ssp.transients_only = True + scalar_names = { + arr_name + for arr_name, arr in nsdfg.sdfg.arrays.items() + if isinstance(arr, dace.data.Scalar) or (isinstance(arr, dace.data.Array) and (arr.shape == [ + 1, + ] or arr.shape == (1, ))) + }.difference({"zlcrit"}) + ssp.ignore = scalar_names + ssp.apply_pass(nsdfg.sdfg, {}) + + run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0], e=E[0]) + + branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" + + +@dace.program +def single_assignment(a: dace.float64[ + N, +], _if_cond_1: dace.float64): + for i in dace.map[0:N]: + if _if_cond_1 > 0.0: + a[i] = 0.0 + + +@dace.program +def single_assignment_cond_from_scalar(a: dace.float64[512]): + for i in dace.map[0:256]: + _if_cond_1 = a[256 + i] + if _if_cond_1 > 0.0: + a[i] = 0.0 + + +def test_single_assignment(): + if_cond_1 = np.array([1], dtype=np.float64) + A = np.ones(shape=(N, ), dtype=np.float64) + run_and_compare(single_assignment, 0, True, a=A, _if_cond_1=if_cond_1[0]) + + +def test_single_assignment_cond_from_scalar(): + A = np.ones(shape=(512, ), dtype=np.float64) + before = single_assignment_cond_from_scalar.to_sdfg() + before.name = "non_fusion_single_assignment_cond_from_scalar" + before.compile() + run_and_compare(single_assignment_cond_from_scalar, 0, True, a=A) + + +def _get_sdfg_with_condition_from_transient_scalar() -> dace.SDFG: + sdfg = dace.SDFG("sd1") + + sdfg.add_scalar("zacond_0", transient=True, dtype=dace.float64) + sdfg.add_scalar("_if_cond_41", transient=True, dtype=dace.float64) + sdfg.add_array("zsolac", (N, ), dace.float64) + sdfg.add_array("zlcond2", (N, ), dace.float64) + sdfg.add_array("za", (N, ), dace.float64) + sdfg.add_symbol("_if_cond_42", dace.float64) + + s1 = sdfg.add_state("s1", is_start_block=True) + cb1 = ConditionalBlock(label="cb1", sdfg=sdfg, parent=sdfg) + cfg1 = ControlFlowRegion(label="cfg1", sdfg=sdfg, parent=cb1) + cb1.add_branch(condition=CodeBlock("_if_cond_41 == 1"), branch=cfg1) + cb2 = ConditionalBlock(label="cb2", sdfg=sdfg, parent=sdfg) + cfg2 = ControlFlowRegion(label="cfg2", sdfg=sdfg, parent=cb2) + cb2.add_branch(condition=CodeBlock("_if_cond_42 == 1"), branch=cfg2) + s2 = cfg1.add_state("s2", is_start_block=True) + s3 = cfg2.add_state("s3", is_start_block=True) + + sdfg.add_edge(s1, cb1, InterstateEdge()) + sdfg.add_edge(cb1, cb2, InterstateEdge()) + s4 = sdfg.add_state_after(state=cb2, label="s4") + + # _if_cond_42 is a symbol (free symbol) + # Calculate _if_cond_41, zacond_0 in s1 + # Calculate zacond_1 in s2 + # Calculate zsolac using zacond_0 + + t1 = s1.add_tasklet(name="t1", + inputs={"_in1", "_in2"}, + outputs={"_out"}, + code="_out = ((_in1 < 0.3) and (_in2 < 0.5))") + s1.add_edge(s1.add_access("za"), None, t1, "_in1", dace.memlet.Memlet("za[4]")) + s1.add_edge(s1.add_access("zlcond2"), None, t1, "_in2", dace.memlet.Memlet("zlcond2[4]")) + s1.add_edge(t1, "_out", s1.add_access("_if_cond_41"), None, dace.memlet.Memlet("_if_cond_41[0]")) + + t2 = s2.add_tasklet(name="t2_1", inputs=set(), outputs={"_out"}, code="_out = 0.0") + s2.add_edge(t2, "_out", s2.add_access("zacond_0"), None, dace.memlet.Memlet("zacond_0[0]")) + + t2_2 = s2.add_tasklet(name="t2_2", inputs=set(), outputs={"_out"}, code="_out = 0.0") + s2.add_edge(t2_2, "_out", s2.add_access("zlcond2"), None, dace.memlet.Memlet("zlcond2[4]")) + + t3 = s3.add_tasklet(name="t3", inputs=set(), outputs={"_out"}, code="_out = 0.5") + s3.add_edge(t3, "_out", s3.add_access("zacond_0"), None, dace.memlet.Memlet("zacond_0[0]")) + + t4 = s4.add_tasklet(name="t4", inputs={"_in1", "_in2"}, outputs={"_out"}, code="_out = _in1 + _in2") + s4.add_edge(t4, "_out", s4.add_access("zsolac"), None, dace.memlet.Memlet("zsolac[4]")) + s4.add_edge(s4.add_access("zsolac"), None, t4, "_in1", dace.memlet.Memlet("zsolac[4]")) + s4.add_edge(s4.add_access("zacond_0"), None, t4, "_in2", dace.memlet.Memlet("zacond_0[0]")) + + sdfg.validate() + return sdfg + + +def test_condition_from_transient_scalar(): + zsolac = np.random.choice([8.0, 11.0], size=(N, )) + zlcond2 = np.random.choice([8.0, 11.0], size=(N, )) + za = np.random.choice([8.0, 11.0], size=(N, )) + _if_cond_42 = np.random.choice([8.0, 11.0], size=(1, )) + sdfg = _get_sdfg_with_condition_from_transient_scalar() + + run_and_compare_sdfg(sdfg, permissive=False, zsolac=zsolac, zlcond2=zlcond2, za=za, _if_cond_42=_if_cond_42[0]) + + branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" + + +def _get_disjoint_chain_sdfg() -> dace.SDFG: + sd1 = dace.SDFG("disjoint_chain") + cb1 = ConditionalBlock("cond_if_cond_58", sdfg=sd1, parent=sd1) + ss1 = sd1.add_state(label="pre", is_start_block=True) + sd1.add_node(cb1, is_start_block=False) + + cfg1 = ControlFlowRegion(label="cond_58_true", sdfg=sd1, parent=cb1) + s1 = cfg1.add_state("main_1", is_start_block=True) + cfg2 = ControlFlowRegion(label="cond_58_false", sdfg=sd1, parent=cb1) + s2 = cfg2.add_state("main_2", is_start_block=True) + + cb1.add_branch( + condition=CodeBlock("_if_cond_58 == 1"), + branch=cfg1, + ) + cb1.add_branch( + condition=None, + branch=cfg2, + ) + for arr_name, shape in [ + ("zsolqa", (N, 5, 5)), + ("zrainaut", (N, )), + ("zrainacc", (N, )), + ("ztp1", (N, )), + ]: + sd1.add_array(arr_name, shape, dace.float64) + sd1.add_scalar("rtt", dace.float64) + sd1.add_symbol("_if_cond_58", dace.float64) + sd1.add_symbol("_for_it_52", dace.int64) + sd1.add_edge(src=ss1, dst=cb1, data=InterstateEdge(assignments={ + "_if_cond_58": "ztp1[_for_it_52] <= rtt", + }, )) + + for state, d1_access_str, zsolqa_access_str, zsolqa_access_str_rev in [ + (s1, "_for_it_52", "_for_it_52,3,0", "_for_it_52,0,3"), (s2, "_for_it_52", "_for_it_52,2,0", "_for_it_52,0,2") + ]: + zrainaut = state.add_access("zrainaut") + zrainacc = state.add_access("zrainacc") + zsolqa1 = state.add_access("zsolqa") + zsolqa2 = state.add_access("zsolqa") + zsolqa3 = state.add_access("zsolqa") + zsolqa4 = state.add_access("zsolqa") + zsolqa5 = state.add_access("zsolqa") + for i, (tasklet_code, in1, instr1, in2, instr2, out, outstr) in enumerate([ + ("_out = _in1 + _in2", zrainaut, d1_access_str, zsolqa1, zsolqa_access_str, zsolqa2, zsolqa_access_str), + ("_out = _in1 + _in2", zrainacc, d1_access_str, zsolqa2, zsolqa_access_str, zsolqa3, zsolqa_access_str), + ("_out = (-_in1) + _in2", zrainaut, d1_access_str, zsolqa3, zsolqa_access_str_rev, zsolqa4, + zsolqa_access_str_rev), + ("_out = (-_in1) + _in2", zrainacc, d1_access_str, zsolqa4, zsolqa_access_str_rev, zsolqa5, + zsolqa_access_str_rev), + ]): + t1 = state.add_tasklet("t1", {"_in1", "_in2"}, {"_out"}, tasklet_code) + state.add_edge(in1, None, t1, "_in1", dace.memlet.Memlet(f"{in1.data}[{instr1}]")) + state.add_edge(in2, None, t1, "_in2", dace.memlet.Memlet(f"{in2.data}[{instr2}]")) + state.add_edge(t1, "_out", out, None, dace.memlet.Memlet(f"{out.data}[{outstr}]")) + + sd1.validate() + + sd2 = dace.SDFG("sd2") + p_s1 = sd2.add_state("p_s1", is_start_block=True) + + map_entry, map_exit = p_s1.add_map(name="map1", ndrange={"_for_it_52": dace.subsets.Range([(0, N - 1, 1)])}) + nsdfg = p_s1.add_nested_sdfg(sdfg=sd1, + inputs={"zsolqa", "ztp1", "zrainaut", "zrainacc", "rtt"}, + outputs={"zsolqa"}, + symbol_mapping={"_for_it_52": "_for_it_52"}) + for arr_name, shape in [("zsolqa", (N, 5, 5)), ("zrainaut", (N, )), ("zrainacc", (N, )), ("ztp1", (N, ))]: + sd2.add_array(arr_name, shape, dace.float64) + sd2.add_scalar("rtt", dace.float64) + for input_name in {"zsolqa", "ztp1", "zrainaut", "zrainacc", "rtt"}: + a = p_s1.add_access(input_name) + p_s1.add_edge(a, None, map_entry, f"IN_{input_name}", + dace.memlet.Memlet.from_array(input_name, sd2.arrays[input_name])) + p_s1.add_edge(map_entry, f"OUT_{input_name}", nsdfg, input_name, + dace.memlet.Memlet.from_array(input_name, sd2.arrays[input_name])) + map_entry.add_in_connector(f"IN_{input_name}") + map_entry.add_out_connector(f"OUT_{input_name}") + for output_name in {"zsolqa"}: + a = p_s1.add_access(output_name) + p_s1.add_edge(map_exit, f"OUT_{output_name}", a, None, + dace.memlet.Memlet.from_array(output_name, sd2.arrays[output_name])) + p_s1.add_edge(nsdfg, output_name, map_exit, f"IN_{output_name}", + dace.memlet.Memlet.from_array(output_name, sd2.arrays[output_name])) + map_exit.add_in_connector(f"IN_{output_name}") + map_exit.add_out_connector(f"OUT_{output_name}") + + nsdfg.sdfg.parent_nsdfg_node = nsdfg + + sd1.validate() + sd2.validate() + return sd2, p_s1 + + +@pytest.mark.parametrize("rtt_val", [0.0, 4.0, 6.0]) +def test_disjoint_chain_split_branch_only(rtt_val): + sdfg, nsdfg_parent_state = _get_disjoint_chain_sdfg() + zsolqa = np.random.choice([0.001, 5.0], size=(N, 5, 5)) + zrainacc = np.random.choice([0.001, 5.0], size=(N, )) + zrainaut = np.random.choice([0.001, 5.0], size=(N, )) + ztp1 = np.random.choice([3.5, 5.0], size=(N, )) + rtt = np.random.choice([rtt_val], size=(1, )) + + copy_sdfg = copy.deepcopy(sdfg) + arrays = {"zsolqa": zsolqa, "zrainacc": zrainacc, "zrainaut": zrainaut, "ztp1": ztp1, "rtt": rtt[0]} + + sdfg.validate() + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + sdfg(**out_no_fuse) + + # Run SDFG version (with transformation) + xform = branch_elimination.BranchElimination() + cblocks = {n for n, g in copy_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + cblock = cblocks.pop() + + xform.conditional = cblock + xform.parent_nsdfg_state = nsdfg_parent_state + xform.sequentialize_if_else_branch_if_disjoint_subsets(cblock.parent_graph) + + out_fused = {k: v.copy() for k, v in arrays.items()} + copy_sdfg(**out_fused) + + for name in arrays.keys(): + np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) + + +@pytest.mark.parametrize("rtt_val", [0.0, 4.0, 6.0]) +def test_disjoint_chain(rtt_val): + sdfg, _ = _get_disjoint_chain_sdfg() + zsolqa = np.random.choice([0.001, 5.0], size=(N, 5, 5)) + zrainacc = np.random.choice([0.001, 5.0], size=(N, )) + zrainaut = np.random.choice([0.001, 5.0], size=(N, )) + ztp1 = np.random.choice([3.5, 5.0], size=(N, )) + rtt = np.random.choice([rtt_val], size=(1, )) + + run_and_compare_sdfg(sdfg, + permissive=False, + zsolqa=zsolqa, + zrainacc=zrainacc, + zrainaut=zrainaut, + ztp1=ztp1, + rtt=rtt[0]) + + +@dace.program +def pattern_from_cloudsc_one( + A: dace.float64[2, N, N], + B: dace.float64[N, N], + c: dace.float64, + D: dace.float64[N, N], + E: dace.float64[N, N], +): + for i in dace.map[0:N]: + for j in dace.map[0:N]: + B[i, j] = A[1, i, j] + A[0, i, j] + _if_cond_5 = B[i, j] > c + if _if_cond_5: + D[i, j] = B[i, j] / A[0, i, j] + E[i, j] = 1.0 - D[i, j] + else: + D[i, j] = 0.0 + E[i, j] = 0.0 + + +@pytest.mark.parametrize("c_val", [0.0, 1.0, 6.0]) +def test_pattern_from_cloudsc_one(c_val): + A = np.random.choice([0.001, 5.0], size=( + 2, + N, + N, + )) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.array([c_val], ) + D = np.random.choice([0.001, 5.0], size=(N, N)) + E = np.random.choice([0.001, 5.0], size=(N, N)) + + run_and_compare(pattern_from_cloudsc_one, 0, True, A=A, B=B, c=C[0], D=D, E=E) + + +@dace.program +def map_param_usage( + a: dace.float64[N, N], + b: dace.float64[N, N], + d: dace.float64[N, N], +): + for i in dace.map[0:N]: + _if_cond_1 = d[i, i] > 0.0 + if _if_cond_1: + tc1 = d[i, i] + a[i, i] + zmdn = tc1 + else: + zmdn = 0.0 + b[i, i] = zmdn + b[i, i] = zmdn + + +def test_can_be_applied_on_map_param_usage(): + A = np.random.choice([0.001, 5.0], size=( + N, + N, + )) + B = np.random.choice([0.001, 5.0], size=(N, N)) + D = np.random.choice([0.001, 5.0], size=(N, N)) + + sdfg = map_param_usage.to_sdfg() + + xform = branch_elimination.BranchElimination() + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + xform.conditional = cblocks.pop() + xform.parent_nsdfg_state = _find_state(sdfg, xform.conditional.sdfg.parent_nsdfg_node) + + assert xform.can_be_applied(xform.conditional.parent_graph, 0, xform.conditional.sdfg, False) + + run_and_compare(map_param_usage, 0, True, a=A, b=B, d=D) + + +def _get_safe_map_param_use_in_nested_sdfg() -> dace.SDFG: + inner_sdfg = dace.SDFG("inner") + outer_sdfg = dace.SDFG("outer") + + inner_symbol_mapping = { + "_for_it_37": "_for_it_37", + } + for arr_name in ["zsolac", "zacust", "zfinalsum"]: + inner_sdfg.add_array(arr_name, (N, ), dace.float64) + outer_sdfg.add_array(arr_name, (N, ), dace.float64) + inner_inputs = {"zsolac", "zacust", "zfinalsum"} + inner_outputs = {"zacust", "zsolac"} + + i_s1 = inner_sdfg.add_state("i_s1", is_start_block=True) + i_cb1 = ConditionalBlock("i_cb1", sdfg=inner_sdfg, parent=inner_sdfg) + inner_sdfg.add_node(i_cb1) + inner_sdfg.add_edge(i_s1, i_cb1, InterstateEdge(assignments={"_if_cond_22": "zfinalsum[_for_it_37] < 1e-14"})) + inner_sdfg.add_symbol("_if_cond_22", dace.int32) + inner_sdfg.add_symbol("_for_it_37", dace.int32) + + i_cfg1 = ControlFlowRegion("i_cfg1", sdfg=inner_sdfg, parent=i_cb1) + i_cfg1_s1 = i_cfg1.add_state("i_cfg1_s1", is_start_block=True) + + t1 = i_cfg1_s1.add_tasklet("t1", inputs={}, outputs={"_out"}, code="_out = 0.0") + i_cfg1_s1.add_edge(t1, "_out", i_cfg1_s1.add_access("zacust"), None, dace.memlet.Memlet("zacust[_for_it_37]")) + + i_cb1.add_branch(CodeBlock("_if_cond_22 == 1"), i_cfg1) + + i_s2 = inner_sdfg.add_state_after(i_cb1, label="i_s2") + t2 = i_s2.add_tasklet("t2", inputs={"_in1", "_in2"}, outputs={"_out"}, code="_out = _in1 + _in2") + for in_name, conn_name in [("zacust", "_in1"), ("zsolac", "_in2")]: + i_s2.add_edge(i_s2.add_access(in_name), None, t2, conn_name, dace.memlet.Memlet(f"{in_name}[_for_it_37]")) + i_s2.add_edge(t2, "_out", i_s2.add_access("zsolac"), None, dace.memlet.Memlet(f"zsolac[_for_it_37]")) + + o_s1 = outer_sdfg.add_state("o_s1", is_start_block=True) + nsdfg = o_s1.add_nested_sdfg(sdfg=inner_sdfg, + inputs=inner_inputs, + outputs=inner_outputs, + symbol_mapping=inner_symbol_mapping) + + map_entry, map_exit = o_s1.add_map(name="m1", ndrange={ + "_for_it_37": dace.subsets.Range([(0, N - 1, 1)]), + }) + for in_name in inner_inputs: + o_s1.add_edge(o_s1.add_access(in_name), None, map_entry, f"IN_{in_name}", + dace.memlet.Memlet.from_array(in_name, o_s1.sdfg.arrays[in_name])) + map_entry.add_in_connector(f"IN_{in_name}") + map_entry.add_out_connector(f"OUT_{in_name}") + o_s1.add_edge(map_entry, f"OUT_{in_name}", nsdfg, in_name, + dace.memlet.Memlet.from_array(in_name, o_s1.sdfg.arrays[in_name])) + for out_name in inner_outputs: + o_s1.add_edge(nsdfg, out_name, map_exit, f"IN_{out_name}", + dace.memlet.Memlet.from_array(out_name, o_s1.sdfg.arrays[out_name])) + map_exit.add_in_connector(f"IN_{out_name}") + map_exit.add_out_connector(f"OUT_{out_name}") + o_s1.add_edge(map_exit, f"OUT_{out_name}", o_s1.add_access(out_name), None, + dace.memlet.Memlet.from_array(out_name, o_s1.sdfg.arrays[out_name])) + + outer_sdfg.validate() + return outer_sdfg + + +def test_safe_map_param_use_in_nested_sdfg(): + sdfg = _get_safe_map_param_use_in_nested_sdfg() + sdfg.validate() + + for n, g in sdfg.all_nodes_recursive(): + if isinstance(n, ConditionalBlock): + xform = branch_elimination.BranchElimination() + xform.conditional = n + xform.parent_nsdfg_state = _find_state(sdfg, g.sdfg.parent_nsdfg_node) + assert xform.can_be_applied(graph=g, expr_index=0, sdfg=g.sdfg, permissive=False) + assert xform.can_be_applied(graph=g, expr_index=0, sdfg=g.sdfg, permissive=True) + + # "zsolac", "zacust", "zlfinalsum" + zsolac = np.random.choice([0.001, 5.0], size=(N, )) + zfinalsum = np.random.choice([0.001, 5.0], size=(N, )) + zacust = np.random.choice([0.001, 5.0], size=(N, )) + run_and_compare_sdfg(sdfg, False, zsolac=zsolac, zfinalsum=zfinalsum, zacust=zacust) + + +def _get_nsdfg_with_return(return_arr: bool) -> dace.SDFG: + inner_sdfg = dace.SDFG("inner") + outer_sdfg = dace.SDFG("outer") + + inner_symbol_mapping = {} + for outer_arr_name in ["ztp"]: + outer_sdfg.add_array(outer_arr_name, (N, N), dace.float64) + for outer_scalar_name in ["rtt"]: + outer_sdfg.add_scalar(outer_scalar_name, dace.float64) + outer_sdfg.add_array("zalfa_1", (1, ), dace.float64) + if return_arr: + inner_sdfg.add_array("foedelta__ret", (1, ), dace.float64) + else: + inner_sdfg.add_scalar("foedelta__ret", dace.float64) + for inner_scalar_name in ["ptare_var_0", "rtt_var_1"]: + inner_sdfg.add_scalar(inner_scalar_name, dace.float64) + for inner_tmp_name in ["tmp_call_103", "tmp_call_1"]: + inner_sdfg.add_scalar(inner_tmp_name, dace.float64, transient=True) + + inner_inputs = {"ptare_var_0", "rtt_var_1"} + inner_outputs = {"foedelta__ret"} + inner_to_outer_name_mapping_in = { + "ptare_var_0": ("rtt", "[0]"), + "rtt_var_1": ("ztp", "[4,4]"), + } + inner_to_outer_name_mapping_out = {"foedelta__ret": ("zalfa_1", "[0]")} + + i_s1 = inner_sdfg.add_state("i_s1", is_start_block=True) + i_cb1 = ConditionalBlock("i_cb1", sdfg=inner_sdfg, parent=inner_sdfg) + inner_sdfg.add_node(i_cb1) + inner_sdfg.add_edge(i_s1, i_cb1, InterstateEdge()) + + i_cfg1 = ControlFlowRegion("i_cfg1", sdfg=inner_sdfg, parent=i_cb1) + i_cfg1_s1 = i_cfg1.add_state("i_cfg1_s1", is_start_block=True) + i_cfg2 = ControlFlowRegion("i_cfg2", sdfg=inner_sdfg, parent=i_cb1) + i_cfg2_s1 = i_cfg2.add_state("i_cfg2_s1", is_start_block=True) + + t1 = i_cfg1_s1.add_tasklet("t1", inputs={}, outputs={"_out"}, code="_out = 0.0") + i_cfg1_s1.add_edge(t1, "_out", i_cfg1_s1.add_access("tmp_call_103"), None, dace.memlet.Memlet("tmp_call_103[0]")) + t2 = i_cfg2_s1.add_tasklet("t2", inputs={}, outputs={"_out"}, code="_out = 1.0") + taccess_1 = i_cfg2_s1.add_access("tmp_call_1") + i_cfg2_s1.add_edge(t2, "_out", taccess_1, None, dace.memlet.Memlet("tmp_call_1[0]")) + t3 = i_cfg2_s1.add_tasklet("t3", inputs={"_in1"}, outputs={"_out"}, code="_out = (- _in1)") + i_cfg2_s1.add_edge(taccess_1, None, t3, "_in1", dace.memlet.Memlet("tmp_call_1[0]")) + i_cfg2_s1.add_edge(t3, "_out", i_cfg2_s1.add_access("tmp_call_103"), None, dace.memlet.Memlet("tmp_call_103[0]")) + + i_cb1.add_branch(CodeBlock("(ptare_var_0 - rtt_var_1) >= 0.0"), i_cfg1) + i_cb1.add_branch(None, i_cfg2) + + i_s2 = inner_sdfg.add_state_after(i_cb1, label="i_s2") + t4 = i_s2.add_tasklet("t4", inputs={"_in1"}, outputs={"_out"}, code="_out = max(0.0, _in1)") + i_s2.add_edge(i_s2.add_access("tmp_call_103"), None, t4, "_in1", dace.memlet.Memlet("tmp_call_103[0]")) + i_s2.add_edge(t4, "_out", i_s2.add_access("foedelta__ret"), None, dace.memlet.Memlet("foedelta__ret[0]")) + + o_s1 = outer_sdfg.add_state("o_s1", is_start_block=True) + nsdfg = o_s1.add_nested_sdfg(sdfg=inner_sdfg, + inputs=inner_inputs, + outputs=inner_outputs, + symbol_mapping=inner_symbol_mapping) + + for inner_name, (outer_name, access_str) in inner_to_outer_name_mapping_in.items(): + o_s1.add_edge(o_s1.add_access(outer_name), None, nsdfg, inner_name, + dace.memlet.Memlet(f"{outer_name}{access_str}")) + for inner_name, (outer_name, access_str) in inner_to_outer_name_mapping_out.items(): + o_s1.add_edge(nsdfg, inner_name, o_s1.add_access(outer_name), None, + dace.memlet.Memlet(f"{outer_name}{access_str}")) + + outer_sdfg.validate() + return outer_sdfg + + +@pytest.mark.parametrize("ret_arr", [True, False]) +def test_nested_sdfg_with_return(ret_arr): + sdfg = _get_nsdfg_with_return(ret_arr) + sdfg.validate() + + for n, g in sdfg.all_nodes_recursive(): + if isinstance(n, ConditionalBlock): + xform = branch_elimination.BranchElimination() + xform.conditional = n + xform.parent_nsdfg_state = _find_state(sdfg, g.sdfg.parent_nsdfg_node) + assert xform.can_be_applied(graph=g, expr_index=0, sdfg=g.sdfg, permissive=False) + assert xform.can_be_applied(graph=g, expr_index=0, sdfg=g.sdfg, permissive=True) + + ztp = np.random.choice([0.001, 5.0], size=(N, N)) + rtt = np.random.choice([10.0, 15.0], size=(1, )) + zalfa_1 = np.array([999.9]) + arrays = {"ztp": ztp, "rtt": rtt[0], "zalfa_1": zalfa_1} + + # Run SDFG version (no transformation) + sdfg.validate() + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + sdfg(**out_no_fuse) + assert out_no_fuse["zalfa_1"][0] != 999.9 + + # Run SDFG version (with transformation) + fb = EliminateBranches() + fb.try_clean = True + fb.permissive = False + fb.apply_pass(sdfg, {}) + out_fused = {k: v.copy() for k, v in arrays.items()} + sdfg(**out_fused) + assert out_fused["zalfa_1"][0] != 999.9 + + # Compare all arrays + for name in arrays.keys(): + np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) + + +@dace.program +def mid_sdfg(pap: dace.float64[N], ptsphy: dace.float64, r2es: dace.float64, r3ies: dace.float64, r4ies: dace.float64, + rcldtopcf: dace.float64, rd: dace.float64, rdepliqrefdepth: dace.float64, rdepliqrefrate: dace.float64, + rg: dace.float64, riceinit: dace.float64, rlmin: dace.float64, rlstt: dace.float64, rtt: dace.float64, + rv: dace.float64, za: dace.float64[N], zdp: dace.float64[N], zfokoop: dace.float64[N], + zicecld: dace.float64[N], zrho: dace.float64[N], ztp1: dace.float64[N], zcldtopdist: dace.float64[N], + zicenuclei: dace.float64[N], zqxfg: dace.float64[N], zsolqa: dace.float64[N]): + for it_47 in dace.map[ + 0:N:1, + ]: + # Ice nucleation and deposition + if ztp1[it_47] < rtt and zqxfg[it_47] > rlmin: + # Calculate ice saturation vapor pressure + tmp_arg_72 = (r3ies * (ztp1[it_47] - rtt)) / (ztp1[it_47] - r4ies) + zicenuclei[it_47] = 2.0 * np.exp(tmp_arg_72) + # Deposition calculation parameters + zadd = (1.6666666666667 * rlstt * (rlstt / ztp1[it_47])) + zbdd = (0.452488687782805 * pap[it_47] * rv * ztp1[it_47]) + # Update mixing ratios + zqxfg[it_47] = zqxfg[it_47] + zadd + zsolqa[it_47] = zqxfg[it_47] + zbdd + + +@dace.program +def huge_sdfg(pap: dace.float64[N], ptsphy: dace.float64, r2es: dace.float64, r3ies: dace.float64, r4ies: dace.float64, + rcldtopcf: dace.float64, rd: dace.float64, rdepliqrefdepth: dace.float64, rdepliqrefrate: dace.float64, + rg: dace.float64, riceinit: dace.float64, rlmin: dace.float64, rlstt: dace.float64, rtt: dace.float64, + rv: dace.float64, za: dace.float64[N], zdp: dace.float64[N], zfokoop: dace.float64[N], + zicecld: dace.float64[N], zrho: dace.float64[N], ztp1: dace.float64[N], zcldtopdist: dace.float64[N], + zicenuclei: dace.float64[N], zqxfg: dace.float64[N], zsolqa: dace.float64[N]): + for it_47 in dace.map[ + 0:N:1, + ]: + # Check if crossing cloud top threshold + if za[it_47] < rcldtopcf and za[it_47] >= rcldtopcf: + zcldtopdist[it_47] = 0.0 + else: + zcldtopdist[it_47] = zcldtopdist[it_47] + (zdp[it_47] / (rg * zrho[it_47])) + + # Ice nucleation and deposition + if ztp1[it_47] < rtt and zqxfg[it_47] > rlmin: + # Calculate ice saturation vapor pressure + tmp_arg_72 = (r3ies * (ztp1[it_47] - rtt)) / (ztp1[it_47] - r4ies) + tmp_call_47 = r2es * np.exp(tmp_arg_72) + zvpice = (rv * tmp_call_47) / rd + + # Calculate liquid vapor pressure + zvpliq = zfokoop[it_47] * np.log(zvpice) + + # Ice nuclei concentration + tmp_arg_27 = -0.639 + ((-1.96 * zvpice + 1.96 * zvpliq) / zvpliq) + zicenuclei[it_47] = 1000.0 * np.exp(tmp_arg_27) + + # Nucleation factor + zinfactor = min(1.0, 6.66666666666667e-05 * zicenuclei[it_47]) + + # Deposition calculation parameters + zadd = (1.6666666666667 * rlstt * (rlstt / (rv * ztp1[it_47]) - 1.0)) / ztp1[it_47] + zbdd = (0.452488687782805 * pap[it_47] * rv * ztp1[it_47]) / zvpice + + tmp_call_49 = (zicenuclei[it_47] / zrho[it_47]) + zcvds = (7.8 * tmp_call_49 * (zvpliq - zvpice)) / (zvpice * (zadd + zbdd)) + + # Initial ice content + zice0 = max(riceinit * zicenuclei[it_47] / zrho[it_47], zicecld[it_47]) + + # New ice after deposition + tmp_arg_30 = 0.666 * ptsphy * zcvds + zice0 + zinew = tmp_arg_30**1.5 + + # Deposition amount + zdepos1 = max(0.0, za[it_47] * (zinew - zice0)) + zdepos2 = min(zdepos1, 1.1) + + # Apply nucleation factor and cloud top distance factor + tmp_arg_33 = zinfactor + (1.0 - zinfactor) * (rdepliqrefrate + zcldtopdist[it_47] / rdepliqrefdepth) + zdepos3 = zdepos2 * min(1.0, tmp_arg_33) + + # Update mixing ratios + zqxfg[it_47] = zqxfg[it_47] + zdepos3 + zsolqa[it_47] = zsolqa[it_47] + zdepos3 + + +@pytest.mark.parametrize("eps_operator_type_for_log_and_div", ["max", "add"]) +def test_huge_sdfg_with_log_exp_div(eps_operator_type_for_log_and_div: str): + """Generate test data for the loop body function""" + + data = { + 'ptsphy': np.float64(36.0), # timestep (s) + 'r2es': np.float64(6.11), # saturation vapor pressure constant (hPa) + 'r3ies': np.float64(12.0), # ice saturation constant + 'r4ies': np.float64(15.5), # ice saturation constant + 'rcldtopcf': np.float64(16.8), # cloud top threshold + 'rd': np.float64(287.0), # gas constant for dry air (J/kg/K) + 'rdepliqrefdepth': np.float64(20.0), # reference depth + 'rdepliqrefrate': np.float64(17.3), # reference rate + 'rg': np.float64(9.81), # gravity (m/s²) + 'riceinit': np.float64(5.3), # initial ice content (kg/m³) + 'rlmin': np.float64(3.9), # minimum liquid water (kg/m³) + 'rlstt': np.float64(2.5e6), # latent heat (J/kg) + 'rtt': np.float64(273.15), # triple point temperature (K) + 'rv': np.float64(461.5), # gas constant for water vapor (J/kg/K) + } + + # 1D arrays with safe ranges + rng = np.random.default_rng(0) + + def safe_uniform(low, high, size): + """Avoid near-zero or extreme values that could cause NaN in log/div.""" + return rng.uniform(low, high, size).astype(np.float64) + + # State variables (N = grid size) + data['pap'] = safe_uniform(1.0, 2.0, (N, )) # pressure-like + data['za'] = safe_uniform(0.9, 1.5, (N, )) # altitude/cloud-top + data['ztp1'] = safe_uniform(260.0, 280.0, (N, )) # temperature near freezing + data['zqxfg'] = safe_uniform(5.0, 11.0, (N, )) # mixing ratios + data['zsolqa'] = safe_uniform(5.0, 11.0, (N, )) # ice tendencies + + data['zdp'] = safe_uniform(0.5, 2.0, (N, )) # layer depth + data['zfokoop'] = safe_uniform(0.95, 1.05, (N, )) # correction factor + data['zicecld'] = safe_uniform(10.0, 11.0, (N, )) # cloud ice + data['zrho'] = safe_uniform(0.9, 1.2, (N, )) # density + data['zcldtopdist'] = safe_uniform(0.1, 1.0, (N, )) # distance to cloud top + data['zicenuclei'] = safe_uniform(1e2, 1e4, (N, )) # ice nuclei concentration + sdfg = huge_sdfg.to_sdfg() + sdfg.validate() + #it_23: dace.int64, it_47: dace.int64 + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + sdfg.validate() + ConstantPropagation().apply_pass(sdfg, {}) + sdfg.validate() + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + sdfg.auto_optimize(dace.dtypes.DeviceType.CPU, True, True) + sdfg.validate() + out_no_fuse = {k: v.copy() for k, v in data.items()} + sdfg(**out_no_fuse) + # Apply transformation + fb = EliminateBranches() + fb.try_clean = True + fb.eps_operator_type_for_log_and_div = eps_operator_type_for_log_and_div + fb.apply_pass(sdfg, {}) + sdfg.name = sdfg.label + "_transformed" + + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 0 + + # Run SDFG version (with transformation) + out_fused = {k: v.copy() for k, v in data.items()} + + sdfg(**out_fused) + + # Compare all arrays + for name in data.keys(): + print(name) + print(out_fused[name] - out_no_fuse[name]) + np.testing.assert_allclose(out_fused[name], out_no_fuse[name], atol=1e-12) + + +@pytest.mark.parametrize("eps_operator_type_for_log_and_div", ["max", "add"]) +def test_mid_sdfg_with_log_exp_div(eps_operator_type_for_log_and_div: str): + """Generate test data for the loop body function""" + + data = { + 'ptsphy': np.float64(36.0), # timestep (s) + 'r2es': np.float64(6.11), # saturation vapor pressure constant (hPa) + 'r3ies': np.float64(12.0), # ice saturation constant + 'r4ies': np.float64(15.5), # ice saturation constant + 'rcldtopcf': np.float64(16.8), # cloud top threshold + 'rd': np.float64(287.0), # gas constant for dry air (J/kg/K) + 'rdepliqrefdepth': np.float64(20.0), # reference depth + 'rdepliqrefrate': np.float64(17.3), # reference rate + 'rg': np.float64(9.81), # gravity (m/s²) + 'riceinit': np.float64(5.3), # initial ice content (kg/m³) + 'rlmin': np.float64(3.9), # minimum liquid water (kg/m³) + 'rlstt': np.float64(2.5e6), # latent heat (J/kg) + 'rtt': np.float64(273.15), # triple point temperature (K) + 'rv': np.float64(461.5), # gas constant for water vapor (J/kg/K) + } + + # 1D arrays with safe ranges + rng = np.random.default_rng(0) + + def safe_uniform(low, high, size): + """Avoid near-zero or extreme values that could cause NaN in log/div.""" + return rng.uniform(low, high, size).astype(np.float64) + + # State variables (N = grid size) + data['pap'] = safe_uniform(1.0, 2.0, (N, )) # pressure-like + data['za'] = safe_uniform(0.9, 1.5, (N, )) # altitude/cloud-top + data['ztp1'] = safe_uniform(260.0, 280.0, (N, )) # temperature near freezing + data['zqxfg'] = safe_uniform(5.0, 11.0, (N, )) # mixing ratios + data['zsolqa'] = safe_uniform(5.0, 11.0, (N, )) # ice tendencies + + data['zdp'] = safe_uniform(0.5, 2.0, (N, )) # layer depth + data['zfokoop'] = safe_uniform(0.95, 1.05, (N, )) # correction factor + data['zicecld'] = safe_uniform(10.0, 11.0, (N, )) # cloud ice + data['zrho'] = safe_uniform(0.9, 1.2, (N, )) # density + data['zcldtopdist'] = safe_uniform(0.1, 1.0, (N, )) # distance to cloud top + data['zicenuclei'] = safe_uniform(1e2, 1e4, (N, )) # ice nuclei concentration + sdfg = mid_sdfg.to_sdfg() + sdfg.validate() + #it_23: dace.int64, it_47: dace.int64 + ScalarToSymbolPromotion().apply_pass(sdfg, {}) + sdfg.validate() + ConstantPropagation().apply_pass(sdfg, {}) + sdfg.validate() + SymbolPropagation().apply_pass(sdfg, {}) + sdfg.validate() + sdfg.auto_optimize(dace.dtypes.DeviceType.CPU, True, True) + sdfg.validate() + out_no_fuse = {k: v.copy() for k, v in data.items()} + sdfg(**out_no_fuse) + # Apply transformation + fb = EliminateBranches() + fb.try_clean = True + fb.eps_operator_type_for_log_and_div = eps_operator_type_for_log_and_div + fb.apply_pass(sdfg, {}) + sdfg.name = sdfg.label + "_transformed" + + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 0 + + # Run SDFG version (with transformation) + out_fused = {k: v.copy() for k, v in data.items()} + + sdfg(**out_fused) + + # Compare all arrays + for name in data.keys(): + print(name) + print(out_fused[name] - out_no_fuse[name]) + np.testing.assert_allclose(out_fused[name], out_no_fuse[name], atol=1e-12) + + +@dace.program +def wcr_edge(A: dace.float64[N, N]): + for i, j in dace.map[0:N, 0:N]: + cond = A[i, j] + if cond > 0.00000000001: + A[i, j] += 2.0 + + +@dace.program +def loop_param_usage(A: dace.float64[6, N, N], B: dace.float64[N, N], C: dace.float64[N, N]): + for i in range(6): + for j in dace.map[0:N]: + for k in dace.map[0:N]: + if A[i, j, k] > 2.0: + C[i, j] = 2.0 + C[i, j] + + +def test_loop_param_usage(): + A = np.random.choice([0.001, 5.0], size=(6, N, N)) + B = np.random.choice([0.001, 5.0], size=(N, N)) + C = np.random.choice([0.001, 5.0], size=(N, N)) + + sdfg = loop_param_usage.to_sdfg() + sdfg.save("x.sdfg") + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + + for cblock in cblocks: + xform = branch_elimination.BranchElimination() + xform.conditional = cblock + xform.parent_nsdfg_state = _find_state( + sdfg, cblock.sdfg.parent_nsdfg_node) if cblock.sdfg.parent_nsdfg_node is not None else None + assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, False) is True + assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, True) is True + + run_and_compare_sdfg(sdfg, False, A=A, B=B, C=C) + + +def test_can_be_applied_on_wcr_edge(): + sdfg = wcr_edge.to_sdfg() + + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + assert len(cblocks) == 1 + + for cblock in cblocks: + xform = branch_elimination.BranchElimination() + xform.conditional = cblock + xform.parent_nsdfg_state = _find_state( + sdfg, cblock.sdfg.parent_nsdfg_node) if cblock.sdfg.parent_nsdfg_node is not None else None + assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, False) is False + assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, True) is False + + from dace.transformation.dataflow.wcr_conversion import WCRToAugAssign + sdfg.apply_transformations_repeated(WCRToAugAssign) + + cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + for cblock in cblocks: + xform = branch_elimination.BranchElimination() + xform.conditional = cblock + xform.parent_nsdfg_state = _find_state( + sdfg, cblock.sdfg.parent_nsdfg_node) if cblock.sdfg.parent_nsdfg_node is not None else None + assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, False) is True + assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, True) is True + + A = np.random.choice([0.001, 5.0], size=(N, N)) + + run_and_compare_sdfg(sdfg, False, A=A) + + +if __name__ == "__main__": + test_huge_sdfg_with_log_exp_div("max") + test_huge_sdfg_with_log_exp_div("add") + test_mid_sdfg_with_log_exp_div("max") + test_mid_sdfg_with_log_exp_div("add") + test_nested_sdfg_with_return(True) + test_nested_sdfg_with_return(False) + test_safe_map_param_use_in_nested_sdfg() + test_can_be_applied_on_map_param_usage() + test_pattern_from_cloudsc_one(0.0) + test_pattern_from_cloudsc_one(1.0) + test_pattern_from_cloudsc_one(6.0) + test_condition_on_bounds() + test_nested_if_two() + test_disjoint_chain_split_branch_only(0.0) + test_disjoint_chain_split_branch_only(4.0) + test_disjoint_chain_split_branch_only(6.0) + test_disjoint_chain(0.0) + test_disjoint_chain(4.0) + test_disjoint_chain(6.0) + test_condition_from_transient_scalar() + test_single_assignment() + test_single_assignment_cond_from_scalar() + test_sdfg_with_interstate_array_condition() + test_branch_dependent_value_write_with_transient_reuse() + test_try_clean() + test_try_clean_as_pass() + test_repeated_condition_variables() + test_weird_condition() + test_if_over_map() + test_if_over_map_with_top_level_tasklets() + test_can_be_applied_parameters_on_nested_sdfg() + test_non_trivial_subset_after_combine_tasklet() + test_split_on_disjoint_subsets() + test_split_on_disjoint_subsets_nested() + test_write_to_transient() + test_write_to_transient_two() + test_double_empty_state() + test_complicated_pattern_for_manual_clean_up_one() + test_try_clean_on_complicated_pattern_for_manual_clean_up_one() + test_try_clean_on_complicated_pattern_for_manual_clean_up_two() + for use_pass_flag in [True, False]: + test_branch_dependent_value_write(use_pass_flag) + test_branch_dependent_value_write_two(use_pass_flag) + test_branch_dependent_value_write_single_branch(use_pass_flag) + test_branch_dependent_value_write_single_branch_nonzero_write(use_pass_flag) + test_single_branch_connectors(use_pass_flag) + test_complicated_if(use_pass_flag) + test_multi_state_branch_body(use_pass_flag) + test_nested_if(use_pass_flag) + test_tasklets_in_if(use_pass_flag) + test_disjoint_subsets(use_pass_flag) diff --git a/tests/utils/generate_assignment_as_tasklet_instate_test.py b/tests/utils/generate_assignment_as_tasklet_instate_test.py new file mode 100644 index 0000000000..19d2fb999a --- /dev/null +++ b/tests/utils/generate_assignment_as_tasklet_instate_test.py @@ -0,0 +1,23 @@ +import dace +import dace.sdfg.construction_utils as cutil + + +def _get_sdfg() -> dace.SDFG: + sdfg = dace.SDFG("sd1") + s1 = sdfg.add_state("s1", is_start_block=True) + + sdfg.add_array("A", (5, 5), dace.float64) + sdfg.add_array("B", (5, 5), dace.float64) + sdfg.add_scalar("c", dace.float64) + return sdfg, s1 + + +def test_assignment_as_tasklet(): + sdfg, s1 = _get_sdfg() + sdfg.validate() + cutil.generate_assignment_as_tasklet_in_state(s1, "c", "A[4, 4] + 2.0 * B[1, 2]") + sdfg.validate() + + +if __name__ == "__main__": + test_assignment_as_tasklet() From f019f98bc9d123a94a38f14a66fc4be406fb63b2 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 12:43:02 +0100 Subject: [PATCH 02/17] Rm explicit vectorize things --- .../passes/explicit_vectorization_cpu.py | 402 ------------------ .../passes/explicit_vectorization_gpu.py | 129 ------ 2 files changed, 531 deletions(-) delete mode 100644 dace/transformation/passes/explicit_vectorization_cpu.py delete mode 100644 dace/transformation/passes/explicit_vectorization_gpu.py diff --git a/dace/transformation/passes/explicit_vectorization_cpu.py b/dace/transformation/passes/explicit_vectorization_cpu.py deleted file mode 100644 index 91c553518c..0000000000 --- a/dace/transformation/passes/explicit_vectorization_cpu.py +++ /dev/null @@ -1,402 +0,0 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. -import dace -from typing import Dict, Iterator -from dace.transformation import Pass, pass_pipeline as ppl -from dace.transformation.pass_pipeline import Modifies -from dace.transformation.passes.clean_data_to_scalar_slice_to_tasklet_pattern import CleanDataToScalarSliceToTaskletPattern -from dace.transformation.passes.duplicate_all_memlets_sharing_same_in_connector import DuplicateAllMemletsSharingSingleMapOutConnector -from dace.transformation.passes.split_tasklets import SplitTasklets -from dace.transformation.passes.tasklet_preprocessing_passes import PowerOperatorExpansion, RemoveFPTypeCasts, RemoveIntTypeCasts -from dace.transformation.passes import InlineSDFGs -from dace.transformation.passes.explicit_vectorization import ExplicitVectorization -from dace.transformation.passes.eliminate_branches import EliminateBranchesPass -from dace.transformation.passes.remove_redundant_assignment_tasklets import RemoveRedundantAssignmentTasklets -import dace.sdfg.utils as sdutil - - -class ExplicitVectorizationPipelineCPU(ppl.Pipeline): - _cpu_global_code = """ -#include - - -#if defined(__clang__) - #define _dace_vectorize_hint - #define _dace_vectorize "clang loop vectorize(enable) vectorize_width({vector_width}8)" -#elif defined(__GNUC__) - #define _dace_vectorize_hint - #define _dace_vectorize "omp simd simdlen({vector_width})" -#else - #define _dace_vectorize_hint - #define _dace_vectorize "omp simd simdlen({vector_width})" -#endif - - -template -inline void vector_mult(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - c[i] = a[i] * b[i]; - }} -}} - -template -inline void vector_mult_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = a[i] * constant; - }} -}} - -template -inline void vector_add(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - c[i] = a[i] + b[i]; - }} -}} - -template -inline void vector_add_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = a[i] + constant; - }} -}} - -template -inline void vector_sub(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - c[i] = a[i] - b[i]; - }} -}} - -template -inline void vector_sub_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = a[i] - constant; - }} -}} - -template -inline void vector_sub_w_scalar_c(T * __restrict__ b, const T constant, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = constant - a[i]; - }} -}} - -template -inline void vector_div(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - c[i] = a[i] / b[i]; - }} -}} - -template -inline void vector_div_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = a[i] / constant; - }} -}} - -template -inline void vector_div_w_scalar_c(T * __restrict__ b, const T constant, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = constant / a[i]; - }} -}} - -template -inline void vector_copy(T * __restrict__ dst, const T * __restrict__ src) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - dst[i] = src[i]; - }} -}} - -// ---- Additional elementwise ops ---- - -template -inline void vector_exp(T * __restrict__ out, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = std::exp(a[i]); - }} -}} - -template -inline void vector_log(T * __restrict__ out, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = std::log(a[i]); - }} -}} - -template -inline void vector_min(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = std::min(a[i], b[i]); - }} -}} - -template -inline void vector_min_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = std::min(a[i], constant); - }} -}} - -template -inline void vector_max(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = std::max(a[i], b[i]); - }} -}} - -template -inline void vector_max_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = std::max(a[i], constant); - }} -}} - -template -inline void vector_gt(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] > b[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_gt_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] > constant) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_gt_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (constant > a[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_lt(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] < b[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_lt_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] < constant) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_lt_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (constant < a[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_ge(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] >= b[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_ge_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] >= constant) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_ge_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (constant >= a[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_le(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] <= b[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_le_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] <= constant) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_le_w_scalar_c(T * __restrict__ out, const T constant, const T * __restrict__ a) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (constant <= a[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_eq(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] == b[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_eq_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] == constant) ? 1.0 : 0.0; - }} -}} - - -template -inline void vector_ne(T * __restrict__ out, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] != b[i]) ? 1.0 : 0.0; - }} -}} - -template -inline void vector_ne_w_scalar(T * __restrict__ out, const T * __restrict__ a, const T constant) {{ - #pragma _dace_vectorize_hint - #pragma _dace_vectorize - for (int i = 0; i < {vector_width}; i++) {{ - out[i] = (a[i] != constant) ? 1.0 : 0.0; - }} -}} -""" - - def __init__(self, vector_width): - passes = [ - EliminateBranchesPass(), - RemoveFPTypeCasts(), - RemoveIntTypeCasts(), - PowerOperatorExpansion(), - SplitTasklets(), - CleanDataToScalarSliceToTaskletPattern(), - InlineSDFGs(), - DuplicateAllMemletsSharingSingleMapOutConnector(), - ExplicitVectorization( - templates={ - "*": "vector_mult({lhs}, {rhs1}, {rhs2});", - "+": "vector_add({lhs}, {rhs1}, {rhs2});", - "-": "vector_sub({lhs}, {rhs1}, {rhs2});", - "/": "vector_div({lhs}, {rhs1}, {rhs2});", - "=": "vector_copy({lhs}, {rhs1});", - "log": "vector_log({lhs}, {rhs1});", - "exp": "vector_exp({lhs}, {rhs1});", - "min": "vector_min({lhs}, {rhs1}, {rhs2});", - "max": "vector_max({lhs}, {rhs1}, {rhs2});", - ">": "vector_gt({lhs}, {rhs1}, {rhs2});", - "<": "vector_lt({lhs}, {rhs1}, {rhs2});", - ">=": "vector_ge({lhs}, {rhs1}, {rhs2});", - "<=": "vector_le({lhs}, {rhs1}, {rhs2});", - "==": "vector_eq({lhs}, {rhs1}, {rhs2});", - "!=": "vector_ne({lhs}, {rhs1}, {rhs2});", - # scalar variants type 1 - "c*": "vector_mult_w_scalar({lhs}, {rhs1}, {constant});", - "c+": "vector_add_w_scalar({lhs}, {rhs1}, {constant});", - "c-": "vector_sub_w_scalar({lhs}, {rhs1}, {constant});", - "c/": "vector_div_w_scalar({lhs}, {rhs1}, {constant});", - "cmin": "vector_min_w_scalar({lhs}, {rhs1}, {constant});", - "cmax": "vector_max_w_scalar({lhs}, {rhs1}, {constant});", - "c>": "vector_gt_w_scalar({lhs}, {rhs1}, {constant});", - "c<": "vector_lt_w_scalar({lhs}, {rhs1}, {constant});", - "c>=": "vector_ge_w_scalar({lhs}, {rhs1}, {constant});", - "c<=": "vector_le_w_scalar({lhs}, {rhs1}, {constant});", - "c==": "vector_eq_w_scalar({lhs}, {rhs1}, {constant});", - "c!=": "vector_ne_w_scalar({lhs}, {rhs1}, {constant});", - # scalar variants type 2 for non-commutative ops - "-c": "vector_sub_w_scalar_c({lhs}, {constant}, {rhs1});", - "/c": "vector_div_w_scalar_c({lhs}, {constant}, {rhs1});", - ">c": "vector_gt_w_scalar_c({lhs}, {constant}, {rhs1});", - "=c": "vector_ge_w_scalar_c({lhs}, {constant}, {rhs1});", - "<=c": "vector_le_w_scalar_c({lhs}, {constant}, {rhs1});", - }, - vector_width=vector_width, - vector_input_storage=dace.dtypes.StorageType.Register, - vector_output_storage=dace.dtypes.StorageType.Register, - global_code=ExplicitVectorizationPipelineCPU._cpu_global_code.format(vector_width=vector_width), - global_code_location="frame", - vector_op_numeric_type=dace.float64) - ] - super().__init__(passes) - - def iterate_over_passes(self, sdfg: dace.SDFG) -> Iterator[Pass]: - """ - Iterates over passes in the pipeline, potentially multiple times based on which elements were modified - in the pass. - Note that this method may be overridden by subclasses to modify pass order. - - :param sdfg: The SDFG on which the pipeline is currently being applied - """ - for p in self.passes: - p: Pass - yield p diff --git a/dace/transformation/passes/explicit_vectorization_gpu.py b/dace/transformation/passes/explicit_vectorization_gpu.py deleted file mode 100644 index 22ab47853e..0000000000 --- a/dace/transformation/passes/explicit_vectorization_gpu.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Dict, Iterator -import dace -from dace.transformation import Pass, pass_pipeline as ppl -from dace.transformation.pass_pipeline import Modifies -from dace.transformation.passes.clean_data_to_scalar_slice_to_tasklet_pattern import CleanDataToScalarSliceToTaskletPattern -from dace.transformation.passes.duplicate_all_memlets_sharing_same_in_connector import DuplicateAllMemletsSharingSingleMapOutConnector -from dace.transformation.passes.split_tasklets import SplitTasklets -from dace.transformation.passes.tasklet_preprocessing_passes import PowerOperatorExpansion, RemoveFPTypeCasts, RemoveIntTypeCasts -from dace.transformation.passes import InlineSDFGs -from dace.transformation.passes.explicit_vectorization import ExplicitVectorization -from dace.transformation.passes.eliminate_branches import EliminateBranchesPass -from dace.transformation.passes.remove_redundant_assignment_tasklets import RemoveRedundantAssignmentTasklets -import dace.sdfg.utils as sdutil - - -class ExplicitVectorizationPipelineGPU(ppl.Pipeline): - _gpu_global_code = """ -template -__host__ __device__ __forceinline__ void vector_mult(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - c[i] = a[i] * b[i]; - }} -}} - -template -__host__ __device__ __forceinline__ void vector_mult_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ - T cReg[{vector_width}]; - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - cReg[i] = constant; - }} - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = a[i] * cReg[i]; - }} -}} - -template -__host__ __device__ __forceinline__ void vector_add(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - c[i] = a[i] + b[i]; - }} -}} - -template -__host__ __device__ __forceinline__ void vector_add_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ - T cReg[{vector_width}]; - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - cReg[i] = constant; - }} - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = a[i] + cReg[i]; - }} -}} - -template -__host__ __device__ __forceinline__ void vector_div(T * __restrict__ c, const T * __restrict__ a, const T * __restrict__ b) {{ - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - c[i] = a[i] / b[i]; - }} -}} - -template -__host__ __device__ __forceinline__ void vector_div_w_scalar(T * __restrict__ b, const T * __restrict__ a, const T constant) {{ - T cReg[{vector_width}]; - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - cReg[i] = constant; - }} - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - b[i] = a[i] / cReg[i]; - }} -}} - -template -__host__ __device__ __forceinline__ void vector_copy(T * __restrict__ dst, const T * __restrict__ src) {{ - #pragma omp unroll - for (int i = 0; i < {vector_width}; i++) {{ - dst[i] = src[i]; - }} -}} -""" - - def __init__(self, vector_width): - passes = [ - EliminateBranchesPass(), - RemoveFPTypeCasts(), - RemoveIntTypeCasts(), - PowerOperatorExpansion(), - SplitTasklets(), - CleanDataToScalarSliceToTaskletPattern(), - InlineSDFGs(), - DuplicateAllMemletsSharingSingleMapOutConnector(), - ExplicitVectorization( - templates={ - "*": "vector_mult({lhs}, {rhs1}, {rhs2});", - "+": "vector_add({lhs}, {rhs1}, {rhs2});", - "=": "vector_copy({lhs}, {rhs1});", - "c+": "vector_add({lhs}, {rhs1}, {constant});", - "c*": "vector_mult({lhs}, {rhs1}, {constant});", - }, - vector_width=vector_width, - vector_input_storage=dace.dtypes.StorageType.Register, - vector_output_storage=dace.dtypes.StorageType.Register, - global_code=ExplicitVectorizationPipelineGPU._gpu_global_code.format(vector_width=vector_width), - global_code_location="frame", - vector_op_numeric_type=dace.float64, - ) - ] - super().__init__(passes) - - def iterate_over_passes(self, sdfg: dace.SDFG) -> Iterator[Pass]: - """ - Iterates over passes in the pipeline, potentially multiple times based on which elements were modified - in the pass. - Note that this method may be overridden by subclasses to modify pass order. - - :param sdfg: The SDFG on which the pipeline is currently being applied - """ - for p in self.passes: - p: Pass - yield p From eadeda58f39d5c8ca189d1e0ca61048dbd774cc3 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 12:45:31 +0100 Subject: [PATCH 03/17] Fix --- dace/sdfg/construction_utils.py | 133 -------------------------------- dace/sdfg/utils.py | 2 +- 2 files changed, 1 insertion(+), 134 deletions(-) diff --git a/dace/sdfg/construction_utils.py b/dace/sdfg/construction_utils.py index 58c201441b..388d0c9d4c 100644 --- a/dace/sdfg/construction_utils.py +++ b/dace/sdfg/construction_utils.py @@ -763,136 +763,3 @@ def get_parent_maps(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_stat parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) return maps - - -def duplicate_memlets_sharing_single_in_connector(state: dace.SDFGState, map_entry: dace.nodes.MapEntry): - - def _find_new_name(base: str, existing_names: Set[str]) -> str: - i = 0 - candidate = f"{base}_d_{i}" - while candidate in existing_names: - i += 1 - candidate = f"{base}_d_{i}" - return candidate - - for out_conn in list(map_entry.out_connectors.keys()): - out_edges_of_out_conn = set(state.out_edges_by_connector(map_entry, out_conn)) - if len(out_edges_of_out_conn) > 1: - base_in_edge = out_edges_of_out_conn.pop() - - # Get all parent maps (including this) - parent_maps: Set[dace.nodes.MapEntry] = {map_entry} - sdict = state.scope_dict() - parent_map = sdict[map_entry] - while parent_map is not None: - parent_maps.add(parent_map) - parent_map = sdict[parent_map] - - # Need it to find unique names - all_existing_connector_names = set() - for map_entry in parent_maps: - for in_conn in map_entry.in_connectors: - all_existing_connector_names.add(in_conn[len("IN_"):]) - for out_conn in map_entry.out_connectors: - all_existing_connector_names.add(out_conn[len("OUT_"):]) - - # Base path - memlet_paths = [] - path = state.memlet_path(base_in_edge) - source_node = path[0].src - memlet_paths.append(path) - while sdict[source_node] is not None: - if not isinstance(source_node, (dace.nodes.AccessNode, dace.nodes.MapEntry)): - print(source_node) - raise Exception( - f"In the path from map entry to the top level scope, only access nodes and other map entries may appear, got: {source_node}" - ) - in_edges = state.in_edges(source_node) - if isinstance(source_node, dace.nodes.MapEntry) and len(in_edges) != 1: - in_edges = list(state.in_edges_by_connector(source_node, "IN_" + path[-1].src_conn[len("OUT_"):])) - if isinstance(source_node, dace.nodes.AccessNode) and len(in_edges) != 1: - raise Exception( - "In the path from map entry to the top level scope, the intermediate access nodes need to have in and out degree (by connector) 1" - ) - - in_edge = in_edges[0] - path = state.memlet_path(in_edge) - source_node = path[0].src - memlet_paths.append(path) - #print(source_node) - - # Need to duplicate the out edges - for e in list(out_edges_of_out_conn): - state.remove_edge(e) - - for edge_to_duplicate in out_edges_of_out_conn: - base = edge_to_duplicate.src_conn[len("OUT_"):] - new_connector_base = _find_new_name(base, all_existing_connector_names) - all_existing_connector_names.add(new_connector_base) - - node_map = dict() - for i, subpath in enumerate(memlet_paths): - for j, e in enumerate(reversed(subpath)): - # We work by adding an in edge - in_name = f"IN_{new_connector_base}" - out_name = f"OUT_{new_connector_base}" - - if e.src_conn is not None: - out_conn = out_name if e.src_conn.startswith("OUT_") else e.src_conn - else: - out_conn = None - - if e.dst_conn is not None: - if e.src == map_entry: - in_conn = edge_to_duplicate.dst_conn - else: - in_conn = in_name if e.dst_conn.startswith("IN_") else e.dst_conn - else: - in_conn = None - - if isinstance(e.src, dace.nodes.MapEntry): - src_node = e.src - elif isinstance(e.src, dace.nodes.AccessNode): - if e.src in node_map: - src_node = node_map[e.src] - else: - a = state.add_access(e.src.data) - node_map[e.src] = a - src_node = a - else: - src_node = e.src - - if isinstance(e.dst, dace.nodes.MapEntry): - dst_node = e.dst - elif isinstance(e.dst, dace.nodes.AccessNode): - if e.dst in node_map: - dst_node = node_map[e.dst] - else: - a = state.add_access(e.dst.data) - node_map[e.dst] = a - dst_node = a - else: - dst_node = e.dst - - # Above the first map, always add the complete subset and then call memlet propagation - if e.src is map_entry: - data = copy.deepcopy(edge_to_duplicate.data) - else: - data = dace.memlet.Memlet.from_array(e.data.data, state.sdfg.arrays[e.data.data]) - - state.add_edge(src_node, out_conn, dst_node, in_conn, data) - - if out_conn is not None and out_conn not in src_node.out_connectors: - src_node.add_out_connector(out_conn, force=True) - if in_conn is not None and in_conn not in dst_node.in_connectors: - dst_node.add_in_connector(in_conn, force=True) - - # If we duplicate an access node, we should add correct dependency edges - if i == len(memlet_paths) - 1: - if j == len(subpath) - 1: - # Source node - origin_source_node = e.src - for ie in state.in_edges(origin_source_node): - state.add_edge(ie.src, None, src_node, None, dace.memlet.Memlet(None)) - - propagate_memlets_state(state.sdfg, state) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index ba99bf78c2..4c298cba50 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -2590,7 +2590,7 @@ def demote_symbol_to_scalar(sdfg: 'dace.SDFG', symbol_str: str, default_type: 'd # 2. If used in tasklet try to replace symbol name with an in connector and add an access to the scalar # Sanity check no tasklet should assign to a symbol - cutil.tasklet_replace_code(n.code.as_string, {symbol_str: f"_in_{symbol_str}"}) + cutil.tasklet_replace_code(n, {symbol_str: f"_in_{symbol_str}"}) n.add_in_connector(f"_in_{symbol_str}") access = g.add_access(symbol_str) g.add_edge(access, None, n, f"_in_{symbol_str}", dace.memlet.Memlet(expr=f"{symbol_str}[0]")) From 6a603c4f21f0403ff9aa1f2b11a8c0d0e4bdfe27 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 12:54:32 +0100 Subject: [PATCH 04/17] Stuff --- dace/sdfg/utils.py | 173 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 164 insertions(+), 9 deletions(-) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 4c298cba50..1bf4632860 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -8,8 +8,6 @@ import networkx as nx import time -import sympy - import dace.sdfg.nodes from dace.codegen import compiled_sdfg as csdfg from dace.sdfg.graph import MultiConnectorEdge @@ -596,6 +594,139 @@ def merge_maps( return merged_entry, merged_exit +def canonicalize_memlet_trees_for_scope( + state: SDFGState, + scope_node: Union[nd.EntryNode, nd.ExitNode], +) -> int: + """Canonicalize the Memlet trees of scope nodes. + + The function will modify all Memlets that are adjacent to `scope_node` + such that the Memlet always refers to the data that is on the outside. + This function only operates on a single scope. + + :param state: The SDFG state in which the scope to consolidate resides. + :param scope_node: The scope node whose edges will be consolidated. + :return: Number of modified Memlets. + + :note: This is the "historical" expected format of Memlet trees at scope nodes, + which was present before the introduction of `other_subset`. Running this + transformation might fix some issues. + """ + if isinstance(scope_node, nd.EntryNode): + may_have_dynamic_map_range = True + is_downward_tree = True + outer_edges = state.in_edges(scope_node) + get_outer_edge_connector = lambda e: e.dst_conn + inner_edges_for = lambda conn: state.out_edges_by_connector(scope_node, conn) + inner_prefix = 'OUT_' + outer_prefix = 'IN_' + + def get_outer_data(e: MultiConnectorEdge[dace.Memlet]): + mpath = state.memlet_path(e) + assert isinstance(mpath[0].src, nd.AccessNode) + return mpath[0].src.data + + else: + may_have_dynamic_map_range = False + is_downward_tree = False + outer_edges = state.out_edges(scope_node) + get_outer_edge_connector = lambda e: e.src_conn + inner_edges_for = lambda conn: state.in_edges_by_connector(scope_node, conn) + inner_prefix = 'IN_' + outer_prefix = 'OUT_' + + def get_outer_data(e: MultiConnectorEdge[dace.Memlet]): + mpath = state.memlet_path(e) + assert isinstance(mpath[-1].dst, nd.AccessNode) + return mpath[-1].dst.data + + def swap_prefix(conn: str) -> str: + if conn.startswith(inner_prefix): + return outer_prefix + conn[len(inner_prefix):] + else: + assert conn.startswith( + outer_prefix), f"Expected connector to start with '{outer_prefix}', but it was '{conn}'." + return inner_prefix + conn[len(outer_prefix):] + + modified_memlet = 0 + for outer_edge in outer_edges: + outer_edge_connector = get_outer_edge_connector(outer_edge) + if may_have_dynamic_map_range and (not outer_edge_connector.startswith(outer_prefix)): + continue + assert outer_edge_connector.startswith(outer_prefix) + corresponding_inner_connector = swap_prefix(outer_edge_connector) + + # In case `scope_node` is at the global scope it should be enough to run + # `outer_edge.data.data` but this way it is more in line with consolidate. + outer_data = get_outer_data(outer_edge) + + for inner_edge in inner_edges_for(corresponding_inner_connector): + for mtree in state.memlet_tree(inner_edge).traverse_children(include_self=True): + medge: MultiConnectorEdge[dace.Memlet] = mtree.edge + if medge.data.data == outer_data: + # This edge is already referring to the outer data, so no change is needed. + continue + + # Now we have to extract subset from the Memlet. + if is_downward_tree: + subset = medge.data.get_src_subset(medge, state) + other_subset = medge.data.dst_subset + else: + subset = medge.data.get_dst_subset(medge, state) + other_subset = medge.data.src_subset + + # Now for an update. + medge.data._data = outer_data + medge.data._subset = subset + medge.data._other_subset = other_subset + medge.data.try_initialize(state.sdfg, state, medge) + modified_memlet += 1 + + return modified_memlet + + +def canonicalize_memlet_trees( + sdfg: 'dace.SDFG', + starting_scope: Optional['dace.sdfg.scope.ScopeTree'] = None, +) -> int: + """Canonicalize the Memlet trees of all scopes in the SDFG. + + This function runs `canonicalize_memlet_trees_for_scope()` on all scopes + in the SDFG. Note that this function does not recursively processes + nested SDFGs. + + :param sdfg: The SDFG to consolidate. + :param starting_scope: If not None, starts with a certain scope. Note in that + mode only the state in which the scope is located will be processes. + :return: Number of modified Memlets. + """ + + total_modified_memlets = 0 + for state in sdfg.states(): + # Start bottom-up + if starting_scope is not None and starting_scope.entry not in state.nodes(): + continue + + queue = [starting_scope] if starting_scope else state.scope_leaves() + next_queue = [] + while len(queue) > 0: + for scope in queue: + if scope.entry is not None: + total_modified_memlets += canonicalize_memlet_trees_for_scope(state, scope.entry) + if scope.exit is not None: + total_modified_memlets += canonicalize_memlet_trees_for_scope(state, scope.exit) + if scope.parent is not None: + next_queue.append(scope.parent) + queue = next_queue + next_queue = [] + + if starting_scope is not None: + # No need to traverse other states + break + + return total_modified_memlets + + def consolidate_edges_scope(state: SDFGState, scope_node: Union[nd.EntryNode, nd.ExitNode]) -> int: """ Union scope-entering memlets relating to the same data node in a scope. @@ -2418,11 +2549,21 @@ def _specialize_scalar_impl(root: 'dace.SDFG', sdfg: 'dace.SDFG', scalar_name: s # 3. Access Node # -> If access node is used then e.g. [scalar] -> [tasklet] # -> then create a [tasklet] that uses the scalar_val as a constant value inside - import dace.sdfg.construction_utils as cutil + import re + + def _token_replace(code: str, src: str, dst: str) -> str: + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', code) + + # Replace tokens that exactly match src + tokens = [dst if token.strip() == src else token for token in tokens] + + # Recombine everything + return ''.join(tokens).strip() def repl_code_block_or_str(input: Union[CodeBlock, str], src: str, dst: str): if isinstance(input, CodeBlock): - return CodeBlock(cutil.replace_code(input.as_string, input.language, {src: dst}), input.language) + return CodeBlock(_token_replace(input.as_string, src, dst)) else: return input.replace(src, dst) @@ -2454,10 +2595,19 @@ def repl_code_block_or_str(input: Union[CodeBlock, str], src: str, dst: str): if isinstance(e.dst, nd.Tasklet): in_tasklet_name = e.dst_conn - new_code = CodeBlock(code=cutil.replace_code(e.dst.code.as_string, e.dst.code.language, - {in_tasklet_name: scalar_val}), - language=e.dst.code.language) - e.dst.code = new_code + if e.dst.code.language == dace.dtypes.Language.Python: + import sympy + lhs, rhs = e.dst.code.as_string.split("=") + lhs = lhs.strip() + rhs = rhs.strip() + subs_rhs = str(sympy.pycode(dace.symbolic.SymExpr(rhs).subs({in_tasklet_name: scalar_val}))).strip() + new_code = CodeBlock(code=f"{lhs} = {subs_rhs}", language=dace.dtypes.Language.Python) + e.dst.code = new_code + else: + + new_code = CodeBlock(code=_token_replace(e.dst.code.as_string, in_tasklet_name, scalar_val), + language=e.dst.code.language) + e.dst.code = new_code state.remove_edge(e) if e.src_conn is not None: src.remove_out_connector(e.src_conn) @@ -2521,7 +2671,9 @@ def repl_code_block_or_str(input: Union[CodeBlock, str], src: str, dst: str): _specialize_scalar_impl(root, nsdfg, scalar_name, scalar_val) -def specialize_scalar(sdfg: 'dace.SDFG', scalar_name: str, scalar_val: Union[float, int, str, sympy.Number]): +def specialize_scalar(sdfg: 'dace.SDFG', scalar_name: str, scalar_val: Union[float, int, str]): + import sympy + assert isinstance(scalar_name, str), f"Expected scalar name to be str got {type(scalar_val)}" def _sympy_to_python_number(val): @@ -2590,6 +2742,9 @@ def demote_symbol_to_scalar(sdfg: 'dace.SDFG', symbol_str: str, default_type: 'd # 2. If used in tasklet try to replace symbol name with an in connector and add an access to the scalar # Sanity check no tasklet should assign to a symbol + lhs, rhs = n.code.as_string.split(" = ") + tasklet_lhs = lhs.strip() + assert symbol_str not in tasklet_lhs cutil.tasklet_replace_code(n, {symbol_str: f"_in_{symbol_str}"}) n.add_in_connector(f"_in_{symbol_str}") access = g.add_access(symbol_str) From 358aef33340adcc45132e32a5266936161d9721c Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 13:02:32 +0100 Subject: [PATCH 05/17] Add classify tasklet --- dace/sdfg/tasklet_utils.py | 576 +++++++++++++++++++++++++++ tests/utils/classify_tasklet_test.py | 447 +++++++++++++++++++++ 2 files changed, 1023 insertions(+) create mode 100644 dace/sdfg/tasklet_utils.py create mode 100644 tests/utils/classify_tasklet_test.py diff --git a/dace/sdfg/tasklet_utils.py b/dace/sdfg/tasklet_utils.py new file mode 100644 index 0000000000..cc9d921a08 --- /dev/null +++ b/dace/sdfg/tasklet_utils.py @@ -0,0 +1,576 @@ +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. +""" +Tasklet Classification Utilities + +This module provides utilities for analyzing and classifying DaCe tasklets based on their +computational patterns. It parses tasklet code to determine the types of operations, operands, +and constants involved, enabling automated code generation and optimization passes. + +The main functionality is the `classify_tasklet` function, which inspects a tasklet's code +and metadata to determine its type (e.g., array-symbol operation, binary array operation) +and extract relevant information such as operands, constants, and operations. +""" + +import dace +from typing import Dict, Tuple, Set +from dace.properties import CodeBlock +from enum import Enum +import ast +import typing + + +class TaskletType(Enum): + """ + Enumeration of supported tasklet computational patterns. + + Each pattern represents a specific combination of input types (arrays, scalars, symbols) + and operation types (assignment, binary operation, unary operation). + + Note: inside a tasklet you always have scalars, it is about he connector types + Assignment Operations: + ARRAY_ARRAY_ASSIGNMENT: Direct array-to-array copy (e.g., a = b) + ARRAY_SYMBOL_ASSIGNMENT: Symbol/constant assignment to array (e.g., a = sym) + ARRAY_SCALAR_ASSIGNMENT: Scalar variable assignment to array (e.g., a = scl) + SCALAR_ARRAY_ASSIGNMENT: Array assignment to scalar variable (e.g., scl = a) + SCALAR_SCALAR_ASSIGNMENT: Scalar assignment to scalar variable (e.g., scl = scl) + + Binary Operations with Arrays: + ARRAY_SYMBOL: Array with symbol/constant (e.g., out = arr + 5, out = arr * N) + ARRAY_SCALAR: Array with scalar variable (e.g., out = arr + scl) + ARRAY_ARRAY: Two arrays (e.g., out = arr1 + arr2) + + Binary Operations with Scalars/Symbols: + SCALAR_SYMBOL: Scalar with symbol/constant (e.g., out = scl + 5) + SCALAR_SCALAR: Two scalars (e.g., out = scl1 + scl2) + SYMBOL_SYMBOL: Two symbols (e.g., out = sym1 + sym2) + + Unary Operations: + UNARY_ARRAY: Single array operand (e.g., out = abs(arr), out = arr * arr) + UNARY_SCALAR: Single scalar operand (e.g., out = abs(scl), out = scl * scl) + UNARY_SYMBOL: Single symbol operand (e.g., out = abs(sym), out = sym * sym) + """ + ARRAY_ARRAY_ASSIGNMENT = "array_array_assignment" + ARRAY_SYMBOL_ASSIGNMENT = "array_symbol_assignment" + ARRAY_SCALAR_ASSIGNMENT = "array_scalar_assignment" + SCALAR_ARRAY_ASSIGNMENT = "scalar_array_assignment" + SCALAR_SCALAR_ASSIGNMENT = "scalar_scalar_assignment" + SCALAR_SYMBOL = "scalar_symbol" + ARRAY_SYMBOL = "array_symbol" + ARRAY_SCALAR = "array_scalar" + ARRAY_ARRAY = "array_array" + UNARY_ARRAY = "unary_array" + UNARY_SYMBOL = "unary_symbol" + UNARY_SCALAR = "unary_scalar" + SCALAR_SCALAR = "scalar_scalar" + SYMBOL_SYMBOL = "symbol_symbol" + + +def _extract_constant_from_ast_str(src: str) -> str: + """ + Extract a numeric constant from a Python code string using AST parsing. + + Supports both direct constants (e.g., 42, 3.14) and unary operations on constants + (e.g., -5, +3.14). The function walks the AST tree to find constant nodes. + + Args: + src: Python code string containing a constant (e.g., "x + 3.14" or "y - (-5)") + + Returns: + String representation of the constant value + + Raises: + ValueError: If no constant is found in the source string + + Examples: + >>> _extract_constant_from_ast_str("x + 3.14") + '3.14' + >>> _extract_constant_from_ast_str("y + (-5)") + '-5' + """ + tree = ast.parse(src) + + for node in ast.walk(tree): + if isinstance(node, ast.Constant): + return str(node.value) + elif isinstance(node, ast.UnaryOp) and isinstance(node.operand, ast.Constant): + if isinstance(node.op, ast.USub): + return f"-{node.operand.value}" + elif isinstance(node.op, ast.UAdd): + return str(node.operand.value) + + raise ValueError("No constant found") + + +def _extract_non_connector_syms_from_tasklet(node: dace.nodes.Tasklet) -> typing.Set[str]: + """ + Identify free symbols in tasklet code that are not input/output connectors. + + This function extracts all symbolic variables from the right-hand side of the tasklet's + code expression and filters out those that correspond to input/output connectors, + leaving only the actual free symbols (e.g., SDFG symbols or constants). + + Args: + node: The tasklet node to analyze (must be a Python tasklet) + + Returns: + Set of symbol names that appear in the code but are not connectors + + Examples: + For a tasklet "out = in_a + N" with connectors {in_a, out}, this returns {"N"} + For a tasklet "out = in_x * alpha + beta" with connectors {in_x, out}, this returns {"alpha", "beta"} + + Note: + Requires the tasklet to use Python language and have valid symbolic expressions. + """ + assert isinstance(node, dace.nodes.Tasklet) + assert node.code.language == dace.dtypes.Language.Python + connectors = {str(s) for s in set(node.in_connectors.keys()).union(set(node.out_connectors.keys()))} + code_rhs: str = node.code.as_string.split("=")[-1].strip() + all_syms = {str(s) for s in dace.symbolic.SymExpr(code_rhs).free_symbols} + real_free_syms = all_syms - connectors + free_non_connector_syms = {str(s) for s in real_free_syms} + return free_non_connector_syms + + +_BINOP_SYMBOLS = { + ast.Add: "+", + ast.Sub: "-", + ast.Mult: "*", + ast.Div: "/", +} +"""Mapping from AST binary operation nodes to their string representations.""" + +_UNARY_SYMBOLS = { + ast.UAdd: "+", + ast.USub: "-", +} +"""Mapping from AST unary operation nodes to their string representations.""" + +_CMP_SYMBOLS = { + ast.Gt: ">", + ast.Lt: "<", + ast.GtE: ">=", + ast.LtE: "<=", + ast.Eq: "==", + ast.NotEq: "!=", +} +"""Mapping from AST comparison operation nodes to their string representations.""" + +_SUPPORTED_OPS = {'*', '+', '-', '/', '>', '<', '>=', '<=', '==', '!='} +"""Set of supported binary and comparison operators.""" + +_SUPPORTED = {'*', '+', '-', '/', 'abs', 'exp', 'sqrt', 'log', 'ln', 'exp', 'pow', 'min', 'max'} +"""Set of all supported operations including functions.""" + + +def _extract_single_op(src: str, default_to_assignment: bool = False) -> str: + """ + Extract the single supported operation from Python code. + + Parses the code string and identifies exactly one supported operation. The operation + can be a binary operator (+, -, *, /), comparison operator (>, <, etc.), or a + function call (abs, exp, sqrt, etc.). + + Args: + src: Python code string (should be parseable into an AST) (e.g., "out = a + b" or "out = sqrt(x)") + default_to_assignment: If True, return "=" when no operation is found; + if False, raise ValueError + + Returns: + The operation symbol (e.g., "+", "*") or function name (e.g., "sqrt", "abs") + + Note: + This function assumes tasklet contains a single operation. + You can run the pass `SplitTasklets` to get such tasklets. + """ + print(f"Extract single op from {src}") + + tree = ast.parse(src) + found = None + + for node in ast.walk(tree): + op = None + + if isinstance(node, ast.BinOp): + op = _BINOP_SYMBOLS.get(type(node.op), None) + elif isinstance(node, ast.UnaryOp): + op = _UNARY_SYMBOLS.get(type(node.op), None) + elif isinstance(node, ast.Compare): + assert len(node.ops) == 1 + op = _CMP_SYMBOLS.get(type(node.ops[0]), None) + elif isinstance(node, ast.Call): + if isinstance(node.func, ast.Name): + op = node.func.id + elif isinstance(node.func, ast.Attribute): + op = node.func.attr + + if op is None: + continue + + if op not in _SUPPORTED: + print(f"Found unsupported op {op} in {src}") + + if found is not None: + raise ValueError("More than one supported operation found") + + found = op + + code_rhs = src.split(" = ")[-1].strip() + try: + tree = ast.parse(code_rhs, mode="eval") + call_node = tree.body + if isinstance(call_node, ast.Call) and isinstance(call_node.func, ast.Name): + func_name = call_node.func.id + found = func_name + except SyntaxError as e: + print(e) + pass + + if found is None: + if default_to_assignment is True: + found = "=" + else: + raise ValueError(f"No supported operation found for code_str: {src}") + + return found + + +def _match_connector_to_data(state: dace.SDFGState, tasklet: dace.nodes.Tasklet) -> Dict: + """ + Map input connector names to their corresponding data descriptors. + + Creates a dictionary that maps each input connector of the tasklet to its + associated data descriptor (array or scalar) by examining the incoming edges. + + Args: + state: The SDFG state containing the tasklet + tasklet: The tasklet node whose connectors to map + + Returns: + Dictionary mapping connector names (str) to data descriptors (dace.data.Data) + + Examples: + For a tasklet with input connector "in_a" connected to array "A": + >>> _match_connector_to_data(state, tasklet) + {'in_a': } + """ + tdict = dict() + for ie in state.in_edges(tasklet): + if ie.data is not None: + tdict[ie.dst_conn] = state.sdfg.arrays[ie.data.data] + + return tdict + + +def _get_scalar_and_array_arguments(state: dace.SDFGState, tasklet: dace.nodes.Tasklet) -> Tuple[Set[str], Set[str]]: + """ + Separate tasklet input connectors into scalars and arrays. + + Returns: + Tuple of (scalar_connectors, array_connectors) where each is a set of connector names + """ + tdict = _match_connector_to_data(state, tasklet) + scalars = {k for k, v in tdict.items() if isinstance(v, dace.data.Scalar)} + arrays = {k for k, v in tdict.items() if isinstance(v, dace.data.Array)} + return scalars, arrays + + +def _reorder_rhs(code_str: str, op: str, rhs1: str, rhs2: str) -> Tuple[str, str]: + """ + Determine the correct left-right ordering of operands based on their appearance in code. + + For binary operations, this function analyzes the code to determine which operand + appears on the left side of the operator and which appears on the right. This is + important for non-commutative operations like subtraction and division. + + Args: + code_str: Full tasklet code string (e.g., "out = a - b") + op: Operation symbol (e.g., "-", "*", "min") + rhs1: First operand name + rhs2: Second operand name + + Returns: + Tuple of (left_operand, right_operand) in the order they appear in the code + + Note: + For function calls, uses AST parsing to extract arguments in order. + For operators, splits the code by the operator symbol. + """ + code_rhs = code_str.split(" = ")[-1].strip() + if op not in _SUPPORTED_OPS: + try: + tree = ast.parse(code_rhs, mode="eval") + call_node = tree.body + if not isinstance(call_node, ast.Call): + raise ValueError(f"Expected a function call in expression: {code_rhs}") + + args = [ast.get_source_segment(code_rhs, arg).strip() for arg in call_node.args] + left_string, right_string = args[0:2] + assert len(args) == 2 + except SyntaxError as e: + raise ValueError(f"Failed to parse function expression: {code_rhs}") from e + + else: + left_string, right_string = [cstr.strip() for cstr in code_rhs.split(op)] + + if rhs1 in left_string and rhs2 in left_string: + raise Exception("SSA tasklet, rhs1 and rhs2 both can't appear on left side of the operand") + + if rhs1 in right_string and rhs2 in right_string: + raise Exception("SSA tasklet, rhs1 and rhs2 both can't appear on right side of the operand") + + if rhs1 in left_string and rhs2 in right_string: + return rhs1, rhs2 + + if rhs1 in right_string and rhs2 in left_string: + return rhs2, rhs1 + + if rhs1 not in left_string and rhs2 not in right_string: + raise Exception("SSA tasklet, rhs1 appears in none of the substrings") + + if rhs2 not in left_string and rhs2 not in right_string: + raise Exception("SSA tasklet, rhs2 appears in none of the substrings") + + +def count_name_occurrences(expr: str, name: str) -> int: + """ + Count how many times a given variable name appears in an expression. + + Uses AST parsing to accurately count variable name occurrences, distinguishing + between actual variable references and other uses of the same string. + + Args: + expr: Expression to parse (e.g., "a + b * a") + name: Variable name to count (e.g., "a") + + Returns: + Number of times the variable appears in the expression + + Examples: + >>> count_name_occurrences("a + b * a", "a") + 2 + >>> count_name_occurrences("x * x * x", "x") + 3 + >>> count_name_occurrences("abs(y)", "y") + 1 + + Note: + This is used to distinguish between unary operations (single occurrence) + and binary operations where the same operand appears twice (e.g., x * x). + """ + tree = ast.parse(expr, mode="eval") + count = 0 + for node in ast.walk(tree): + if isinstance(node, ast.Name) and node.id == name: + count += 1 + return count + + +def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: + """ + Analyze a tasklet and return its classification with metadata. + + This is the main entry point for tasklet classification. It inspects the tasklet's + code, input/output connectors, and data descriptors to determine the tasklet type + and extract relevant metadata for code generation. + + Args: + state: The SDFG state containing the tasklet + node: The tasklet node to classify + + Returns: + Dictionary with the following keys: + - type (TaskletType): The classified tasklet type + - lhs (str): Output connector name (left-hand side variable) + - rhs1 (str or None): First input connector/operand name (left of the operator if both rhs1 and rhs2 are set) + - rhs2 (str or None): Second input connector/operand name (right of the operator if both rhs1 and rhs2 are set, can be same as rhs1) + - constant1 (str or None): First constant/symbol value (left of the operator if both c1 and c2 are set) + - constant2 (str or None): Second constant/symbol value (right of the operator if both c1 and c2 are set, can be same as c1) + - op (str): Operation symbol or function name + + Raises: + AssertionError: If tasklet has more than 1 output connector + NotImplementedError: If tasklet pattern is not supported + ValueError: If code cannot be parsed or contains unsupported operations + + Classification Logic: + (Output can be scalar / array) + Single Input (n_in == 1): + - Direct assignment: a = b + - Array/scalar with constant: a = b + 5 + - Array/scalar with symbol: a = b * N + - Unary operation: a = abs(b) or a = b * b + + Two Inputs (n_in == 2): + - Two arrays: a = b + c + - Array and scalar: a = b * scl + - Two scalars: a = scl1 + scl2 + + Zero Inputs (n_in == 0): + - Symbol assignment: a = N + - Two symbols: a = N + M + - Unary symbol: a = abs(N) + + Examples: + >>> # For tasklet "out = in_a + 5" + >>> result = classify_tasklet(state, tasklet_node) + >>> result + { + 'type': TaskletType.ARRAY_SYMBOL, + 'lhs': 'out', + 'rhs1': 'in_a', + 'rhs2': None, + 'constant1': '5', + 'constant2': None, + 'op': '+' + } + # For more see the unit tests + + Constraints: + - Tasklet must have exactly 1 output connector + - Tasklet must use Python language + - Code must contain at most one operation (See SplitTasklets pass to enforce this easily) + """ + in_conns = list(node.in_connectors.keys()) + out_conns = list(node.out_connectors.keys()) + n_in = len(in_conns) + n_out = len(out_conns) + + assert n_out <= 1, "Only support tasklets with at most 1 output in this pass" + lhs = next(iter(node.out_connectors.keys())) if n_out == 1 else None + + assert isinstance(node, dace.nodes.Tasklet) + code: CodeBlock = node.code + assert code.language == dace.dtypes.Language.Python + code_str: str = code.as_string + + info_dict = {"type": None, "lhs": lhs, "rhs1": None, "rhs2": None, "constant1": None, "constant2": None, "op": None} + + assert n_out == 1 + + if n_in == 1: + rhs = in_conns[0] + in_edges = {ie for ie in state.in_edges_by_connector(node, rhs)} + assert len(in_edges) == 1, f"expected 1 in-edge for connector {rhs}, found {len(in_edges)}" + rhs_data_name = in_edges.pop().data.data + rhs_data = state.sdfg.arrays[rhs_data_name] + out_edges = {oe for oe in state.out_edges_by_connector(node, lhs)} + assert len(out_edges) == 1, f"expected 1 out-edge for connector {lhs}, found {len(out_edges)}" + lhs_data_name = out_edges.pop().data.data + lhs_data = state.sdfg.arrays[lhs_data_name] + + if code_str == f"{lhs} = {rhs}" or code_str == f"{lhs} = {rhs};": + lhs_datadesc = lhs_data + rhs_datadesc = rhs_data + ttype = None + if isinstance(lhs_datadesc, dace.data.Array) and isinstance(rhs_datadesc, dace.data.Array): + ttype = TaskletType.ARRAY_ARRAY_ASSIGNMENT + elif isinstance(lhs_datadesc, dace.data.Array) and isinstance(rhs_datadesc, dace.data.Scalar): + ttype = TaskletType.ARRAY_SCALAR_ASSIGNMENT + elif isinstance(lhs_datadesc, dace.data.Scalar) and isinstance(rhs_datadesc, dace.data.Array): + ttype = TaskletType.SCALAR_ARRAY_ASSIGNMENT + elif isinstance(lhs_datadesc, dace.data.Scalar) and isinstance(rhs_datadesc, dace.data.Scalar): + ttype = TaskletType.SCALAR_SCALAR_ASSIGNMENT + else: + raise ValueError(f"Unsupported Assignment Type {lhs_datadesc} <- {rhs_datadesc}") + info_dict.update({"type": ttype, "op": "=", "rhs1": rhs}) + return info_dict + + has_constant = False + constant = None + try: + constant = _extract_constant_from_ast_str(code_str) + has_constant = True + except Exception: + has_constant = False + + free_non_connector_syms = _extract_non_connector_syms_from_tasklet(node) + if len(free_non_connector_syms) == 1: + has_constant = True + constant = free_non_connector_syms.pop() + + if not has_constant: + rhs_occurence_count = count_name_occurrences(code_str.split(" = ")[1].strip(), rhs) + if isinstance(rhs_data, dace.data.Array): + rhs2 = None if rhs_occurence_count == 1 else rhs + ttype = TaskletType.UNARY_ARRAY if rhs_occurence_count == 1 else TaskletType.ARRAY_ARRAY + info_dict.update({"type": ttype, "rhs1": rhs, "rhs2": rhs2, "op": _extract_single_op(code_str)}) + return info_dict + elif isinstance(rhs_data, dace.data.Scalar): + rhs2 = None if rhs_occurence_count == 1 else rhs + ttype = TaskletType.UNARY_SCALAR if rhs_occurence_count == 1 else TaskletType.SCALAR_SCALAR + info_dict.update({"type": ttype, "rhs1": rhs, "rhs2": rhs2, "op": _extract_single_op(code_str)}) + return info_dict + else: + raise Exception(f"Unhandled case in tasklet type (1) {rhs_data}, {type(rhs_data)}") + else: + if isinstance(rhs_data, dace.data.Array): + info_dict.update({ + "type": TaskletType.ARRAY_SYMBOL, + "rhs1": rhs, + "constant1": constant, + "op": _extract_single_op(code_str) + }) + return info_dict + elif isinstance(rhs_data, dace.data.Scalar): + info_dict.update({ + "type": TaskletType.SCALAR_SYMBOL, + "rhs1": rhs, + "constant1": constant, + "op": _extract_single_op(code_str) + }) + return info_dict + else: + raise Exception("Unhandled case in tasklet type (2) {rhs_data}, {type(rhs_data)}") + + elif n_in == 2: + op = _extract_single_op(code_str) + rhs1, rhs2 = in_conns[0], in_conns[1] + rhs1, rhs2 = _reorder_rhs(code_str, op, rhs1, rhs2) + + lhs = next(iter(node.out_connectors.keys())) + scalars, arrays = _get_scalar_and_array_arguments(state, node) + assert len(scalars) + len(arrays) == 2 + + if len(arrays) == 2 and len(scalars) == 0: + info_dict.update({"type": TaskletType.ARRAY_ARRAY, "rhs1": rhs1, "rhs2": rhs2, "op": op}) + return info_dict + elif len(scalars) == 1 and len(arrays) == 1: + array_arg = next(iter(arrays)) + scalar_arg = next(iter(scalars)) + info_dict.update({"type": TaskletType.ARRAY_SCALAR, "rhs1": array_arg, "constant1": scalar_arg, "op": op}) + return info_dict + elif len(scalars) == 2: + info_dict.update({"type": TaskletType.SCALAR_SCALAR, "rhs1": rhs1, "rhs2": rhs2, "op": op}) + return info_dict + + elif n_in == 0: + free_syms = _extract_non_connector_syms_from_tasklet(node) + assert len(free_syms) == 2 or len(free_syms) == 1, f"{str(free_syms)}" + if len(free_syms) == 2: + free_sym1 = free_syms.pop() + free_sym2 = free_syms.pop() + op = _extract_single_op(code_str, default_to_assignment=False) + free_sym1, free_sym2 = _reorder_rhs(code_str, op, free_sym1, free_sym2) + info_dict.update({ + "type": TaskletType.SYMBOL_SYMBOL, + "constant1": free_sym1, + "constant2": free_sym2, + "op": _extract_single_op(code_str) + }) + return info_dict + elif len(free_syms) == 1: + op = _extract_single_op(code_str, default_to_assignment=True) + if op == "=": + free_sym1 = free_syms.pop() + info_dict.update({"type": TaskletType.ARRAY_SYMBOL_ASSIGNMENT, "constant1": free_sym1, "op": "="}) + return info_dict + else: + free_sym1 = free_syms.pop() + rhs_occurence_count = count_name_occurrences(code_str.split(" = ")[1].strip(), free_sym1) + free_sym2 = None if rhs_occurence_count == 1 else free_sym1 + ttype = TaskletType.UNARY_SYMBOL if rhs_occurence_count == 1 else TaskletType.SYMBOL_SYMBOL + info_dict.update({"type": ttype, "constant1": free_sym1, "constant2": free_sym2, "op": op}) + return info_dict + + raise NotImplementedError("Unhandled case in detect tasklet type") diff --git a/tests/utils/classify_tasklet_test.py b/tests/utils/classify_tasklet_test.py new file mode 100644 index 0000000000..823356fa46 --- /dev/null +++ b/tests/utils/classify_tasklet_test.py @@ -0,0 +1,447 @@ +import pytest +import dace +import typing +import dace.sdfg.tasklet_utils as tutil + +tasklet_infos = [ + # === ARRAY + SYMBOL === + ("out = in_a + sym_b", "array", {"a"}, {}, {"sym_b"}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "+", + "constant1": "sym_b", + "constant2": None + }), + ("out = in_a - sym_b", "array", {"a"}, {}, {"sym_b"}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "-", + "constant1": "sym_b", + "constant2": None + }), + ("out = in_a * sym_b", "array", {"a"}, {}, {"sym_b"}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "*", + "constant1": "sym_b", + "constant2": None + }), + ("out = in_a / sym_b", "array", {"a"}, {}, {"sym_b"}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "/", + "constant1": "sym_b", + "constant2": None + }), + + # === ARRAY + CONSTANT === + ("out = in_a + 2", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "+", + "constant1": "2", + "constant2": None + }), + ("out = in_a * 3", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "*", + "constant1": "3", + "constant2": None + }), + ("out = in_a / 2.5", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "/", + "constant1": "2.5", + "constant2": None + }), + ("out = in_a - 5", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "-", + "constant1": "5", + "constant2": None + }), + + # === ARRAY + ARRAY === + ("out = in_a + in_b", "array", {"a", "b"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_b", + "op": "+", + "constant1": None, + "constant2": None + }), + ("out = in_a - in_b", "array", {"a", "b"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_b", + "op": "-", + "constant1": None, + "constant2": None + }), + ("out = in_a * in_b", "array", {"a", "b"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_b", + "op": "*", + "constant1": None, + "constant2": None + }), + ("out = in_a / in_b", "array", {"a", "b"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_b", + "op": "/", + "constant1": None, + "constant2": None + }), + + # === SCALAR + SYMBOL === + ("out = in_x + sym_y", "scalar", {}, {"x"}, {"sym_y"}, { + "type": tutil.TaskletType.SCALAR_SYMBOL, + "lhs": "out", + "rhs1": "in_x", + "rhs2": None, + "op": "+", + "constant1": "sym_y", + "constant2": None + }), + ("out = in_x * sym_y", "scalar", {}, {"x"}, {"sym_y"}, { + "type": tutil.TaskletType.SCALAR_SYMBOL, + "lhs": "out", + "rhs1": "in_x", + "rhs2": None, + "op": "*", + "constant1": "sym_y", + "constant2": None + }), + ("out = in_x - sym_y", "scalar", {}, {"x"}, {"sym_y"}, { + "type": tutil.TaskletType.SCALAR_SYMBOL, + "lhs": "out", + "rhs1": "in_x", + "rhs2": None, + "op": "-", + "constant1": "sym_y", + "constant2": None + }), + + # === SCALAR + SCALAR === + ("out = in_x + in_y", "scalar", {}, {"x", "y"}, {}, { + "type": tutil.TaskletType.SCALAR_SCALAR, + "lhs": "out", + "rhs1": "in_x", + "rhs2": "in_y", + "op": "+", + "constant1": None, + "constant2": None + }), + ("out = in_x * in_y", "scalar", {}, {"x", "y"}, {}, { + "type": tutil.TaskletType.SCALAR_SCALAR, + "lhs": "out", + "rhs1": "in_x", + "rhs2": "in_y", + "op": "*", + "constant1": None, + "constant2": None + }), + ("out = in_x / in_y", "scalar", {}, {"x", "y"}, {}, { + "type": tutil.TaskletType.SCALAR_SCALAR, + "lhs": "out", + "rhs1": "in_x", + "rhs2": "in_y", + "op": "/", + "constant1": None, + "constant2": None + }), + + # === SYMBOL + SYMBOL === + ("out = sym_a + sym_b", "scalar", {}, {}, {"sym_a", "sym_b"}, { + "type": tutil.TaskletType.SYMBOL_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "+", + "constant1": "sym_a", + "constant2": "sym_b" + }), + ("out = sym_a * sym_b", "scalar", {}, {}, {"sym_a", "sym_b"}, { + "type": tutil.TaskletType.SYMBOL_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "*", + "constant1": "sym_a", + "constant2": "sym_b" + }), + ("out = sym_a / sym_b", "scalar", {}, {}, {"sym_a", "sym_b"}, { + "type": tutil.TaskletType.SYMBOL_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "/", + "constant1": "sym_a", + "constant2": "sym_b" + }), + + # === FUNCTIONAL / SUPPORTED OPS === + ("out = abs(in_a)", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.UNARY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "abs", + "constant1": None, + "constant2": None + }), + ("out = exp(in_a)", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.UNARY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "exp", + "constant1": None, + "constant2": None + }), + ("out = sqrt(in_a)", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.UNARY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "sqrt", + "constant1": None, + "constant2": None + }), + ("out = log(in_a)", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.UNARY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "log", + "constant1": None, + "constant2": None + }), + ("out = pow(in_a, 2)", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_SYMBOL, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "pow", + "constant1": "2", + "constant2": None + }), + ("out = min(in_a, in_b)", "array", {"a", "b"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_b", + "op": "min", + "constant1": None, + "constant2": None + }), + ("out = max(in_a, in_b)", "array", {"a", "b"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_b", + "op": "max", + "constant1": None, + "constant2": None + }), + ("out = abs(sym_a)", "array", {}, {}, {"sym_a"}, { + "type": tutil.TaskletType.UNARY_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "abs", + "constant1": "sym_a", + "constant2": None + }), + ("out = exp(in_a)", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.UNARY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "exp", + "constant1": None, + "constant2": None + }), + ("out = sqrt(in_a)", "scalar", {}, {"a"}, {}, { + "type": tutil.TaskletType.UNARY_SCALAR, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "sqrt", + "constant1": None, + "constant2": None + }), + + # === ASSIGNMENTS === + ("out = in_a", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY_ASSIGNMENT, + "lhs": "out", + "rhs1": "in_a", + "rhs2": None, + "op": "=", + "constant1": None, + "constant2": None + }), + ("out = in_b", "array", {"b"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY_ASSIGNMENT, + "lhs": "out", + "rhs1": "in_b", + "rhs2": None, + "op": "=", + "constant1": None, + "constant2": None + }), + ("out = in_b", "array", {}, {"b"}, {}, { + "type": tutil.TaskletType.ARRAY_SCALAR_ASSIGNMENT, + "lhs": "out", + "rhs1": "in_b", + "rhs2": None, + "op": "=", + "constant1": None, + "constant2": None + }), + ("out = in_b", "scalar", {"b"}, {}, {}, { + "type": tutil.TaskletType.SCALAR_ARRAY_ASSIGNMENT, + "lhs": "out", + "rhs1": "in_b", + "rhs2": None, + "op": "=", + "constant1": None, + "constant2": None + }), + ("out = in_b", "scalar", {}, {"b"}, {}, { + "type": tutil.TaskletType.SCALAR_SCALAR_ASSIGNMENT, + "lhs": "out", + "rhs1": "in_b", + "rhs2": None, + "op": "=", + "constant1": None, + "constant2": None + }), + ("out = sym_a", "array", {}, {}, {"sym_a"}, { + "type": tutil.TaskletType.ARRAY_SYMBOL_ASSIGNMENT, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "=", + "constant1": "sym_a", + "constant2": None + }), + + # === SINGLE-INPUT TWO RHS CASE === + ("out = in_a * in_a", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_a", + "op": "*", + "constant1": None, + "constant2": None + }), + ("out = in_a + in_a", "array", {"a"}, {}, {}, { + "type": tutil.TaskletType.ARRAY_ARRAY, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_a", + "op": "+", + "constant1": None, + "constant2": None + }), + ("out = in_a + in_a", "array", {}, {"a"}, {}, { + "type": tutil.TaskletType.SCALAR_SCALAR, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_a", + "op": "+", + "constant1": None, + "constant2": None + }), +] + +i = 0 + + +def _gen_sdfg( + tasklet_info: typing.Tuple[str, str, typing.Set[str], typing.Set[str], typing.Set[str], tutil.TaskletType] +) -> dace.SDFG: + global i + i += 1 + sdfg = dace.SDFG(f"sd{i}") + state = sdfg.add_state("s0", is_start_block=True) + + expr_str, out_type, in_arrays, in_scalars, in_symbols, _ = tasklet_info + + t1 = state.add_tasklet(name="t1", + inputs={f"in_{a}" + for a in in_arrays}.union({f"in_{a}" + for a in in_scalars}), + outputs={"out"}, + code=expr_str) + + for in_array in in_arrays: + sdfg.add_array(in_array, (1, ), dace.float64) + state.add_edge(state.add_access(in_array), None, t1, f"in_{in_array}", dace.memlet.Memlet(f"{in_array}[0]")) + for in_scalar in in_scalars: + sdfg.add_scalar(in_scalar, dace.float64) + state.add_edge(state.add_access(in_scalar), None, t1, f"in_{in_scalar}", dace.memlet.Memlet(f"{in_scalar}[0]")) + for in_symbol in in_symbols: + sdfg.add_symbol(in_symbol, dace.float64) + + if out_type == "array": + sdfg.add_array("O", (1, ), dace.float64) + else: + sdfg.add_scalar("O", dace.float64) + + state.add_edge(t1, "out", state.add_access("O"), None, dace.memlet.Memlet("O[0]" if out_type == "array" else "O")) + + sdfg.validate() + return sdfg + + +@pytest.mark.parametrize("tasklet_info", tasklet_infos) +def test_single_tasklet_split(tasklet_info): + sdfg = _gen_sdfg(tasklet_info) + sdfg.validate() + sdfg.compile() + + _, _, _, _, _, desired_tasklet_info = tasklet_info + + tasklets = {(n, g) for n, g in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.Tasklet)} + assert len(tasklets) == 1 + tasklet, state = tasklets.pop() + + tasklet_info_dict = tutil.classify_tasklet(state=state, node=tasklet) + print(desired_tasklet_info) + print(tasklet_info_dict) + + assert desired_tasklet_info == tasklet_info_dict, f"Expected: {desired_tasklet_info}, Got: {tasklet_info_dict}" + + +if __name__ == "__main__": + for config_tuple in tasklet_infos: + test_single_tasklet_split(config_tuple) From 60f7f9d8b4d1dfcd9d7fb293f0c774a90cc054d4 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 14:03:19 +0100 Subject: [PATCH 06/17] Rm use of | from type hints --- dace/sdfg/construction_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/construction_utils.py b/dace/sdfg/construction_utils.py index 388d0c9d4c..a2085fafe8 100644 --- a/dace/sdfg/construction_utils.py +++ b/dace/sdfg/construction_utils.py @@ -684,8 +684,8 @@ def get_num_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: dace.nodes.Ma return len(get_parent_map_and_loop_scopes(root_sdfg, node, parent_state)) -def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry | ControlFlowRegion - | dace.nodes.Tasklet | ConditionalBlock, parent_state: dace.SDFGState): +def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes.MapEntry,ControlFlowRegion + ,dace.nodes.Tasklet, ConditionalBlock], parent_state: dace.SDFGState): scope_dict = parent_state.scope_dict() if parent_state is not None else None num_parent_maps_and_loops = 0 cur_node = node From ce952138f000b3bcd9b3b07bf1b297834c5114b2 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 14:10:39 +0100 Subject: [PATCH 07/17] Add stuff --- dace/sdfg/construction_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/construction_utils.py b/dace/sdfg/construction_utils.py index a2085fafe8..7deb9f3f11 100644 --- a/dace/sdfg/construction_utils.py +++ b/dace/sdfg/construction_utils.py @@ -684,8 +684,9 @@ def get_num_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: dace.nodes.Ma return len(get_parent_map_and_loop_scopes(root_sdfg, node, parent_state)) -def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes.MapEntry,ControlFlowRegion - ,dace.nodes.Tasklet, ConditionalBlock], parent_state: dace.SDFGState): +def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes.MapEntry, ControlFlowRegion, + dace.nodes.Tasklet, ConditionalBlock], + parent_state: dace.SDFGState): scope_dict = parent_state.scope_dict() if parent_state is not None else None num_parent_maps_and_loops = 0 cur_node = node From 19130546ad67eec2997c8c7546abc59a34517844 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 15:33:37 +0100 Subject: [PATCH 08/17] Minor refactor --- dace/sdfg/construction_utils.py | 2 -- dace/transformation/interstate/branch_elimination.py | 1 - tests/transformations/interstate/branch_elimination_test.py | 3 +-- 3 files changed, 1 insertion(+), 5 deletions(-) diff --git a/dace/sdfg/construction_utils.py b/dace/sdfg/construction_utils.py index 7deb9f3f11..4d683c02e7 100644 --- a/dace/sdfg/construction_utils.py +++ b/dace/sdfg/construction_utils.py @@ -338,8 +338,6 @@ def _get_out_conn_name(src, state=state): dst.add_in_connector(f"IN_{data_access}_p", force=True) src.add_out_connector(_get_out_conn_name(dst)) - parent_graph.sdfg.save("x.sdfg") - # Re-propagate memlets when subsets are explicit if add_with_exact_subset: propagate_memlets_state(parent_graph.sdfg, parent_graph) diff --git a/dace/transformation/interstate/branch_elimination.py b/dace/transformation/interstate/branch_elimination.py index dd21b65305..93e2c1315d 100644 --- a/dace/transformation/interstate/branch_elimination.py +++ b/dace/transformation/interstate/branch_elimination.py @@ -1377,7 +1377,6 @@ def demote_branch_only_symbols_appearing_only_a_single_branch_to_scalars_and_try # Copy all access nodes to the next state, connect the sink node from prev. state # to the next state body.reset_cfg_list() - body.sdfg.save("x2.sdfg") assignment_state, other_state = list(body.bfs_nodes())[1:3] node_map = cutil.copy_state_contents(assignment_state, other_state) # Multiple symbols -> multiple sink nodes diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py index 202edce85e..570a334d07 100644 --- a/tests/transformations/interstate/branch_elimination_test.py +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -558,7 +558,7 @@ def test_try_clean(): xform.conditional = cblock xform.try_clean(graph=parent_graph, sdfg=parent_sdfg) - # Should have moe states before + # Should have move states before cblocks = {n for n, g in sdfg1.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(cblocks) == 2 # A state must have been moved before) @@ -2036,7 +2036,6 @@ def test_loop_param_usage(): C = np.random.choice([0.001, 5.0], size=(N, N)) sdfg = loop_param_usage.to_sdfg() - sdfg.save("x.sdfg") cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(cblocks) == 1 From e504b9d63c37da8a2f743abc62a4033072bbd976 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 17:06:48 +0100 Subject: [PATCH 09/17] Refactor copy/reuse ont ransformations --- .../interstate/branch_elimination_test.py | 43 ++++++++++++------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py index 570a334d07..6bdd7a5d07 100644 --- a/tests/transformations/interstate/branch_elimination_test.py +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -319,21 +319,23 @@ def run_and_compare( sdfg.validate() out_no_fuse = {k: v.copy() for k, v in arrays.items()} sdfg(**out_no_fuse) + + copy_sdfg = copy.deepcopy(sdfg) + copy_sdfg.name = sdfg.name + "_branch_eliminated" # Apply transformation if use_pass: fb = EliminateBranches() fb.try_clean = True - fb.apply_pass(sdfg, {}) + fb.apply_pass(copy_sdfg, {}) else: - apply_branch_elimination(sdfg, 2) - sdfg.name = sdfg.label + "_transformed" + apply_branch_elimination(copy_sdfg, 2) # Run SDFG version (with transformation) out_fused = {k: v.copy() for k, v in arrays.items()} - sdfg(**out_fused) + copy_sdfg(**out_fused) - branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + branch_code = {n for n, g in copy_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len( branch_code) == num_expected_branches, f"(actual) len({branch_code}) != (desired) {num_expected_branches}" @@ -353,17 +355,21 @@ def run_and_compare_sdfg( sdfg(**out_no_fuse) # Run SDFG version (with transformation) + copy_sdfg = copy.deepcopy(sdfg) + copy_sdfg.name = sdfg.name + "_branch_eliminated" fb = EliminateBranches() fb.try_clean = True fb.permissive = permissive - fb.apply_pass(sdfg, {}) + fb.apply_pass(copy_sdfg, {}) out_fused = {k: v.copy() for k, v in arrays.items()} - sdfg(**out_fused) + copy_sdfg(**out_fused) # Compare all arrays for name in arrays.keys(): np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) + return copy_sdfg + @pytest.mark.parametrize("use_pass_flag", [True, False]) def test_branch_dependent_value_write(use_pass_flag): @@ -1196,9 +1202,9 @@ def test_try_clean_on_complicated_pattern_for_manual_clean_up_one(): ssp.ignore = scalar_names ssp.apply_pass(nsdfg.sdfg, {}) - run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0]) + transformed_sdfg = run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0]) - branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + branch_code = {n for n, g in transformed_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" @@ -1266,9 +1272,9 @@ def test_try_clean_on_complicated_pattern_for_manual_clean_up_two(): ssp.ignore = scalar_names ssp.apply_pass(nsdfg.sdfg, {}) - run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0], e=E[0]) + transformed_sdfg = run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0], e=E[0]) - branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + branch_code = {n for n, g in transformed_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" @@ -1298,7 +1304,6 @@ def test_single_assignment(): def test_single_assignment_cond_from_scalar(): A = np.ones(shape=(512, ), dtype=np.float64) before = single_assignment_cond_from_scalar.to_sdfg() - before.name = "non_fusion_single_assignment_cond_from_scalar" before.compile() run_and_compare(single_assignment_cond_from_scalar, 0, True, a=A) @@ -1365,9 +1370,14 @@ def test_condition_from_transient_scalar(): _if_cond_42 = np.random.choice([8.0, 11.0], size=(1, )) sdfg = _get_sdfg_with_condition_from_transient_scalar() - run_and_compare_sdfg(sdfg, permissive=False, zsolac=zsolac, zlcond2=zlcond2, za=za, _if_cond_42=_if_cond_42[0]) + transformed_sdfg = run_and_compare_sdfg(sdfg, + permissive=False, + zsolac=zsolac, + zlcond2=zlcond2, + za=za, + _if_cond_42=_if_cond_42[0]) - branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + branch_code = {n for n, g in transformed_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" @@ -1474,6 +1484,7 @@ def test_disjoint_chain_split_branch_only(rtt_val): rtt = np.random.choice([rtt_val], size=(1, )) copy_sdfg = copy.deepcopy(sdfg) + copy_sdfg.name = sdfg.name + "_branch_eliminated" arrays = {"zsolqa": zsolqa, "zrainacc": zrainacc, "zrainaut": zrainaut, "ztp1": ztp1, "rtt": rtt[0]} sdfg.validate() @@ -1915,12 +1926,12 @@ def safe_uniform(low, high, size): sdfg.validate() out_no_fuse = {k: v.copy() for k, v in data.items()} sdfg(**out_no_fuse) + # Apply transformation fb = EliminateBranches() fb.try_clean = True fb.eps_operator_type_for_log_and_div = eps_operator_type_for_log_and_div fb.apply_pass(sdfg, {}) - sdfg.name = sdfg.label + "_transformed" cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(cblocks) == 0 @@ -1991,12 +2002,12 @@ def safe_uniform(low, high, size): sdfg.validate() out_no_fuse = {k: v.copy() for k, v in data.items()} sdfg(**out_no_fuse) + # Apply transformation fb = EliminateBranches() fb.try_clean = True fb.eps_operator_type_for_log_and_div = eps_operator_type_for_log_and_div fb.apply_pass(sdfg, {}) - sdfg.name = sdfg.label + "_transformed" cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(cblocks) == 0 From f5bdd912618c0cb2025b20bcc849e765cee637b7 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Wed, 29 Oct 2025 21:05:49 +0100 Subject: [PATCH 10/17] Try something to fix spurious fails in the runner --- .../interstate/branch_elimination_test.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py index 6bdd7a5d07..683fd5b54b 100644 --- a/tests/transformations/interstate/branch_elimination_test.py +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -322,6 +322,8 @@ def run_and_compare( copy_sdfg = copy.deepcopy(sdfg) copy_sdfg.name = sdfg.name + "_branch_eliminated" + del sdfg + # Apply transformation if use_pass: fb = EliminateBranches() @@ -343,6 +345,10 @@ def run_and_compare( for name in arrays.keys(): np.testing.assert_allclose(out_fused[name], out_no_fuse[name], atol=1e-12) + del out_no_fuse + del out_fused + del copy_sdfg + def run_and_compare_sdfg( sdfg, @@ -357,6 +363,8 @@ def run_and_compare_sdfg( # Run SDFG version (with transformation) copy_sdfg = copy.deepcopy(sdfg) copy_sdfg.name = sdfg.name + "_branch_eliminated" + del sdfg + fb = EliminateBranches() fb.try_clean = True fb.permissive = permissive @@ -368,6 +376,9 @@ def run_and_compare_sdfg( for name in arrays.keys(): np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) + del out_no_fuse + del out_fused + return copy_sdfg From d0e6f2c6a74c169347a34952e31d14f4c1d626ab Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Fri, 31 Oct 2025 15:19:47 +0100 Subject: [PATCH 11/17] Merge --- dace/config_schema.yml | 4 +- dace/sdfg/construction_utils.py | 250 ++++----------------------- dace/sdfg/tasklet_utils.py | 249 ++++++++++++++++++++++++++ dace/sdfg/utils.py | 6 +- tests/utils/classify_tasklet_test.py | 16 +- 5 files changed, 293 insertions(+), 232 deletions(-) diff --git a/dace/config_schema.yml b/dace/config_schema.yml index a753a55f3b..4627ea1644 100644 --- a/dace/config_schema.yml +++ b/dace/config_schema.yml @@ -261,7 +261,7 @@ required: type: str title: Arguments description: Compiler argument flags - default: '-fopenmp -std=c++14 -fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' + default: '-fopenmp -fPIC -Wall -Wextra -O3 -march=native -ffast-math -Wno-unused-parameter -Wno-unused-label' default_Windows: '/O2 /fp:fast /arch:AVX2 /D_USRDLL /D_WINDLL /D__restrict__=__restrict' libs: @@ -303,7 +303,7 @@ required: type: str title: nvcc Arguments description: Compiler argument flags for CUDA - default: '-Xcompiler -march=native --use_fast_math -Xcompiler -Wno-unused-parameter' + default: '--expt-relaxed-constexpr -Xcompiler -march=native --use_fast_math -Xcompiler -Wno-unused-parameter' default_Windows: '-O3 --use_fast_math' hip_args: diff --git a/dace/sdfg/construction_utils.py b/dace/sdfg/construction_utils.py index 4d683c02e7..890c033e36 100644 --- a/dace/sdfg/construction_utils.py +++ b/dace/sdfg/construction_utils.py @@ -1,28 +1,15 @@ -import re -from typing import Dict, Set, Union +# Copyright 2019-2025 ETH Zurich and the DaCe authors. All rights reserved. import dace +from typing import Dict, Set, Union import copy - from dace.sdfg import ControlFlowRegion from dace.sdfg.propagation import propagate_memlets_state import copy from dace.properties import CodeBlock from dace.sdfg.state import ConditionalBlock, LoopRegion - -import sympy -from sympy import symbols, Function - -from sympy.printing.pycode import PythonCodePrinter +from sympy import Function import dace.sdfg.utils as sdutil -from dace.transformation.passes import FuseStates - - -class BracketFunctionPrinter(PythonCodePrinter): - - def _print_Function(self, expr): - name = self._print(expr.func) - args = ", ".join([self._print(arg) for arg in expr.args]) - return f"{name}[{args}]" +from dace.sdfg.tasklet_utils import token_replace_dict, extract_bracket_tokens, remove_bracket_tokens def copy_state_contents(old_state: dace.SDFGState, new_state: dace.SDFGState) -> Dict[dace.nodes.Node, dace.nodes.Node]: @@ -100,7 +87,28 @@ def copy_graph_contents(old_graph: ControlFlowRegion, def move_branch_cfg_up_discard_conditions(if_block: ConditionalBlock, body_to_take: ControlFlowRegion): - # Sanity check the ensure apssed arguments are correct + """ + Moves a branch of a conditional block up in the control flow graph (CFG), + replacing the conditional with the selected branch, discarding + the conditional check and other branches. + + This operation: + - Copies all nodes and edges from the selected branch (`body_to_take`) into + the parent graph of the conditional. + - Connects all incoming edges of the original conditional block to the + start of the selected branch. + - Connects all outgoing edges of the original conditional block to the + end of the selected branch. + - Removes the original conditional block from the graph. + + Parameters: + -if_block : ConditionalBlock + The conditional block in the CFG whose branch is to be promoted. + -body_to_take : ControlFlowRegion + The branch of the conditional block to be moved up. Must be one of the + branches of `if_block`. + """ + # Sanity check the ensure passed arguments are correct bodies = {b for _, b in if_block.branches} assert body_to_take in bodies assert isinstance(if_block, ConditionalBlock) @@ -139,11 +147,6 @@ def move_branch_cfg_up_discard_conditions(if_block: ConditionalBlock, body_to_ta graph.remove_node(if_block) -# Put map-body into NSDFG -# Convert Map to Loop -# Put map into NSDFG - - def insert_non_transient_data_through_parent_scopes(non_transient_data: Set[str], nsdfg_node: 'dace.nodes.NestedSDFG', parent_graph: 'dace.SDFGState', @@ -365,47 +368,6 @@ def _get_out_conn_name(src, state=state): nsdfg_node.symbol_mapping[str(sym)] = str(sym) -def token_replace_dict(code: str, repldict: Dict[str, str]) -> str: - # Split while keeping delimiters - tokens = re.split(r'(\s+|[()\[\]])', code) - - # Replace tokens that exactly match src - tokens = [repldict[token.strip()] if token.strip() in repldict else token for token in tokens] - - # Recombine everything - return ''.join(tokens).strip() - - -def token_match(string_to_check: str, pattern_str: str) -> str: - # Split while keeping delimiters - tokens = re.split(r'(\s+|[()\[\]])', string_to_check) - - # Replace tokens that exactly match src - tokens = {token.strip() for token in tokens} - - return pattern_str in tokens - - -def token_split(string_to_check: str, pattern_str: str) -> Set[str]: - # Split while keeping delimiters - tokens = re.split(r'(\s+|[()\[\]])', string_to_check) - - # Replace tokens that exactly match src - tokens = {token.strip() for token in tokens} - - return tokens - - -def token_split_variable_names(string_to_check: str) -> Set[str]: - # Split while keeping delimiters - tokens = re.split(r'(\s+|[()\[\]])', string_to_check) - - # Replace tokens that exactly match src - tokens = {token.strip() for token in tokens if token not in ["[", "]", "(", ")"] and token.isidentifier()} - - return tokens - - def replace_length_one_arrays_with_scalars(sdfg: dace.SDFG, recursive: bool = True, transient_only: bool = False): scalarized_arrays = set() for arr_name, arr in [(k, v) for k, v in sdfg.arrays.items()]: @@ -466,124 +428,6 @@ def replace_length_one_arrays_with_scalars(sdfg: dace.SDFG, recursive: bool = Tr replace_length_one_arrays_with_scalars(node.sdfg, recursive=True, transient_only=True) -def connect_array_names(sdfg: dace.SDFG, local_storage: dace.dtypes.StorageType, src_storage: dace.dtypes.StorageType, - local_name_prefix: str): - - array_name_dict = dict() - for state in sdfg.all_states(): - for node in state.nodes(): - if isinstance(node, dace.nodes.AccessNode): - local_arr = state.sdfg.arrays[node.data] - print(local_arr.storage) - if local_arr.storage == local_storage: - assert len(state.in_edges(node)) <= 1 - # Reads - for ie in state.in_edges(node): - if ie.data.data is not None and ie.data.data != node.data: - src_data = state.sdfg.arrays[ie.data.data] - print(src_data) - if src_data.storage == src_storage: - assert node.data not in array_name_dict - array_name_dict[node.data] = ie.data.data - # Writes - for oe in state.out_edges(node): - if oe.data.data is not None and oe.data.data != node.data: - dst_data = state.sdfg.arrays[oe.data.data] - print(dst_data) - if dst_data.storage == src_storage: - assert node.data not in array_name_dict - array_name_dict[node.data] = oe.data.data - - print(array_name_dict) - repldict = {k: f"{local_name_prefix}{v}" for k, v in array_name_dict.items()} - - sdfg.replace_dict(repldict, replace_keys=True) - sdfg.validate() - - -def tasklet_has_symbol(tasklet: dace.nodes.Tasklet, symbol_str: str) -> bool: - if tasklet.code.language == dace.dtypes.Language.Python: - try: - sym_expr = dace.symbolic.SymExpr(tasklet.code.as_astring) - return (symbol_str in {str(s) for s in sym_expr.free_symbols}) - except Exception as e: - return token_match(tasklet.code.as_string, symbol_str) - else: - return token_match(tasklet.code.as_string, symbol_str) - - -def replace_code(code_str: str, code_lang: dace.dtypes.Language, repldict: Dict[str, str]) -> str: - - def _str_replace(lhs: str, rhs: str) -> str: - code_str = token_replace_dict(rhs, repldict) - return f"{lhs.strip()} = {code_str.strip()}" - - if code_lang == dace.dtypes.Language.Python: - try: - lhs, rhs = code_str.split(" = ") - lhs = lhs.strip() - rhs = rhs.strip() - except Exception as e: - try: - new_rhs_sym_expr = dace.symbolic.SymExpr(code_str).subs(repldict) - printer = BracketFunctionPrinter({'strict': False}) - cleaned_expr = printer.doprint(new_rhs_sym_expr).strip() - return f"{cleaned_expr}" - except Exception as e: - return _str_replace(code_str) - try: - new_rhs_sym_expr = dace.symbolic.SymExpr(rhs).subs(repldict) - printer = BracketFunctionPrinter({'strict': False}) - cleaned_expr = printer.doprint(new_rhs_sym_expr).strip() - return f"{lhs.strip()} = {cleaned_expr}" - except Exception as e: - return _str_replace(rhs) - else: - return _str_replace(rhs) - - -def tasklet_replace_code(tasklet: dace.nodes.Tasklet, repldict: Dict[str, str]): - new_code = replace_code(tasklet.code.as_string, tasklet.code.language, repldict) - tasklet.code = CodeBlock(code=new_code, language=tasklet.code.language) - - -def extract_bracket_tokens(s: str) -> list[tuple[str, list[str]]]: - """ - Extracts all contents inside [...] along with the token before the '[' as the name. - - Args: - s (str): Input string. - - Returns: - List of tuples: [(name_token, string inside brackes)] - """ - results = [] - - # Pattern to match [content_inside] - pattern = re.compile(r'(\b\w+)\[([^\]]*?)\]') - - for match in pattern.finditer(s): - name = match.group(1) # token before '[' - content = match.group(2).split() # split content inside brackets into tokens - - results.append((name, " ".join(content))) - - return {k: v for (k, v) in results} - - -def remove_bracket_tokens(s: str) -> str: - """ - Removes all [...] patterns from the string. - - Args: - s (str): Input string. - - Returns: - str: String with all [...] removed. - """ - return re.sub(r'\[.*?\]', '', s) - - def generate_assignment_as_tasklet_in_state(state: dace.SDFGState, lhs: str, rhs: str): rhs = rhs.strip() rhs_sym_expr = dace.symbolic.SymExpr(rhs).evalf() @@ -642,40 +486,8 @@ def generate_assignment_as_tasklet_in_state(state: dace.SDFGState, lhs: str, rhs state.add_edge(t, k, v, None, dace.memlet.Memlet(expr=f"{data_name}[{access_str}]")) -def _find_parent_state(root_sdfg: dace.SDFG, node: dace.nodes.NestedSDFG): - if node is not None: - # Find parent state of that node - for n, g in root_sdfg.all_nodes_recursive(): - if n == node: - parent_state = g - return parent_state - return None - - def get_num_parent_map_scopes(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_state: dace.SDFGState): - scope_dict = parent_state.scope_dict() - num_parent_maps = 0 - cur_node = node - while scope_dict[cur_node] is not None: - if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): - num_parent_maps += 1 - cur_node = scope_dict[cur_node] - - # Check parent nsdfg - parent_nsdfg_node = parent_state.sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) - - while parent_nsdfg_node is not None: - scope_dict = parent_nsdfg_parent_state.scope_dict() - cur_node = parent_nsdfg_node - while scope_dict[cur_node] is not None: - if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): - num_parent_maps += 1 - cur_node = scope_dict[cur_node] - parent_nsdfg_node = parent_nsdfg_parent_state.sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) - - return num_parent_maps + return len(get_parent_maps(root_sdfg, node, parent_state)) def get_num_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_state: dace.SDFGState): @@ -707,7 +519,7 @@ def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes. # Check parent nsdfg parent_nsdfg_node = parent_sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + parent_nsdfg_parent_state = parent_state.sdfg.parent_graph while parent_nsdfg_node is not None and parent_nsdfg_parent_state is not None: scope_dict = parent_nsdfg_parent_state.scope_dict() @@ -727,7 +539,7 @@ def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes. parent_graph = parent_graph.parent_graph parent_nsdfg_node = parent_sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + parent_nsdfg_parent_state = parent_state.sdfg.parent_graph return parent_scopes @@ -749,7 +561,7 @@ def get_parent_maps(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_stat # Check parent nsdfg parent_nsdfg_node = parent_state.sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + parent_nsdfg_parent_state = parent_state.sdfg.parent_graph while parent_nsdfg_node is not None: scope_dict = parent_nsdfg_parent_state.scope_dict() @@ -759,6 +571,6 @@ def get_parent_maps(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_stat maps.append((cur_node, parent_state)) cur_node = scope_dict[cur_node] parent_nsdfg_node = parent_nsdfg_parent_state.sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = _find_parent_state(root_sdfg, parent_nsdfg_node) + parent_nsdfg_parent_state = parent_state.sdfg.parent_graph return maps diff --git a/dace/sdfg/tasklet_utils.py b/dace/sdfg/tasklet_utils.py index cc9d921a08..fc22d7f597 100644 --- a/dace/sdfg/tasklet_utils.py +++ b/dace/sdfg/tasklet_utils.py @@ -11,6 +11,8 @@ and extract relevant information such as operands, constants, and operations. """ +import re +import sympy import dace from typing import Dict, Tuple, Set from dace.properties import CodeBlock @@ -65,6 +67,253 @@ class TaskletType(Enum): SYMBOL_SYMBOL = "symbol_symbol" +def token_replace_dict(code: str, repldict: Dict[str, str]) -> str: + """ + Replaces exact token matches in a code string using a replacement dictionary. + Tokens are split using whitespace and common delimiters (` `, `(`, `)`, `[`, `]`). + + Parameters + ---------- + code : str + The code string in which to replace tokens. + repldict : Dict[str, str] + Mapping from token names to their replacement strings. + + Returns + ------- + str + The code string with replacements applied. + """ + + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', code) + + # Replace tokens that exactly match src + tokens = [repldict[token.strip()] if token.strip() in repldict else token for token in tokens] + + # Recombine everything + return ''.join(tokens).strip() + + +def token_match(string_to_check: str, pattern_str: str) -> str: + """ + Checks if a given pattern string exists as a token in the input string. + The input string is split on empty space and brackets (` `, `(`, `)`, `[`, `]`). + + Parameters + ---------- + string_to_check : str + The string to search for the token. + pattern_str : str + The token to search for. + + Returns + ------- + bool + True if the token exists, False otherwise. + """ + + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', string_to_check) + + # Replace tokens that exactly match src + tokens = {token.strip() for token in tokens} + + return pattern_str in tokens + + +def token_split(string_to_check: str) -> Set[str]: + """ + Splits a string into a set of tokens, keeping delimiters, and returns all tokens. + The input string is split on empty space and brackets (` `, `(`, `)`, `[`, `]`). + + Parameters + ---------- + string_to_check : str + The string to split into tokens. + pattern_str : str + (Unused in this function, kept for consistency with token_match) + + Returns + ------- + Set[str] + The set of tokens extracted from the string. + """ + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', string_to_check) + + # Replace tokens that exactly match src + tokens = {token.strip() for token in tokens} + + return tokens + + +def token_split_variable_names(string_to_check: str) -> Set[str]: + """ + Splits a string into variable name tokens, ignoring delimiters and non-identifiers. + Uses `str.isidentifier` on individual tokens. + + The input string is split on empty space and brackets (` `, `(`, `)`, `[`, `]`). + + Parameters + ---------- + string_to_check : str + The string to split into tokens. + + Returns + ------- + Set[str] + The set of tokens extracted from the string. + """ + # Split while keeping delimiters + tokens = re.split(r'(\s+|[()\[\]])', string_to_check) + + # Replace tokens that exactly match src + tokens = {token.strip() for token in tokens if token not in ["[", "]", "(", ")"] and token.isidentifier()} + + return tokens + + +def tasklet_has_symbol(tasklet: dace.nodes.Tasklet, symbol_str: str) -> bool: + """ + Checks if a symbol is present in a tasklet's code. Uses symbolic analysis of sympy. + Checks functions and function arguments too. + + Parameters + ---------- + tasklet : dace.nodes.Tasklet + The tasklet whose code to inspect. + symbol_str : str + The symbol name to search for. + + Returns + ------- + bool + True if the symbol exists in the tasklet's code, False otherwise. + """ + if tasklet.code.language == dace.dtypes.Language.Python: + try: + lhs, rhs = tasklet.code.as_astring.split(" = ", 2) + lhs = lhs.strip() + rhs = rhs.strip() + if symbol_str == lhs: + return True + sym_expr = dace.symbolic.SymExpr(rhs) + # free_symbols gives variables like 'b' + symbols = {str(s) for s in sym_expr.free_symbols} + # collect function names + for func in sym_expr.atoms(sympy.Function): + symbols.add(str(func.func)) # func.func is the function name + return (symbol_str in {str(s) for s in symbols}) + except Exception as e: + return token_match(tasklet.code.as_string, symbol_str) + else: + return token_match(tasklet.code.as_string, symbol_str) + + +def replace_code(code_str: str, code_lang: dace.dtypes.Language, repldict: Dict[str, str]) -> str: + """ + Replaces variables in a code string according to a replacement dictionary. + Supports Python symbolic substitution and fallback string-based replacement. + + Parameters + ---------- + code_str : str + The code string to modify. + code_lang : dace.dtypes.Language + The programming language of the code. + repldict : Dict[str, str] + Mapping from variable names to their replacements. + + Returns + ------- + str + The modified code string with replacements applied. + """ + + def _str_replace(lhs: str, rhs: str) -> str: + code_str = token_replace_dict(rhs, repldict) + return f"{lhs.strip()} = {code_str.strip()}" + + if code_lang == dace.dtypes.Language.Python: + try: + lhs, rhs = code_str.split(" = ") + lhs = lhs.strip() + rhs = rhs.strip() + except Exception as e: + try: + new_rhs_sym_expr = dace.symbolic.SymExpr(code_str).subs(repldict) + cleaned_expr = sympy.pycode(new_rhs_sym_expr, allow_unknown_functions=True).strip() + return f"{cleaned_expr}" + except Exception as e: + return _str_replace(code_str) + try: + new_rhs_sym_expr = dace.symbolic.SymExpr(rhs).subs(repldict) + cleaned_expr = sympy.pycode(new_rhs_sym_expr, allow_unknown_functions=True).strip() + return f"{lhs.strip()} = {cleaned_expr}" + except Exception as e: + return _str_replace(rhs) + else: + return _str_replace(rhs) + + +def tasklet_replace_code(tasklet: dace.nodes.Tasklet, repldict: Dict[str, str]): + """ + Replaces symbols in a tasklet's code according to a replacement dictionary. + Updates the tasklet's code in place. + + Parameters + ---------- + tasklet : dace.nodes.Tasklet + The tasklet whose code to modify. + repldict : Dict[str, str] + Mapping from variable names to their replacements. + + Returns + ------- + None + """ + new_code = replace_code(tasklet.code.as_string, tasklet.code.language, repldict) + tasklet.code = CodeBlock(code=new_code, language=tasklet.code.language) + + +def extract_bracket_tokens(s: str) -> list[tuple[str, list[str]]]: + """ + Extracts all contents inside [...] along with the token before the '[' as the name. + + Args: + s (str): Input string. + + Returns: + List of tuples: [(name_token, string inside brackes)] + """ + results = [] + + # Pattern to match [content_inside] + pattern = re.compile(r'(\b\w+)\[([^\]]*?)\]') + + for match in pattern.finditer(s): + name = match.group(1) # token before '[' + content = match.group(2).split() # split content inside brackets into tokens + + results.append((name, " ".join(content))) + + return {k: v for (k, v) in results} + + +def remove_bracket_tokens(s: str) -> str: + """ + Removes all [...] patterns from the string. + + Args: + s (str): Input string. + + Returns: + str: String with all [...] removed. + """ + return re.sub(r'\[.*?\]', '', s) + + def _extract_constant_from_ast_str(src: str) -> str: """ Extract a numeric constant from a Python code string using AST parsing. diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 1bf4632860..3c94a1d4d7 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -2700,6 +2700,8 @@ def _sympy_to_python_number(val): def demote_symbol_to_scalar(sdfg: 'dace.SDFG', symbol_str: str, default_type: 'dace.dtypes.typeclass' = None): import dace.sdfg.construction_utils as cutil + import dace.sdfg.tasklet_utils as tutil + if default_type is None: default_type = dace.int32 @@ -2738,14 +2740,14 @@ def demote_symbol_to_scalar(sdfg: 'dace.SDFG', symbol_str: str, default_type: 'd if isinstance(n, dace.nodes.Tasklet): assert isinstance(g, dace.SDFGState) sdict = g.scope_dict() - if cutil.tasklet_has_symbol(n, symbol_str): + if tutil.tasklet_has_symbol(n, symbol_str): # 2. If used in tasklet try to replace symbol name with an in connector and add an access to the scalar # Sanity check no tasklet should assign to a symbol lhs, rhs = n.code.as_string.split(" = ") tasklet_lhs = lhs.strip() assert symbol_str not in tasklet_lhs - cutil.tasklet_replace_code(n, {symbol_str: f"_in_{symbol_str}"}) + tutil.tasklet_replace_code(n, {symbol_str: f"_in_{symbol_str}"}) n.add_in_connector(f"_in_{symbol_str}") access = g.add_access(symbol_str) g.add_edge(access, None, n, f"_in_{symbol_str}", dace.memlet.Memlet(expr=f"{symbol_str}[0]")) diff --git a/tests/utils/classify_tasklet_test.py b/tests/utils/classify_tasklet_test.py index 823356fa46..ce778d6ed0 100644 --- a/tests/utils/classify_tasklet_test.py +++ b/tests/utils/classify_tasklet_test.py @@ -383,15 +383,11 @@ }), ] -i = 0 - def _gen_sdfg( tasklet_info: typing.Tuple[str, str, typing.Set[str], typing.Set[str], typing.Set[str], tutil.TaskletType] ) -> dace.SDFG: - global i - i += 1 - sdfg = dace.SDFG(f"sd{i}") + sdfg = dace.SDFG(f"sd") state = sdfg.add_state("s0", is_start_block=True) expr_str, out_type, in_arrays, in_scalars, in_symbols, _ = tasklet_info @@ -423,14 +419,16 @@ def _gen_sdfg( return sdfg -@pytest.mark.parametrize("tasklet_info", tasklet_infos) +@pytest.mark.parametrize("tasklet_info", [(id, tasklet_info) for id, tasklet_info in enumerate(tasklet_infos)]) def test_single_tasklet_split(tasklet_info): - sdfg = _gen_sdfg(tasklet_info) + id, tasklet_info_tuple = tasklet_info + desired_tasklet_info = tasklet_info_tuple[-1] + + sdfg = _gen_sdfg(tasklet_info_tuple) + sdfg.name = f"tasklet_info_test_id_{id}" sdfg.validate() sdfg.compile() - _, _, _, _, _, desired_tasklet_info = tasklet_info - tasklets = {(n, g) for n, g in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.Tasklet)} assert len(tasklets) == 1 tasklet, state = tasklets.pop() From e7e0c85dae09906a189606c8fd0b28e46f226dfa Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Fri, 31 Oct 2025 15:47:41 +0100 Subject: [PATCH 12/17] Major refactor --- dace/sdfg/construction_utils.py | 19 +- .../interstate/branch_elimination.py | 3 +- .../interstate/branch_elimination_test.py | 224 ++++++++++++++---- 3 files changed, 200 insertions(+), 46 deletions(-) diff --git a/dace/sdfg/construction_utils.py b/dace/sdfg/construction_utils.py index 890c033e36..fc0f7eff41 100644 --- a/dace/sdfg/construction_utils.py +++ b/dace/sdfg/construction_utils.py @@ -502,6 +502,12 @@ def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes. cur_node = node parent_scopes = list() + def _get_parent_state(sdfg: dace.SDFG, nsdfg_node: dace.nodes.NestedSDFG): + for n, g in sdfg.all_nodes_recursive(): + if n == nsdfg_node: + return g + return None + if isinstance(cur_node, (dace.nodes.MapEntry, dace.nodes.Tasklet)): while scope_dict[cur_node] is not None: if isinstance(scope_dict[cur_node], dace.nodes.MapEntry): @@ -519,7 +525,7 @@ def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes. # Check parent nsdfg parent_nsdfg_node = parent_sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = parent_state.sdfg.parent_graph + parent_nsdfg_parent_state = _get_parent_state(root_sdfg, parent_nsdfg_node) while parent_nsdfg_node is not None and parent_nsdfg_parent_state is not None: scope_dict = parent_nsdfg_parent_state.scope_dict() @@ -539,12 +545,19 @@ def get_parent_map_and_loop_scopes(root_sdfg: dace.SDFG, node: Union[dace.nodes. parent_graph = parent_graph.parent_graph parent_nsdfg_node = parent_sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = parent_state.sdfg.parent_graph + parent_nsdfg_parent_state = _get_parent_state(root_sdfg, parent_nsdfg_node) return parent_scopes def get_parent_maps(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_state: dace.SDFGState): + + def _get_parent_state(sdfg: dace.SDFG, nsdfg_node: dace.nodes.NestedSDFG): + for n, g in sdfg.all_nodes_recursive(): + if n == nsdfg_node: + return g + return None + maps = [] scope_dict = parent_state.scope_dict() cur_node = node @@ -561,7 +574,7 @@ def get_parent_maps(root_sdfg: dace.SDFG, node: dace.nodes.MapEntry, parent_stat # Check parent nsdfg parent_nsdfg_node = parent_state.sdfg.parent_nsdfg_node - parent_nsdfg_parent_state = parent_state.sdfg.parent_graph + parent_nsdfg_parent_state = _get_parent_state(root_sdfg, parent_nsdfg_node) while parent_nsdfg_node is not None: scope_dict = parent_nsdfg_parent_state.scope_dict() diff --git a/dace/transformation/interstate/branch_elimination.py b/dace/transformation/interstate/branch_elimination.py index 93e2c1315d..d0469a81f0 100644 --- a/dace/transformation/interstate/branch_elimination.py +++ b/dace/transformation/interstate/branch_elimination.py @@ -13,6 +13,7 @@ from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion, SDFGState import dace.sdfg.utils as sdutil import dace.sdfg.construction_utils as cutil +import dace.sdfg.tasklet_utils as tutil from typing import Tuple, Set, Union from dace.symbolic import pystr_to_symbolic from dace.transformation.passes import FuseStates @@ -1062,7 +1063,7 @@ def _extract_condition_var_and_assignment(self, graph: ControlFlowRegion) -> Tup if len(cond_code_symexpr.free_symbols) == 1: cond_var = str(next(iter(cond_code_symexpr.free_symbols))) else: - cond_var = cutil.token_split_variable_names(cond_code_str).pop() + cond_var = tutil.token_split_variable_names(cond_code_str).pop() # If the sym_map has any functions, then we need to drop, e.g. array access new_sym_val_map = dict() diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py index 683fd5b54b..6205bcfbe3 100644 --- a/tests/transformations/interstate/branch_elimination_test.py +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -312,16 +312,18 @@ def run_and_compare( program, num_expected_branches, use_pass, + sdfg_name, **arrays, ): # Run SDFG version (no transformation) sdfg = program.to_sdfg() sdfg.validate() + sdfg.name = sdfg_name out_no_fuse = {k: v.copy() for k, v in arrays.items()} sdfg(**out_no_fuse) copy_sdfg = copy.deepcopy(sdfg) - copy_sdfg.name = sdfg.name + "_branch_eliminated" + copy_sdfg.name = sdfg_name + "_branch_eliminated" del sdfg # Apply transformation @@ -353,16 +355,18 @@ def run_and_compare( def run_and_compare_sdfg( sdfg, permissive, + sdfg_name, **arrays, ): # Run SDFG version (no transformation) sdfg.validate() + sdfg.name = sdfg_name out_no_fuse = {k: v.copy() for k, v in arrays.items()} sdfg(**out_no_fuse) # Run SDFG version (with transformation) copy_sdfg = copy.deepcopy(sdfg) - copy_sdfg.name = sdfg.name + "_branch_eliminated" + copy_sdfg.name = sdfg_name + "_branch_eliminated" del sdfg fb = EliminateBranches() @@ -388,14 +392,21 @@ def test_branch_dependent_value_write(use_pass_flag): b = np.random.rand(N, N) c = np.zeros((N, N)) d = np.zeros((N, N)) - run_and_compare(branch_dependent_value_write, 0, use_pass_flag, a=a, b=b, c=c, d=d) + run_and_compare(branch_dependent_value_write, + 0, + use_pass_flag, + f"branch_dependent_value_write_use_pass_{str(use_pass_flag).lower()}", + a=a, + b=b, + c=c, + d=d) def test_weird_condition(): a = np.random.rand(N, N) b = np.random.rand(N, N) ncldtop = np.array([N // 2], dtype=np.int64) - run_and_compare(weird_condition, 1, False, a=a, b=b, ncldtop=ncldtop[0]) + run_and_compare(weird_condition, 1, False, f"weird_condition", a=a, b=b, ncldtop=ncldtop[0]) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -404,7 +415,14 @@ def test_branch_dependent_value_write_two(use_pass_flag): b = np.zeros((N, N)) c = np.zeros((N, N)) d = np.zeros((N, N)) - run_and_compare(branch_dependent_value_write_two, 0, use_pass_flag, a=a, b=b, c=c, d=d) + run_and_compare(branch_dependent_value_write_two, + 0, + use_pass_flag, + f"branch_dependent_value_write_two_use_pass_{str(use_pass_flag).lower()}", + a=a, + b=b, + c=c, + d=d) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -412,7 +430,13 @@ def test_branch_dependent_value_write_single_branch(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) d = np.zeros((N, N)) - run_and_compare(branch_dependent_value_write_single_branch, 0, use_pass_flag, a=a, b=b, d=d) + run_and_compare(branch_dependent_value_write_single_branch, + 0, + use_pass_flag, + f"branch_dependent_value_write_single_branch_use_pass_{str(use_pass_flag).lower()}", + a=a, + b=b, + d=d) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -420,7 +444,13 @@ def test_complicated_if(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) d = np.zeros((N, N)) - run_and_compare(complicated_if, 0, use_pass_flag, a=a, b=b, d=d) + run_and_compare(complicated_if, + 0, + use_pass_flag, + f"complicated_if_use_pass_{str(use_pass_flag).lower()}", + a=a, + b=b, + d=d) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -430,7 +460,15 @@ def test_multi_state_branch_body(use_pass_flag): c = np.random.choice([0.001, 5.0], size=(N, N)) d = np.zeros((N, N)) s = np.zeros((1, )).astype(np.int64) - run_and_compare(multi_state_branch_body, 1 if use_pass_flag else 1, use_pass_flag, a=a, b=b, c=c, d=d, s=s[0]) + run_and_compare(multi_state_branch_body, + 1 if use_pass_flag else 1, + use_pass_flag, + f"multistate_branch_body_{str(use_pass_flag).lower()}", + a=a, + b=b, + c=c, + d=d, + s=s[0]) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -440,7 +478,15 @@ def test_nested_if(use_pass_flag): c = np.random.choice([0.001, 5.0], size=(N, N)) d = np.random.choice([0.001, 5.0], size=(N, N)) s = np.zeros((1, )).astype(np.int64) - run_and_compare(nested_if, 0, use_pass_flag, a=a, b=b, c=c, d=d, s=s[0]) + run_and_compare(nested_if, + 0, + use_pass_flag, + f"nested_if_use_pass_{str(use_pass_flag).lower()}", + a=a, + b=b, + c=c, + d=d, + s=s[0]) def test_condition_on_bounds(): @@ -451,13 +497,13 @@ def test_condition_on_bounds(): sdfg = condition_on_bounds.to_sdfg() sdfg.validate() + sdfg.name = "condition_on_bounds" arrays = {"a": a, "b": b, "c": c, "d": d} out_no_fuse = {k: v.copy() for k, v in arrays.items()} sdfg(a=out_no_fuse["a"], b=out_no_fuse["b"], c=out_no_fuse["c"], d=out_no_fuse["d"], s=1, SN=2) # Apply transformation EliminateBranches().apply_pass(sdfg, {}) sdfg.validate() - out_fused = {k: v.copy() for k, v in arrays.items()} nsdfgs = {(n, g) for n, g in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} assert len(nsdfgs) == 1 # Can be applied should return false @@ -468,7 +514,7 @@ def test_nested_if_two(): b = np.random.choice([0.001, 5.0], size=(N, N)) c = np.random.choice([0.001, 5.0], size=(N, N)) d = np.random.choice([0.001, 5.0], size=(N, N)) - run_and_compare(nested_if_two, 0, True, a=a, b=b, c=c, d=d) + run_and_compare(nested_if_two, 0, True, f"nested_if_two", a=a, b=b, c=c, d=d) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -477,7 +523,14 @@ def test_tasklets_in_if(use_pass_flag): b = np.random.choice([0.001, 5.0], size=(N, N)) c = np.zeros((1, )) d = np.zeros((N, N)) - run_and_compare(tasklets_in_if, 0, use_pass_flag, a=a, b=b, d=d, c=c[0]) + run_and_compare(tasklets_in_if, + 0, + use_pass_flag, + f"tasklets_in_if_use_pass{str(use_pass_flag).lower()}", + a=a, + b=b, + d=d, + c=c[0]) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -485,14 +538,26 @@ def test_branch_dependent_value_write_single_branch_nonzero_write(use_pass_flag) a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) d = np.random.choice([0.001, 5.0], size=(N, N)) - run_and_compare(branch_dependent_value_write_single_branch_nonzero_write, 0, use_pass_flag, a=a, b=b, d=d) + run_and_compare(branch_dependent_value_write_single_branch_nonzero_write, + 0, + use_pass_flag, + f"branch_dependent_value_write_single_branch_nonzero_write_use_pass_{str(use_pass_flag).lower()}", + a=a, + b=b, + d=d) def test_branch_dependent_value_write_with_transient_reuse(): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 3.0], size=(N, N)) c = np.random.choice([0.001, 3.0], size=(N, N)) - run_and_compare(branch_dependent_value_write_with_transient_reuse, 0, True, a=a, b=b, c=c) + run_and_compare(branch_dependent_value_write_with_transient_reuse, + 0, + True, + f"branch_dependent_value_write_with_transient_reuse", + a=a, + b=b, + c=c) @pytest.mark.parametrize("use_pass_flag", [True, False]) @@ -504,27 +569,30 @@ def test_single_branch_connectors(use_pass_flag): sdfg = single_branch_connectors.to_sdfg() sdfg.validate() + sdfg.name = f"test_single_branch_connectors_use_pass_{str(use_pass_flag).lower()}" arrays = {"a": a, "b": b, "c": c, "d": d} out_no_fuse = {k: v.copy() for k, v in arrays.items()} sdfg(a=out_no_fuse["a"], b=out_no_fuse["b"], c=out_no_fuse["c"][0], d=out_no_fuse["d"]) # Apply transformation + copy_sdfg = copy.deepcopy(sdfg) + copy_sdfg.name = f"test_single_branch_connectors_use_pass_{str(use_pass_flag).lower()}_branch_eliminated" if use_pass_flag: - EliminateBranches().apply_pass(sdfg, {}) + EliminateBranches().apply_pass(copy_sdfg, {}) else: - apply_branch_elimination(sdfg, 2) + apply_branch_elimination(copy_sdfg, 2) # Run SDFG version (with transformation) out_fused = {k: v.copy() for k, v in arrays.items()} - sdfg(a=out_fused["a"], b=out_fused["b"], c=out_fused["c"][0], d=out_fused["d"]) + copy_sdfg(a=out_fused["a"], b=out_fused["b"], c=out_fused["c"][0], d=out_fused["d"]) - branch_code = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + branch_code = {n for n, g in copy_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) 0" # Compare all arrays for name in arrays.keys(): np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) - nsdfgs = {(n, g) for n, g in sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} + nsdfgs = {(n, g) for n, g in copy_sdfg.all_nodes_recursive() if isinstance(n, dace.nodes.NestedSDFG)} assert len(nsdfgs) == 1 nsdfg, parent_state = nsdfgs.pop() assert len(nsdfg.in_connectors) == 4, f"{nsdfg.in_connectors}, length is not 4 but {len(nsdfg.in_connectors)}" @@ -538,7 +606,15 @@ def test_disjoint_subsets(use_pass_flag): B = np.random.randn(N, 3, 3) C = np.random.randn(N, 3, 3) E = np.random.choice([0.001, 3.0], size=(N, 3, 3)) - run_and_compare(disjoint_subsets, 0, use_pass_flag, A=A, B=B, C=C, E=E, if_cond_58=if_cond_58[0]) + run_and_compare(disjoint_subsets, + 0, + use_pass_flag, + f"disjoint_subsets_use_pass_{str(use_pass_flag).lower()}", + A=A, + B=B, + C=C, + E=E, + if_cond_58=if_cond_58[0]) @dace.program @@ -613,7 +689,14 @@ def test_try_clean(): A = np.random.choice([0.001, 3.0], size=(N, )) B = np.random.randn(N, 3, 3) C = np.random.choice([0.001, 3.0], size=(N, )) - run_and_compare_sdfg(sdfg1, permissive=False, A=A, B=B, C=C, if_cond_1=if_cond_1[0], offset=offset[0]) + run_and_compare_sdfg(sdfg1, + permissive=False, + sdfg_name=f"multi_state_nested_if_sdfg", + A=A, + B=B, + C=C, + if_cond_1=if_cond_1[0], + offset=offset[0]) def test_try_clean_as_pass(): @@ -647,7 +730,14 @@ def test_try_clean_as_pass(): A = np.random.choice([0.001, 3.0], size=(N, )) B = np.random.randn(N, 3, 3) C = np.random.choice([0.001, 3.0], size=(N, )) - run_and_compare_sdfg(sdfg, permissive=False, A=A, B=B, C=C, if_cond_1=if_cond_1[0], offset=offset[0]) + run_and_compare_sdfg(sdfg, + permissive=False, + sdfg_name=f"multi_state_nested_if_sdfg_try_clean_variant", + A=A, + B=B, + C=C, + if_cond_1=if_cond_1[0], + offset=offset[0]) def _get_sdfg_with_interstate_array_condition(): @@ -697,6 +787,7 @@ def test_sdfg_with_interstate_array_condition(): run_and_compare_sdfg( sdfg, permissive=False, + sdfg_name=f"sdfg_with_interstate_array_condition", llindex=llindex, zsolqa=zsolqa, zratio=zratio, @@ -735,7 +826,7 @@ def test_repeated_condition_variables(): b = np.random.choice([0.001, 3.0], size=(N, N)) c = np.random.choice([0.001, 3.0], size=(N, N)) conds = np.random.choice([1.0, 3.0], size=(4, N)) - run_and_compare(repeated_condition_variables, 0, True, a=a, b=b, c=c, conds=conds) + run_and_compare(repeated_condition_variables, 0, True, f"repeated_condition_variables", a=a, b=b, c=c, conds=conds) def _find_state(root_sdfg: dace.SDFG, node): @@ -875,6 +966,7 @@ def test_non_trivial_subset_after_combine_tasklet(): non_trivial_subset_after_combine_tasklet, 0, True, + f"non_trivial_subset_after_combine_tasklet", a=A, b=B, c=C, @@ -957,6 +1049,7 @@ def test_split_on_disjoint_subsets(): split_on_disjoint_subsets, 0, True, + f"split_on_disjoint_subsets", a=A, b=B, c=C, @@ -999,6 +1092,7 @@ def test_split_on_disjoint_subsets_nested(): split_on_disjoint_subsets_nested, 0, True, + f"split_on_disjoint_subsets_nested", a=A, b=B, c=C, @@ -1050,6 +1144,7 @@ def test_write_to_transient(): write_to_transient, 0, True, + f"write_to_transient", a=A, b=B, d=D[0], @@ -1066,6 +1161,7 @@ def test_write_to_transient_two(): write_to_transient_two, 0, True, + f"write_to_transient_two", a=A, b=B, d=D[0], @@ -1092,6 +1188,7 @@ def test_double_empty_state(): run_and_compare_sdfg( sdfg, permissive=False, + sdfg_name=f"double_empty_state", a=A, b=B, d=D[0], @@ -1213,7 +1310,13 @@ def test_try_clean_on_complicated_pattern_for_manual_clean_up_one(): ssp.ignore = scalar_names ssp.apply_pass(nsdfg.sdfg, {}) - transformed_sdfg = run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0]) + transformed_sdfg = run_and_compare_sdfg(sdfg, + permissive=True, + sdfg_name="try_clean_on_complicated_pattern_for_manual_cleanup_one", + a=A, + b=B, + c=C, + d=D[0]) branch_code = {n for n, g in transformed_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" @@ -1283,7 +1386,14 @@ def test_try_clean_on_complicated_pattern_for_manual_clean_up_two(): ssp.ignore = scalar_names ssp.apply_pass(nsdfg.sdfg, {}) - transformed_sdfg = run_and_compare_sdfg(sdfg, permissive=True, a=A, b=B, c=C, d=D[0], e=E[0]) + transformed_sdfg = run_and_compare_sdfg(sdfg, + permissive=True, + sdfg_name="try_clean_on_complicated_pattern_for_manual_cleanup_two", + a=A, + b=B, + c=C, + d=D[0], + e=E[0]) branch_code = {n for n, g in transformed_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(branch_code) == 0, f"(actual) len({branch_code}) != (desired) {0}" @@ -1309,14 +1419,14 @@ def single_assignment_cond_from_scalar(a: dace.float64[512]): def test_single_assignment(): if_cond_1 = np.array([1], dtype=np.float64) A = np.ones(shape=(N, ), dtype=np.float64) - run_and_compare(single_assignment, 0, True, a=A, _if_cond_1=if_cond_1[0]) + run_and_compare(single_assignment, 0, True, "single_assignment", a=A, _if_cond_1=if_cond_1[0]) def test_single_assignment_cond_from_scalar(): A = np.ones(shape=(512, ), dtype=np.float64) before = single_assignment_cond_from_scalar.to_sdfg() before.compile() - run_and_compare(single_assignment_cond_from_scalar, 0, True, a=A) + run_and_compare(single_assignment_cond_from_scalar, 0, True, "single_assignment_cond_from_scalar", a=A) def _get_sdfg_with_condition_from_transient_scalar() -> dace.SDFG: @@ -1383,6 +1493,7 @@ def test_condition_from_transient_scalar(): transformed_sdfg = run_and_compare_sdfg(sdfg, permissive=False, + sdfg_name="condition_from_transient_scalar", zsolac=zsolac, zlcond2=zlcond2, za=za, @@ -1488,6 +1599,7 @@ def _get_disjoint_chain_sdfg() -> dace.SDFG: @pytest.mark.parametrize("rtt_val", [0.0, 4.0, 6.0]) def test_disjoint_chain_split_branch_only(rtt_val): sdfg, nsdfg_parent_state = _get_disjoint_chain_sdfg() + sdfg.name = f"disjoint_chain_split_branch_only_rtt_val_{str(rtt_val).replace('.','_')}" zsolqa = np.random.choice([0.001, 5.0], size=(N, 5, 5)) zrainacc = np.random.choice([0.001, 5.0], size=(N, )) zrainaut = np.random.choice([0.001, 5.0], size=(N, )) @@ -1530,6 +1642,7 @@ def test_disjoint_chain(rtt_val): run_and_compare_sdfg(sdfg, permissive=False, + sdfg_name=f"disjoint_chain_rtt_val_{str(rtt_val).replace('.', '_')}", zsolqa=zsolqa, zrainacc=zrainacc, zrainaut=zrainaut, @@ -1569,7 +1682,15 @@ def test_pattern_from_cloudsc_one(c_val): D = np.random.choice([0.001, 5.0], size=(N, N)) E = np.random.choice([0.001, 5.0], size=(N, N)) - run_and_compare(pattern_from_cloudsc_one, 0, True, A=A, B=B, c=C[0], D=D, E=E) + run_and_compare(pattern_from_cloudsc_one, + 0, + True, + f"pattern_from_cloudsc_one_c_val_{str(c_val).replace('.', '_')}", + A=A, + B=B, + c=C[0], + D=D, + E=E) @dace.program @@ -1607,7 +1728,7 @@ def test_can_be_applied_on_map_param_usage(): assert xform.can_be_applied(xform.conditional.parent_graph, 0, xform.conditional.sdfg, False) - run_and_compare(map_param_usage, 0, True, a=A, b=B, d=D) + run_and_compare(map_param_usage, 0, True, "can_be_applied_on_map_param_usage_tester", a=A, b=B, d=D) def _get_safe_map_param_use_in_nested_sdfg() -> dace.SDFG: @@ -1688,7 +1809,12 @@ def test_safe_map_param_use_in_nested_sdfg(): zsolac = np.random.choice([0.001, 5.0], size=(N, )) zfinalsum = np.random.choice([0.001, 5.0], size=(N, )) zacust = np.random.choice([0.001, 5.0], size=(N, )) - run_and_compare_sdfg(sdfg, False, zsolac=zsolac, zfinalsum=zfinalsum, zacust=zacust) + run_and_compare_sdfg(sdfg, + False, + f"safe_map_param_use_in_nested_sdfg", + zsolac=zsolac, + zfinalsum=zfinalsum, + zacust=zacust) def _get_nsdfg_with_return(return_arr: bool) -> dace.SDFG: @@ -1766,12 +1892,15 @@ def _get_nsdfg_with_return(return_arr: bool) -> dace.SDFG: def test_nested_sdfg_with_return(ret_arr): sdfg = _get_nsdfg_with_return(ret_arr) sdfg.validate() + sdfg.name = f"nested_sdfg_with_return_ret_arr_{str(ret_arr).lower()}" + copy_sdfg = copy.deepcopy(sdfg) + copy_sdfg.name = sdfg.name + "_branch_eliminated" - for n, g in sdfg.all_nodes_recursive(): + for n, g in copy_sdfg.all_nodes_recursive(): if isinstance(n, ConditionalBlock): xform = branch_elimination.BranchElimination() xform.conditional = n - xform.parent_nsdfg_state = _find_state(sdfg, g.sdfg.parent_nsdfg_node) + xform.parent_nsdfg_state = _find_state(copy_sdfg, g.sdfg.parent_nsdfg_node) assert xform.can_be_applied(graph=g, expr_index=0, sdfg=g.sdfg, permissive=False) assert xform.can_be_applied(graph=g, expr_index=0, sdfg=g.sdfg, permissive=True) @@ -1790,9 +1919,9 @@ def test_nested_sdfg_with_return(ret_arr): fb = EliminateBranches() fb.try_clean = True fb.permissive = False - fb.apply_pass(sdfg, {}) + fb.apply_pass(copy_sdfg, {}) out_fused = {k: v.copy() for k, v in arrays.items()} - sdfg(**out_fused) + copy_sdfg(**out_fused) assert out_fused["zalfa_1"][0] != 999.9 # Compare all arrays @@ -1924,7 +2053,9 @@ def safe_uniform(low, high, size): data['zrho'] = safe_uniform(0.9, 1.2, (N, )) # density data['zcldtopdist'] = safe_uniform(0.1, 1.0, (N, )) # distance to cloud top data['zicenuclei'] = safe_uniform(1e2, 1e4, (N, )) # ice nuclei concentration + sdfg = huge_sdfg.to_sdfg() + sdfg.name = f"huge_sdfg_with_log_exp_div_operator_{eps_operator_type_for_log_and_div}" sdfg.validate() #it_23: dace.int64, it_47: dace.int64 ScalarToSymbolPromotion().apply_pass(sdfg, {}) @@ -1939,18 +2070,20 @@ def safe_uniform(low, high, size): sdfg(**out_no_fuse) # Apply transformation + copy_sdfg = copy.deepcopy(sdfg) + copy_sdfg.name = f"huge_sdfg_with_log_exp_div_operator_{eps_operator_type_for_log_and_div}_branch_eliminated" fb = EliminateBranches() fb.try_clean = True fb.eps_operator_type_for_log_and_div = eps_operator_type_for_log_and_div - fb.apply_pass(sdfg, {}) + fb.apply_pass(copy_sdfg, {}) - cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + cblocks = {n for n, g in copy_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(cblocks) == 0 # Run SDFG version (with transformation) out_fused = {k: v.copy() for k, v in data.items()} - sdfg(**out_fused) + copy_sdfg(**out_fused) # Compare all arrays for name in data.keys(): @@ -2001,7 +2134,13 @@ def safe_uniform(low, high, size): data['zcldtopdist'] = safe_uniform(0.1, 1.0, (N, )) # distance to cloud top data['zicenuclei'] = safe_uniform(1e2, 1e4, (N, )) # ice nuclei concentration sdfg = mid_sdfg.to_sdfg() + sdfg.name = f"mid_sdfg_with_log_exp_div_operator_{eps_operator_type_for_log_and_div}" + copy_sdfg = copy.deepcopy(sdfg) + copy_sdfg.name = f"mid_sdfg_with_log_exp_div_operator_{eps_operator_type_for_log_and_div}_branch_eliminated" + sdfg.validate() + copy_sdfg.validate() + #it_23: dace.int64, it_47: dace.int64 ScalarToSymbolPromotion().apply_pass(sdfg, {}) sdfg.validate() @@ -2015,18 +2154,19 @@ def safe_uniform(low, high, size): sdfg(**out_no_fuse) # Apply transformation + fb = EliminateBranches() fb.try_clean = True fb.eps_operator_type_for_log_and_div = eps_operator_type_for_log_and_div - fb.apply_pass(sdfg, {}) + fb.apply_pass(copy_sdfg, {}) - cblocks = {n for n, g in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} + cblocks = {n for n, g in copy_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len(cblocks) == 0 # Run SDFG version (with transformation) out_fused = {k: v.copy() for k, v in data.items()} - sdfg(**out_fused) + copy_sdfg(**out_fused) # Compare all arrays for name in data.keys(): @@ -2069,7 +2209,7 @@ def test_loop_param_usage(): assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, False) is True assert xform.can_be_applied(cblock.parent_graph, 0, cblock.sdfg, True) is True - run_and_compare_sdfg(sdfg, False, A=A, B=B, C=C) + run_and_compare_sdfg(sdfg, False, "loop_param_usage", A=A, B=B, C=C) def test_can_be_applied_on_wcr_edge(): @@ -2100,7 +2240,7 @@ def test_can_be_applied_on_wcr_edge(): A = np.random.choice([0.001, 5.0], size=(N, N)) - run_and_compare_sdfg(sdfg, False, A=A) + run_and_compare_sdfg(sdfg, False, "can_be_applied_wcr", A=A) if __name__ == "__main__": From 9268704bc28a10a61b6c67a3472120e384ce3858 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Fri, 31 Oct 2025 15:58:47 +0100 Subject: [PATCH 13/17] Pull improvements to the tasklet classification --- dace/sdfg/tasklet_utils.py | 198 ++++++++++++++++++++++----- tests/utils/classify_tasklet_test.py | 133 +++++++++++++----- 2 files changed, 259 insertions(+), 72 deletions(-) diff --git a/dace/sdfg/tasklet_utils.py b/dace/sdfg/tasklet_utils.py index fc22d7f597..5979faac6b 100644 --- a/dace/sdfg/tasklet_utils.py +++ b/dace/sdfg/tasklet_utils.py @@ -4,11 +4,7 @@ This module provides utilities for analyzing and classifying DaCe tasklets based on their computational patterns. It parses tasklet code to determine the types of operations, operands, -and constants involved, enabling automated code generation and optimization passes. - -The main functionality is the `classify_tasklet` function, which inspects a tasklet's code -and metadata to determine its type (e.g., array-symbol operation, binary array operation) -and extract relevant information such as operands, constants, and operations. +and constants involved. It also provides utilities furhter manipulate and analyze tasklets. """ import re @@ -59,6 +55,7 @@ class TaskletType(Enum): SCALAR_SYMBOL = "scalar_symbol" ARRAY_SYMBOL = "array_symbol" ARRAY_SCALAR = "array_scalar" + SCALAR_ARRAY = "scalar_array" ARRAY_ARRAY = "array_array" UNARY_ARRAY = "unary_array" UNARY_SYMBOL = "unary_symbol" @@ -350,6 +347,39 @@ def _extract_constant_from_ast_str(src: str) -> str: raise ValueError("No constant found") +def _split_code_on_assignment(code_str: str) -> Tuple[str, str]: + """ + Returns the LHS and RHS of the first assignment in a Python tasklet. + + Args: + node: A Python tasklet node. + + Returns: + A tuple (lhs_str, rhs_str) where both are strings representing + the left-hand side and right-hand side of the first assignment. + """ + # Parse the tasklet code into an AST + code_ast = ast.parse(code_str) + + # Find the first assignment statement + assign_node = next((n for n in code_ast.body if isinstance(n, ast.Assign)), None) + if assign_node is None: + raise ValueError("No assignment found in tasklet code.") + + # Convert LHS to string + lhs_node = assign_node.targets[0] # handle simple assignments only + lhs_str = ast.unparse(lhs_node).strip() + + # Convert RHS to string + rhs_node = assign_node.value + rhs_str = ast.unparse(rhs_node).strip() + + assert isinstance(lhs_str, str) + assert isinstance(rhs_str, str) + + return lhs_str, rhs_str + + def _extract_non_connector_syms_from_tasklet(node: dace.nodes.Tasklet) -> typing.Set[str]: """ Identify free symbols in tasklet code that are not input/output connectors. @@ -374,13 +404,46 @@ def _extract_non_connector_syms_from_tasklet(node: dace.nodes.Tasklet) -> typing assert isinstance(node, dace.nodes.Tasklet) assert node.code.language == dace.dtypes.Language.Python connectors = {str(s) for s in set(node.in_connectors.keys()).union(set(node.out_connectors.keys()))} - code_rhs: str = node.code.as_string.split("=")[-1].strip() + code_lhs, code_rhs = _split_code_on_assignment(node.code.as_string) all_syms = {str(s) for s in dace.symbolic.SymExpr(code_rhs).free_symbols} real_free_syms = all_syms - connectors free_non_connector_syms = {str(s) for s in real_free_syms} return free_non_connector_syms +def _extract_non_connector_bound_syms_from_tasklet(code_str: str) -> typing.Set[str]: + """ + Recursively extract all literal constants (numbers, strings, booleans, None) + from a Python AST node. + + Args: + node (ast.AST): The AST node or subtree to traverse. + + Returns: + List of constants (int, float, str, bool, None, etc.) + """ + constants = [] + node = ast.parse(code_str, mode="exec") + + class ConstantExtractor(ast.NodeVisitor): + + def visit_Constant(self, n): + constants.append(n.value) + + # For compatibility with Python <3.8 + def visit_Num(self, n): # type: ignore + constants.append(n.n) + + def visit_Str(self, n): # type: ignore + constants.append(n.s) + + def visit_NameConstant(self, n): # type: ignore + constants.append(n.value) + + ConstantExtractor().visit(node) + return {str(c) for c in constants} + + _BINOP_SYMBOLS = { ast.Add: "+", ast.Sub: "-", @@ -408,7 +471,9 @@ def _extract_non_connector_syms_from_tasklet(node: dace.nodes.Tasklet) -> typing _SUPPORTED_OPS = {'*', '+', '-', '/', '>', '<', '>=', '<=', '==', '!='} """Set of supported binary and comparison operators.""" -_SUPPORTED = {'*', '+', '-', '/', 'abs', 'exp', 'sqrt', 'log', 'ln', 'exp', 'pow', 'min', 'max'} +_SUPPORTED = { + '*', '+', '-', '/', '>', '<', '>=', '<=', '==', '!=', 'abs', 'exp', 'sqrt', 'log', 'ln', 'exp', 'pow', 'min', 'max' +} """Set of all supported operations including functions.""" @@ -432,7 +497,6 @@ def _extract_single_op(src: str, default_to_assignment: bool = False) -> str: This function assumes tasklet contains a single operation. You can run the pass `SplitTasklets` to get such tasklets. """ - print(f"Extract single op from {src}") tree = ast.parse(src) found = None @@ -472,7 +536,6 @@ def _extract_single_op(src: str, default_to_assignment: bool = False) -> str: func_name = call_node.func.id found = func_name except SyntaxError as e: - print(e) pass if found is None: @@ -498,10 +561,6 @@ def _match_connector_to_data(state: dace.SDFGState, tasklet: dace.nodes.Tasklet) Returns: Dictionary mapping connector names (str) to data descriptors (dace.data.Data) - Examples: - For a tasklet with input connector "in_a" connected to array "A": - >>> _match_connector_to_data(state, tasklet) - {'in_a': } """ tdict = dict() for ie in state.in_edges(tasklet): @@ -560,13 +619,19 @@ def _reorder_rhs(code_str: str, op: str, rhs1: str, rhs2: str) -> Tuple[str, str raise ValueError(f"Failed to parse function expression: {code_rhs}") from e else: - left_string, right_string = [cstr.strip() for cstr in code_rhs.split(op)] + left_string, right_string = [token_split(cstr.strip()) for cstr in code_rhs.split(op)] if rhs1 in left_string and rhs2 in left_string: - raise Exception("SSA tasklet, rhs1 and rhs2 both can't appear on left side of the operand") + if rhs1 != rhs2: + raise Exception( + "SSA tasklet, rhs1 and rhs2 both can't appear on left side of the operand (unless they are the same and repeated)" + ) if rhs1 in right_string and rhs2 in right_string: - raise Exception("SSA tasklet, rhs1 and rhs2 both can't appear on right side of the operand") + if rhs1 != rhs2: + raise Exception( + "SSA tasklet, rhs1 and rhs2 both can't appear on right side of the operand (unless they are the same and repeated)" + ) if rhs1 in left_string and rhs2 in right_string: return rhs1, rhs2 @@ -575,10 +640,14 @@ def _reorder_rhs(code_str: str, op: str, rhs1: str, rhs2: str) -> Tuple[str, str return rhs2, rhs1 if rhs1 not in left_string and rhs2 not in right_string: - raise Exception("SSA tasklet, rhs1 appears in none of the substrings") + raise Exception( + f"SSA tasklet, rhs1 appears in none of the substrings rhs1: {rhs1} string: {left_string} -op- {right_string}" + ) if rhs2 not in left_string and rhs2 not in right_string: - raise Exception("SSA tasklet, rhs2 appears in none of the substrings") + raise Exception( + f"SSA tasklet, rhs2 appears in none of the substrings, rhs2: {rhs1} string: {left_string} -op- {right_string}" + ) def count_name_occurrences(expr: str, name: str) -> int: @@ -600,8 +669,6 @@ def count_name_occurrences(expr: str, name: str) -> int: 2 >>> count_name_occurrences("x * x * x", "x") 3 - >>> count_name_occurrences("abs(y)", "y") - 1 Note: This is used to distinguish between unary operations (single occurrence) @@ -631,12 +698,15 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: Dictionary with the following keys: - type (TaskletType): The classified tasklet type - lhs (str): Output connector name (left-hand side variable) - - rhs1 (str or None): First input connector/operand name (left of the operator if both rhs1 and rhs2 are set) - - rhs2 (str or None): Second input connector/operand name (right of the operator if both rhs1 and rhs2 are set, can be same as rhs1) - - constant1 (str or None): First constant/symbol value (left of the operator if both c1 and c2 are set) - - constant2 (str or None): Second constant/symbol value (right of the operator if both c1 and c2 are set, can be same as c1) + - rhs1 (str or None): Input connector/operand name left of the operator/first function argument + - rhs2 (str or None): Input connector/operand name right of the operator/secpnd function argument + - constant1 (str or None): First constant/symbol value left of the operator/first function argument + - constant2 (str or None): Second constant/symbol value right of the operator/secpnd function argument - op (str): Operation symbol or function name + Notes: + - Left of the operator is c1 or rhs1 and right of the operator is c2 or rhs2, regardless of the number of constants or expressions + Raises: AssertionError: If tasklet has more than 1 output connector NotImplementedError: If tasklet pattern is not supported @@ -669,8 +739,8 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: 'lhs': 'out', 'rhs1': 'in_a', 'rhs2': None, - 'constant1': '5', - 'constant2': None, + 'constant1': None, + 'constant2': '5', 'op': '+' } # For more see the unit tests @@ -708,6 +778,7 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: lhs_data_name = out_edges.pop().data.data lhs_data = state.sdfg.arrays[lhs_data_name] + # Assignment operators it will return op <- `=` and always populate `rhs1` if code_str == f"{lhs} = {rhs}" or code_str == f"{lhs} = {rhs};": lhs_datadesc = lhs_data rhs_datadesc = rhs_data @@ -739,6 +810,8 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: constant = free_non_connector_syms.pop() if not has_constant: + # If the rhs arrays appears repeatedly it means we have an operator like `a = b * b` + # In case the occurence equaling two, repeat the `rhs` argument rhs_occurence_count = count_name_occurrences(code_str.split(" = ")[1].strip(), rhs) if isinstance(rhs_data, dace.data.Array): rhs2 = None if rhs_occurence_count == 1 else rhs @@ -753,19 +826,30 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: else: raise Exception(f"Unhandled case in tasklet type (1) {rhs_data}, {type(rhs_data)}") else: + # Handle the correct order, left-of the operand is `1` and right is `2` + op = _extract_single_op(code_str) + reordered = _reorder_rhs(code_str, op, rhs, constant) + rhs1 = rhs if reordered[0] == rhs else None + rhs2 = rhs if reordered[1] == rhs else None + constant1 = constant if reordered[0] == constant else None + constant2 = constant if reordered[1] == constant else None if isinstance(rhs_data, dace.data.Array): info_dict.update({ "type": TaskletType.ARRAY_SYMBOL, - "rhs1": rhs, - "constant1": constant, + "rhs1": rhs1, + "rhs2": rhs2, + "constant1": constant1, + "constant2": constant2, "op": _extract_single_op(code_str) }) return info_dict elif isinstance(rhs_data, dace.data.Scalar): info_dict.update({ "type": TaskletType.SCALAR_SYMBOL, - "rhs1": rhs, - "constant1": constant, + "rhs1": rhs1, + "rhs2": rhs2, + "constant1": constant1, + "constant2": constant2, "op": _extract_single_op(code_str) }) return info_dict @@ -787,7 +871,14 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: elif len(scalars) == 1 and len(arrays) == 1: array_arg = next(iter(arrays)) scalar_arg = next(iter(scalars)) - info_dict.update({"type": TaskletType.ARRAY_SCALAR, "rhs1": array_arg, "constant1": scalar_arg, "op": op}) + ttype = TaskletType.ARRAY_SCALAR if rhs1 == array_arg else TaskletType.SCALAR_ARRAY + if ttype == TaskletType.ARRAY_SCALAR: + assert rhs2 == scalar_arg + else: + assert rhs1 == scalar_arg + assert rhs1 is not None + assert rhs2 is not None + info_dict.update({"type": ttype, "rhs1": rhs1, "rhs2": rhs2, "op": op}) return info_dict elif len(scalars) == 2: info_dict.update({"type": TaskletType.SCALAR_SCALAR, "rhs1": rhs1, "rhs2": rhs2, "op": op}) @@ -795,11 +886,12 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: elif n_in == 0: free_syms = _extract_non_connector_syms_from_tasklet(node) - assert len(free_syms) == 2 or len(free_syms) == 1, f"{str(free_syms)}" + bound_syms = _extract_non_connector_bound_syms_from_tasklet(node.code.as_string) + op = _extract_single_op(code_str, default_to_assignment=True) if len(free_syms) == 2: + assert len(bound_syms) == 0 free_sym1 = free_syms.pop() free_sym2 = free_syms.pop() - op = _extract_single_op(code_str, default_to_assignment=False) free_sym1, free_sym2 = _reorder_rhs(code_str, op, free_sym1, free_sym2) info_dict.update({ "type": TaskletType.SYMBOL_SYMBOL, @@ -809,7 +901,6 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: }) return info_dict elif len(free_syms) == 1: - op = _extract_single_op(code_str, default_to_assignment=True) if op == "=": free_sym1 = free_syms.pop() info_dict.update({"type": TaskletType.ARRAY_SYMBOL_ASSIGNMENT, "constant1": free_sym1, "op": "="}) @@ -817,9 +908,42 @@ def classify_tasklet(state: dace.SDFGState, node: dace.nodes.Tasklet) -> Dict: else: free_sym1 = free_syms.pop() rhs_occurence_count = count_name_occurrences(code_str.split(" = ")[1].strip(), free_sym1) - free_sym2 = None if rhs_occurence_count == 1 else free_sym1 - ttype = TaskletType.UNARY_SYMBOL if rhs_occurence_count == 1 else TaskletType.SYMBOL_SYMBOL - info_dict.update({"type": ttype, "constant1": free_sym1, "constant2": free_sym2, "op": op}) + if rhs_occurence_count == 2: + assert len(bound_syms) == 0 + c1, c2 = free_sym1, None + ttype = TaskletType.UNARY_SYMBOL + else: + # It might be sym1 op 2.0 (constant literal, doesn't have to be 2.0) + # But also a function + assert len(bound_syms) <= 1 + if len(bound_syms) == 1: + bound_sym1 = bound_syms.pop() + # Make sure order is correct + c1, c2 = _reorder_rhs(code_str, op, free_sym1, bound_sym1) + ttype = TaskletType.SYMBOL_SYMBOL + else: + c1, c2 = free_sym1, None + ttype = TaskletType.UNARY_SYMBOL + + info_dict.update({"type": ttype, "constant1": c1, "constant2": c2, "op": op}) + return info_dict + else: + if len(bound_syms) == 2: + c1 = bound_syms.pop() + c2 = bound_syms.pop() + c1, c2 = _reorder_rhs(code_str, op, c1, c2) + if c1 == c2: + ttype = TaskletType.UNARY_SYMBOL + else: + ttype = TaskletType.SYMBOL_SYMBOL + info_dict.update({"type": ttype, "constant1": c1, "constant2": c2, "op": op}) + return info_dict + else: + assert len(bound_syms) == 1 + # Could be a function call on a constant like `f(2.0)` + c1 = bound_syms.pop() + ttype = TaskletType.UNARY_SYMBOL + info_dict.update({"type": ttype, "constant1": c1, "constant2": None, "op": op}) return info_dict raise NotImplementedError("Unhandled case in detect tasklet type") diff --git a/tests/utils/classify_tasklet_test.py b/tests/utils/classify_tasklet_test.py index ce778d6ed0..d855959a7c 100644 --- a/tests/utils/classify_tasklet_test.py +++ b/tests/utils/classify_tasklet_test.py @@ -11,8 +11,8 @@ "rhs1": "in_a", "rhs2": None, "op": "+", - "constant1": "sym_b", - "constant2": None + "constant1": None, + "constant2": "sym_b", }), ("out = in_a - sym_b", "array", {"a"}, {}, {"sym_b"}, { "type": tutil.TaskletType.ARRAY_SYMBOL, @@ -20,8 +20,8 @@ "rhs1": "in_a", "rhs2": None, "op": "-", - "constant1": "sym_b", - "constant2": None + "constant1": None, + "constant2": "sym_b" }), ("out = in_a * sym_b", "array", {"a"}, {}, {"sym_b"}, { "type": tutil.TaskletType.ARRAY_SYMBOL, @@ -29,8 +29,8 @@ "rhs1": "in_a", "rhs2": None, "op": "*", - "constant1": "sym_b", - "constant2": None + "constant1": None, + "constant2": "sym_b" }), ("out = in_a / sym_b", "array", {"a"}, {}, {"sym_b"}, { "type": tutil.TaskletType.ARRAY_SYMBOL, @@ -38,8 +38,8 @@ "rhs1": "in_a", "rhs2": None, "op": "/", - "constant1": "sym_b", - "constant2": None + "constant1": None, + "constant2": "sym_b" }), # === ARRAY + CONSTANT === @@ -49,8 +49,8 @@ "rhs1": "in_a", "rhs2": None, "op": "+", - "constant1": "2", - "constant2": None + "constant1": None, + "constant2": "2" }), ("out = in_a * 3", "array", {"a"}, {}, {}, { "type": tutil.TaskletType.ARRAY_SYMBOL, @@ -58,8 +58,8 @@ "rhs1": "in_a", "rhs2": None, "op": "*", - "constant1": "3", - "constant2": None + "constant1": None, + "constant2": "3" }), ("out = in_a / 2.5", "array", {"a"}, {}, {}, { "type": tutil.TaskletType.ARRAY_SYMBOL, @@ -67,8 +67,8 @@ "rhs1": "in_a", "rhs2": None, "op": "/", - "constant1": "2.5", - "constant2": None + "constant1": None, + "constant2": "2.5" }), ("out = in_a - 5", "array", {"a"}, {}, {}, { "type": tutil.TaskletType.ARRAY_SYMBOL, @@ -76,8 +76,8 @@ "rhs1": "in_a", "rhs2": None, "op": "-", - "constant1": "5", - "constant2": None + "constant1": None, + "constant2": "5" }), # === ARRAY + ARRAY === @@ -125,8 +125,8 @@ "rhs1": "in_x", "rhs2": None, "op": "+", - "constant1": "sym_y", - "constant2": None + "constant1": None, + "constant2": "sym_y" }), ("out = in_x * sym_y", "scalar", {}, {"x"}, {"sym_y"}, { "type": tutil.TaskletType.SCALAR_SYMBOL, @@ -134,8 +134,8 @@ "rhs1": "in_x", "rhs2": None, "op": "*", - "constant1": "sym_y", - "constant2": None + "constant1": None, + "constant2": "sym_y" }), ("out = in_x - sym_y", "scalar", {}, {"x"}, {"sym_y"}, { "type": tutil.TaskletType.SCALAR_SYMBOL, @@ -143,8 +143,8 @@ "rhs1": "in_x", "rhs2": None, "op": "-", - "constant1": "sym_y", - "constant2": None + "constant1": None, + "constant2": "sym_y" }), # === SCALAR + SCALAR === @@ -205,7 +205,7 @@ "constant2": "sym_b" }), - # === FUNCTIONAL / SUPPORTED OPS === + # === UNARY / FUNCTIONAL OPS === ("out = abs(in_a)", "array", {"a"}, {}, {}, { "type": tutil.TaskletType.UNARY_ARRAY, "lhs": "out", @@ -248,8 +248,8 @@ "rhs1": "in_a", "rhs2": None, "op": "pow", - "constant1": "2", - "constant2": None + "constant1": None, + "constant2": "2" }), ("out = min(in_a, in_b)", "array", {"a", "b"}, {}, {}, { "type": tutil.TaskletType.ARRAY_ARRAY, @@ -278,15 +278,6 @@ "constant1": "sym_a", "constant2": None }), - ("out = exp(in_a)", "array", {"a"}, {}, {}, { - "type": tutil.TaskletType.UNARY_ARRAY, - "lhs": "out", - "rhs1": "in_a", - "rhs2": None, - "op": "exp", - "constant1": None, - "constant2": None - }), ("out = sqrt(in_a)", "scalar", {}, {"a"}, {}, { "type": tutil.TaskletType.UNARY_SCALAR, "lhs": "out", @@ -350,7 +341,7 @@ "rhs2": None, "op": "=", "constant1": "sym_a", - "constant2": None + "constant2": None, }), # === SINGLE-INPUT TWO RHS CASE === @@ -381,6 +372,78 @@ "constant1": None, "constant2": None }), + ("out = in_a - in_scl1", "array", {"a"}, {"scl1"}, {}, { + "type": tutil.TaskletType.ARRAY_SCALAR, + "lhs": "out", + "rhs1": "in_a", + "rhs2": "in_scl1", + "op": "-", + "constant1": None, + "constant2": None, + }), + ("out = in_scl1 - in_a", "array", {"a"}, {"scl1"}, {}, { + "type": tutil.TaskletType.SCALAR_ARRAY, + "lhs": "out", + "rhs1": "in_scl1", + "rhs2": "in_a", + "op": "-", + "constant1": None, + "constant2": None, + }), + ("out = in_scl1 - in_a", "scalar", {"a"}, {"scl1"}, {}, { + "type": tutil.TaskletType.SCALAR_ARRAY, + "lhs": "out", + "rhs1": "in_scl1", + "rhs2": "in_a", + "op": "-", + "constant1": None, + "constant2": None, + }), + ("out = 2.0 - 1.0", "scalar", {}, {}, {}, { + "type": tutil.TaskletType.SYMBOL_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "-", + "constant1": "2.0", + "constant2": "1.0", + }), + ("out = 2.0 - sym2", "scalar", {}, {}, {"sym2"}, { + "type": tutil.TaskletType.SYMBOL_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "-", + "constant1": "2.0", + "constant2": "sym2", + }), + ("out = sym2 * sym2", "scalar", {}, {}, {"sym2"}, { + "type": tutil.TaskletType.UNARY_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "*", + "constant1": "sym2", + "constant2": None, + }), + ("out = exp(sym2)", "scalar", {}, {}, {"sym2"}, { + "type": tutil.TaskletType.UNARY_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "exp", + "constant1": "sym2", + "constant2": None, + }), + ("out = exp(3.0)", "scalar", {}, {}, {}, { + "type": tutil.TaskletType.UNARY_SYMBOL, + "lhs": "out", + "rhs1": None, + "rhs2": None, + "op": "exp", + "constant1": "3.0", + "constant2": None, + }), ] From c6f15af6120bd4f83563f12a86071a89bcc7fe1d Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Fri, 31 Oct 2025 17:46:18 +0100 Subject: [PATCH 14/17] Try fix serialization failing but how is that me? --- .../interstate/branch_elimination_test.py | 38 +++++++++---------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py index 6205bcfbe3..53377d23c1 100644 --- a/tests/transformations/interstate/branch_elimination_test.py +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -319,12 +319,16 @@ def run_and_compare( sdfg = program.to_sdfg() sdfg.validate() sdfg.name = sdfg_name - out_no_fuse = {k: v.copy() for k, v in arrays.items()} - sdfg(**out_no_fuse) copy_sdfg = copy.deepcopy(sdfg) copy_sdfg.name = sdfg_name + "_branch_eliminated" - del sdfg + + # Run SDFG version (with transformation) + c_sdfg = sdfg.compile() + c_copy_sdfg = copy_sdfg.compile() + + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + out_fused = {k: v.copy() for k, v in arrays.items()} # Apply transformation if use_pass: @@ -334,10 +338,8 @@ def run_and_compare( else: apply_branch_elimination(copy_sdfg, 2) - # Run SDFG version (with transformation) - out_fused = {k: v.copy() for k, v in arrays.items()} - - copy_sdfg(**out_fused) + c_sdfg(**out_no_fuse) + c_copy_sdfg(**out_fused) branch_code = {n for n, g in copy_sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)} assert len( @@ -347,10 +349,6 @@ def run_and_compare( for name in arrays.keys(): np.testing.assert_allclose(out_fused[name], out_no_fuse[name], atol=1e-12) - del out_no_fuse - del out_fused - del copy_sdfg - def run_and_compare_sdfg( sdfg, @@ -361,28 +359,28 @@ def run_and_compare_sdfg( # Run SDFG version (no transformation) sdfg.validate() sdfg.name = sdfg_name - out_no_fuse = {k: v.copy() for k, v in arrays.items()} - sdfg(**out_no_fuse) - # Run SDFG version (with transformation) copy_sdfg = copy.deepcopy(sdfg) copy_sdfg.name = sdfg_name + "_branch_eliminated" - del sdfg + + c_sdfg = sdfg.compile() + c_copy_sdfg = copy_sdfg.compile() + + out_no_fuse = {k: v.copy() for k, v in arrays.items()} + out_fused = {k: v.copy() for k, v in arrays.items()} fb = EliminateBranches() fb.try_clean = True fb.permissive = permissive fb.apply_pass(copy_sdfg, {}) - out_fused = {k: v.copy() for k, v in arrays.items()} - copy_sdfg(**out_fused) + + c_sdfg(**out_no_fuse) + c_copy_sdfg(**out_fused) # Compare all arrays for name in arrays.keys(): np.testing.assert_allclose(out_no_fuse[name], out_fused[name], atol=1e-12) - del out_no_fuse - del out_fused - return copy_sdfg From 47ba69017c322475d91d5e70d44cb8434871bb41 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Fri, 31 Oct 2025 18:07:13 +0100 Subject: [PATCH 15/17] Fix colliding sdfg names --- dace/codegen/codegen.py | 41 ++++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index 3ccbb56dc6..e9f61ee6e2 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -172,23 +172,38 @@ def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: import filecmp import shutil import tempfile - with tempfile.TemporaryDirectory() as tmp_dir: - sdfg.save(f'{tmp_dir}/test.sdfg', hash=False) - sdfg2 = SDFG.from_file(f'{tmp_dir}/test.sdfg') - sdfg2.save(f'{tmp_dir}/test2.sdfg', hash=False) + import os + with tempfile.NamedTemporaryFile(suffix="_.sdfg", delete=False) as tmp1, \ + tempfile.NamedTemporaryFile(suffix="_.sdfg", delete=False) as tmp2: + tmp1_path = tmp1.name + tmp2_path = tmp2.name + + try: + sdfg.save(tmp1_path, hash=False) + sdfg2 = SDFG.from_file(tmp1_path) + sdfg2.save(tmp2_path, hash=False) + print('Testing SDFG serialization...') - if not filecmp.cmp(f'{tmp_dir}/test.sdfg', f'{tmp_dir}/test2.sdfg'): - with open(f'{tmp_dir}/test.sdfg', 'r') as f1: - with open(f'{tmp_dir}/test2.sdfg', 'r') as f2: - diff = difflib.unified_diff(f1.readlines(), - f2.readlines(), - fromfile='test.sdfg (first save)', - tofile='test2.sdfg (after roundtrip)') + if not filecmp.cmp(tmp1_path, tmp2_path): + with open(tmp1_path, 'r') as f1, open(tmp2_path, 'r') as f2: + diff = difflib.unified_diff(f1.readlines(), + f2.readlines(), + fromfile='test.sdfg (first save)', + tofile='test2.sdfg (after roundtrip)') diff = ''.join(diff) - shutil.move(f'{tmp_dir}/test.sdfg', 'test.sdfg') - shutil.move(f'{tmp_dir}/test2.sdfg', 'test2.sdfg') + + shutil.copy(tmp1_path, 'test.sdfg') + shutil.copy(tmp2_path, 'test2.sdfg') raise RuntimeError(f'SDFG serialization failed - files do not match:\n{diff}') + finally: + # Clean up the temporary files + try: + os.remove(tmp1_path) + os.remove(tmp2_path) + except OSError: + pass + if config.Config.get_bool('optimizer', 'detect_control_flow'): # NOTE: This should likely be done either earlier in the future, or changed entirely in modular codegen. # It is being done here to ensure that for now the semantics of the setting are preserved and legacy tests, From 2db5a920fa5384b5d03d5dc4ef283a3d1c844423 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Sat, 8 Nov 2025 20:55:29 +0100 Subject: [PATCH 16/17] Add no autoopt marker and mark branch elimination tests --- pytest.ini | 1 + tests/conftest.py | 35 +++++++++++++++ .../interstate/branch_elimination_test.py | 43 +++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/pytest.ini b/pytest.ini index b0aa6e9b8f..cf3276f6d1 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,6 +16,7 @@ markers = hptt: Test requires the HPTT library (select with '-m "hptt') long: Test runs for a long time and is skipped in CI (select with '-m "long"') sequential: Test must be run sequentially (select with '-m "sequential"') + no_autoopt: skip this test if auto-optimization is enabled (select with '-m "no_autoopt"') python_files = *_test.py *_cudatest.py diff --git a/tests/conftest.py b/tests/conftest.py index 57f611ce66..bbb9a7d3a0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,3 +13,38 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): if config.option.markexpr == 'mpi': if exitstatus in (pytest.ExitCode.TESTS_FAILED, pytest.ExitCode.INTERNAL_ERROR, pytest.ExitCode.INTERRUPTED): os._exit(1) + + +@pytest.hookimpl() +def pytest_collection_modifyitems(config, items): + """Automatically skip tests marked with @pytest.mark.no_autoopt + when auto-optimization is enabled.""" + import dace + + # 1. Determine if autooptimize is active + autoopt_env = os.environ.get("DACE_optimizer_autooptimize") + autoopt_enabled = False + + if autoopt_env is not None: + autoopt_enabled = autoopt_env.lower() in ("1", "true", "yes", "on") + elif dace is not None: + try: + autoopt_enabled = bool(dace.Config.get("optimizer", "autooptimize")) + except Exception: + autoopt_enabled = False + + # 2. If not enabled, nothing to skip + if not autoopt_enabled: + return + + # 3. Apply skip mark to all tests with @pytest.mark.no_autoopt + skip_marker = pytest.mark.skip(reason="Skipped because autooptimize is enabled") + skipped = 0 + for item in items: + if "no_autoopt" in item.keywords: + item.add_marker(skip_marker) + skipped += 1 + + if skipped: + config.pluginmanager.get_plugin("terminalreporter").write_line( + f"[pytest] autooptimize enabled — skipped {skipped} test(s) with @no_autoopt") diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py index 53377d23c1..27b47c48fb 100644 --- a/tests/transformations/interstate/branch_elimination_test.py +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -385,6 +385,7 @@ def run_and_compare_sdfg( @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_branch_dependent_value_write(use_pass_flag): a = np.random.rand(N, N) b = np.random.rand(N, N) @@ -400,6 +401,7 @@ def test_branch_dependent_value_write(use_pass_flag): d=d) +@pytest.mark.no_autoopt def test_weird_condition(): a = np.random.rand(N, N) b = np.random.rand(N, N) @@ -408,6 +410,7 @@ def test_weird_condition(): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_branch_dependent_value_write_two(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.zeros((N, N)) @@ -424,6 +427,7 @@ def test_branch_dependent_value_write_two(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_branch_dependent_value_write_single_branch(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -438,6 +442,7 @@ def test_branch_dependent_value_write_single_branch(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_complicated_if(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -452,6 +457,7 @@ def test_complicated_if(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_multi_state_branch_body(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -470,6 +476,7 @@ def test_multi_state_branch_body(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_nested_if(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -487,6 +494,7 @@ def test_nested_if(use_pass_flag): s=s[0]) +@pytest.mark.no_autoopt def test_condition_on_bounds(): a = np.random.choice([0.001, 3.0], size=(2, 2)) b = np.random.choice([0.001, 5.0], size=(2, 2)) @@ -507,6 +515,7 @@ def test_condition_on_bounds(): assert len(nsdfgs) == 1 # Can be applied should return false +@pytest.mark.no_autoopt def test_nested_if_two(): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -516,6 +525,7 @@ def test_nested_if_two(): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_tasklets_in_if(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -532,6 +542,7 @@ def test_tasklets_in_if(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_branch_dependent_value_write_single_branch_nonzero_write(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -545,6 +556,7 @@ def test_branch_dependent_value_write_single_branch_nonzero_write(use_pass_flag) d=d) +@pytest.mark.no_autoopt def test_branch_dependent_value_write_with_transient_reuse(): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 3.0], size=(N, N)) @@ -559,6 +571,7 @@ def test_branch_dependent_value_write_with_transient_reuse(): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_single_branch_connectors(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -598,6 +611,7 @@ def test_single_branch_connectors(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) +@pytest.mark.no_autoopt def test_disjoint_subsets(use_pass_flag): if_cond_58 = np.array([1], dtype=np.int32) A = np.random.choice([0.001, 3.0], size=(N, )) @@ -637,6 +651,7 @@ def _multi_state_nested_if( B[6, 0, 2] = A[6] + B[6, 0, 2] +@pytest.mark.no_autoopt def test_try_clean(): sdfg1 = _multi_state_nested_if.to_sdfg() cblocks = {n for n, g in sdfg1.all_nodes_recursive() if isinstance(n, ConditionalBlock)} @@ -697,6 +712,7 @@ def test_try_clean(): offset=offset[0]) +@pytest.mark.no_autoopt def test_try_clean_as_pass(): # This is a test to check the different configurations of try clean, applicability depends on the SDFG and the pass sdfg = _multi_state_nested_if.to_sdfg() @@ -777,6 +793,7 @@ def _get_sdfg_with_interstate_array_condition(): return sdfg +@pytest.mark.no_autoopt def test_sdfg_with_interstate_array_condition(): sdfg = _get_sdfg_with_interstate_array_condition() llindex = np.ones(shape=(4, 4, 4), dtype=np.int64) @@ -819,6 +836,7 @@ def repeated_condition_variables( c[i, j] = a[i, j] * b[i, j] +@pytest.mark.no_autoopt def test_repeated_condition_variables(): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 3.0], size=(N, N)) @@ -834,6 +852,7 @@ def _find_state(root_sdfg: dace.SDFG, node): return None +@pytest.mark.no_autoopt def test_if_over_map(): sdfg = if_over_map.to_sdfg() cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} @@ -860,6 +879,7 @@ def test_if_over_map(): sdfg=xform.conditional.parent_graph.sdfg) is True +@pytest.mark.no_autoopt def test_if_over_map_with_top_level_tasklets(): sdfg = if_over_map.to_sdfg() cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} @@ -891,6 +911,7 @@ def test_if_over_map_with_top_level_tasklets(): sdfg=xform.conditional.parent_graph.sdfg) is True +@pytest.mark.no_autoopt def test_can_be_applied_parameters_on_nested_sdfg(): sdfg = nested_if.to_sdfg() cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} @@ -952,6 +973,7 @@ def non_trivial_subset_after_combine_tasklet( g[6, 6] = tc4 +@pytest.mark.no_autoopt def test_non_trivial_subset_after_combine_tasklet(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1012,6 +1034,7 @@ def split_on_disjoint_subsets_nested( b[i, 4] = 0.0 +@pytest.mark.no_autoopt def test_split_on_disjoint_subsets(): A = np.random.choice([0.001, 5.0], size=(N, N, 2)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1055,6 +1078,7 @@ def test_split_on_disjoint_subsets(): ) +@pytest.mark.no_autoopt def test_split_on_disjoint_subsets_nested(): A = np.random.choice([0.001, 5.0], size=(N, N, 2)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1132,6 +1156,7 @@ def write_to_transient_two( b[i, 3] = zmdn +@pytest.mark.no_autoopt def test_write_to_transient(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1149,6 +1174,7 @@ def test_write_to_transient(): ) +@pytest.mark.no_autoopt def test_write_to_transient_two(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1166,6 +1192,7 @@ def test_write_to_transient_two(): ) +@pytest.mark.no_autoopt def test_double_empty_state(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1211,6 +1238,7 @@ def complicated_pattern_for_manual_clean_up_one( c[i, 0] = 0.0 +@pytest.mark.no_autoopt def test_complicated_pattern_for_manual_clean_up_one(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1269,6 +1297,7 @@ def test_complicated_pattern_for_manual_clean_up_one(): assert all({isinstance(n, dace.SDFGState) for n in body1.nodes()}) +@pytest.mark.no_autoopt def test_try_clean_on_complicated_pattern_for_manual_clean_up_one(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1342,6 +1371,7 @@ def complicated_pattern_for_manual_clean_up_two( a[i, 3] = zlcrit * 2.0 +@pytest.mark.no_autoopt def test_try_clean_on_complicated_pattern_for_manual_clean_up_two(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1414,12 +1444,14 @@ def single_assignment_cond_from_scalar(a: dace.float64[512]): a[i] = 0.0 +@pytest.mark.no_autoopt def test_single_assignment(): if_cond_1 = np.array([1], dtype=np.float64) A = np.ones(shape=(N, ), dtype=np.float64) run_and_compare(single_assignment, 0, True, "single_assignment", a=A, _if_cond_1=if_cond_1[0]) +@pytest.mark.no_autoopt def test_single_assignment_cond_from_scalar(): A = np.ones(shape=(512, ), dtype=np.float64) before = single_assignment_cond_from_scalar.to_sdfg() @@ -1482,6 +1514,7 @@ def _get_sdfg_with_condition_from_transient_scalar() -> dace.SDFG: return sdfg +@pytest.mark.no_autoopt def test_condition_from_transient_scalar(): zsolac = np.random.choice([8.0, 11.0], size=(N, )) zlcond2 = np.random.choice([8.0, 11.0], size=(N, )) @@ -1595,6 +1628,7 @@ def _get_disjoint_chain_sdfg() -> dace.SDFG: @pytest.mark.parametrize("rtt_val", [0.0, 4.0, 6.0]) +@pytest.mark.no_autoopt def test_disjoint_chain_split_branch_only(rtt_val): sdfg, nsdfg_parent_state = _get_disjoint_chain_sdfg() sdfg.name = f"disjoint_chain_split_branch_only_rtt_val_{str(rtt_val).replace('.','_')}" @@ -1630,6 +1664,7 @@ def test_disjoint_chain_split_branch_only(rtt_val): @pytest.mark.parametrize("rtt_val", [0.0, 4.0, 6.0]) +@pytest.mark.no_autoopt def test_disjoint_chain(rtt_val): sdfg, _ = _get_disjoint_chain_sdfg() zsolqa = np.random.choice([0.001, 5.0], size=(N, 5, 5)) @@ -1669,6 +1704,7 @@ def pattern_from_cloudsc_one( @pytest.mark.parametrize("c_val", [0.0, 1.0, 6.0]) +@pytest.mark.no_autoopt def test_pattern_from_cloudsc_one(c_val): A = np.random.choice([0.001, 5.0], size=( 2, @@ -1708,6 +1744,7 @@ def map_param_usage( b[i, i] = zmdn +@pytest.mark.no_autoopt def test_can_be_applied_on_map_param_usage(): A = np.random.choice([0.001, 5.0], size=( N, @@ -1791,6 +1828,7 @@ def _get_safe_map_param_use_in_nested_sdfg() -> dace.SDFG: return outer_sdfg +@pytest.mark.no_autoopt def test_safe_map_param_use_in_nested_sdfg(): sdfg = _get_safe_map_param_use_in_nested_sdfg() sdfg.validate() @@ -1887,6 +1925,7 @@ def _get_nsdfg_with_return(return_arr: bool) -> dace.SDFG: @pytest.mark.parametrize("ret_arr", [True, False]) +@pytest.mark.no_autoopt def test_nested_sdfg_with_return(ret_arr): sdfg = _get_nsdfg_with_return(ret_arr) sdfg.validate() @@ -2011,6 +2050,7 @@ def huge_sdfg(pap: dace.float64[N], ptsphy: dace.float64, r2es: dace.float64, r3 @pytest.mark.parametrize("eps_operator_type_for_log_and_div", ["max", "add"]) +@pytest.mark.no_autoopt def test_huge_sdfg_with_log_exp_div(eps_operator_type_for_log_and_div: str): """Generate test data for the loop body function""" @@ -2091,6 +2131,7 @@ def safe_uniform(low, high, size): @pytest.mark.parametrize("eps_operator_type_for_log_and_div", ["max", "add"]) +@pytest.mark.no_autoopt def test_mid_sdfg_with_log_exp_div(eps_operator_type_for_log_and_div: str): """Generate test data for the loop body function""" @@ -2190,6 +2231,7 @@ def loop_param_usage(A: dace.float64[6, N, N], B: dace.float64[N, N], C: dace.fl C[i, j] = 2.0 + C[i, j] +@pytest.mark.no_autoopt def test_loop_param_usage(): A = np.random.choice([0.001, 5.0], size=(6, N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -2210,6 +2252,7 @@ def test_loop_param_usage(): run_and_compare_sdfg(sdfg, False, "loop_param_usage", A=A, B=B, C=C) +@pytest.mark.no_autoopt def test_can_be_applied_on_wcr_edge(): sdfg = wcr_edge.to_sdfg() From 38aa67c93aa403354b9afa6f328c842a12254709 Mon Sep 17 00:00:00 2001 From: Yakup Koray Budanaz Date: Sat, 8 Nov 2025 22:36:16 +0100 Subject: [PATCH 17/17] Use wrapper instead --- tests/conftest.py | 35 ------ .../interstate/branch_elimination_test.py | 107 +++++++++++------- 2 files changed, 64 insertions(+), 78 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bbb9a7d3a0..57f611ce66 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,38 +13,3 @@ def pytest_terminal_summary(terminalreporter, exitstatus, config): if config.option.markexpr == 'mpi': if exitstatus in (pytest.ExitCode.TESTS_FAILED, pytest.ExitCode.INTERNAL_ERROR, pytest.ExitCode.INTERRUPTED): os._exit(1) - - -@pytest.hookimpl() -def pytest_collection_modifyitems(config, items): - """Automatically skip tests marked with @pytest.mark.no_autoopt - when auto-optimization is enabled.""" - import dace - - # 1. Determine if autooptimize is active - autoopt_env = os.environ.get("DACE_optimizer_autooptimize") - autoopt_enabled = False - - if autoopt_env is not None: - autoopt_enabled = autoopt_env.lower() in ("1", "true", "yes", "on") - elif dace is not None: - try: - autoopt_enabled = bool(dace.Config.get("optimizer", "autooptimize")) - except Exception: - autoopt_enabled = False - - # 2. If not enabled, nothing to skip - if not autoopt_enabled: - return - - # 3. Apply skip mark to all tests with @pytest.mark.no_autoopt - skip_marker = pytest.mark.skip(reason="Skipped because autooptimize is enabled") - skipped = 0 - for item in items: - if "no_autoopt" in item.keywords: - item.add_marker(skip_marker) - skipped += 1 - - if skipped: - config.pluginmanager.get_plugin("terminalreporter").write_line( - f"[pytest] autooptimize enabled — skipped {skipped} test(s) with @no_autoopt") diff --git a/tests/transformations/interstate/branch_elimination_test.py b/tests/transformations/interstate/branch_elimination_test.py index 27b47c48fb..039d3da595 100644 --- a/tests/transformations/interstate/branch_elimination_test.py +++ b/tests/transformations/interstate/branch_elimination_test.py @@ -1,4 +1,5 @@ import copy +import functools import numpy as np import dace import pytest @@ -16,6 +17,26 @@ S2 = 32 +def temporarily_disable_autoopt_and_serialization(func): + + @functools.wraps(func) + def wrapper(*args, **kwargs): + # Save original values + orig_autoopt = dace.config.Config.get("optimizer", "autooptimize") + orig_serialization = dace.config.Config.get("testing", "serialization") + try: + # Set both to False + dace.config.Config.set("optimizer", "autooptimize", value=False) + dace.config.Config.set("testing", "serialization", value=False) + return func(*args, **kwargs) + finally: + # Restore original values + dace.config.Config.set("optimizer", "autooptimize", value=orig_autoopt) + dace.config.Config.set("testing", "serialization", value=orig_serialization) + + return wrapper + + @dace.program def branch_dependent_value_write( a: dace.float64[N, N], @@ -385,7 +406,7 @@ def run_and_compare_sdfg( @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_branch_dependent_value_write(use_pass_flag): a = np.random.rand(N, N) b = np.random.rand(N, N) @@ -401,7 +422,7 @@ def test_branch_dependent_value_write(use_pass_flag): d=d) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_weird_condition(): a = np.random.rand(N, N) b = np.random.rand(N, N) @@ -410,7 +431,7 @@ def test_weird_condition(): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_branch_dependent_value_write_two(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.zeros((N, N)) @@ -427,7 +448,7 @@ def test_branch_dependent_value_write_two(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_branch_dependent_value_write_single_branch(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -442,7 +463,7 @@ def test_branch_dependent_value_write_single_branch(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_complicated_if(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -457,7 +478,7 @@ def test_complicated_if(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_multi_state_branch_body(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -476,7 +497,7 @@ def test_multi_state_branch_body(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_nested_if(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -494,7 +515,7 @@ def test_nested_if(use_pass_flag): s=s[0]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_condition_on_bounds(): a = np.random.choice([0.001, 3.0], size=(2, 2)) b = np.random.choice([0.001, 5.0], size=(2, 2)) @@ -515,7 +536,7 @@ def test_condition_on_bounds(): assert len(nsdfgs) == 1 # Can be applied should return false -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_nested_if_two(): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -525,7 +546,7 @@ def test_nested_if_two(): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_tasklets_in_if(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -542,7 +563,7 @@ def test_tasklets_in_if(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_branch_dependent_value_write_single_branch_nonzero_write(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -556,7 +577,7 @@ def test_branch_dependent_value_write_single_branch_nonzero_write(use_pass_flag) d=d) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_branch_dependent_value_write_with_transient_reuse(): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 3.0], size=(N, N)) @@ -571,7 +592,7 @@ def test_branch_dependent_value_write_with_transient_reuse(): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_single_branch_connectors(use_pass_flag): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 5.0], size=(N, N)) @@ -611,7 +632,7 @@ def test_single_branch_connectors(use_pass_flag): @pytest.mark.parametrize("use_pass_flag", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_disjoint_subsets(use_pass_flag): if_cond_58 = np.array([1], dtype=np.int32) A = np.random.choice([0.001, 3.0], size=(N, )) @@ -651,7 +672,7 @@ def _multi_state_nested_if( B[6, 0, 2] = A[6] + B[6, 0, 2] -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_try_clean(): sdfg1 = _multi_state_nested_if.to_sdfg() cblocks = {n for n, g in sdfg1.all_nodes_recursive() if isinstance(n, ConditionalBlock)} @@ -712,7 +733,7 @@ def test_try_clean(): offset=offset[0]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_try_clean_as_pass(): # This is a test to check the different configurations of try clean, applicability depends on the SDFG and the pass sdfg = _multi_state_nested_if.to_sdfg() @@ -793,7 +814,7 @@ def _get_sdfg_with_interstate_array_condition(): return sdfg -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_sdfg_with_interstate_array_condition(): sdfg = _get_sdfg_with_interstate_array_condition() llindex = np.ones(shape=(4, 4, 4), dtype=np.int64) @@ -836,7 +857,7 @@ def repeated_condition_variables( c[i, j] = a[i, j] * b[i, j] -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_repeated_condition_variables(): a = np.random.choice([0.001, 3.0], size=(N, N)) b = np.random.choice([0.001, 3.0], size=(N, N)) @@ -852,7 +873,7 @@ def _find_state(root_sdfg: dace.SDFG, node): return None -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_if_over_map(): sdfg = if_over_map.to_sdfg() cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} @@ -879,7 +900,7 @@ def test_if_over_map(): sdfg=xform.conditional.parent_graph.sdfg) is True -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_if_over_map_with_top_level_tasklets(): sdfg = if_over_map.to_sdfg() cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} @@ -911,7 +932,7 @@ def test_if_over_map_with_top_level_tasklets(): sdfg=xform.conditional.parent_graph.sdfg) is True -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_can_be_applied_parameters_on_nested_sdfg(): sdfg = nested_if.to_sdfg() cblocks = {n for n in sdfg.all_control_flow_regions() if isinstance(n, ConditionalBlock)} @@ -973,7 +994,7 @@ def non_trivial_subset_after_combine_tasklet( g[6, 6] = tc4 -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_non_trivial_subset_after_combine_tasklet(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1034,7 +1055,7 @@ def split_on_disjoint_subsets_nested( b[i, 4] = 0.0 -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_split_on_disjoint_subsets(): A = np.random.choice([0.001, 5.0], size=(N, N, 2)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1078,7 +1099,7 @@ def test_split_on_disjoint_subsets(): ) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_split_on_disjoint_subsets_nested(): A = np.random.choice([0.001, 5.0], size=(N, N, 2)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1156,7 +1177,7 @@ def write_to_transient_two( b[i, 3] = zmdn -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_write_to_transient(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1174,7 +1195,7 @@ def test_write_to_transient(): ) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_write_to_transient_two(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1192,7 +1213,7 @@ def test_write_to_transient_two(): ) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_double_empty_state(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1238,7 +1259,7 @@ def complicated_pattern_for_manual_clean_up_one( c[i, 0] = 0.0 -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_complicated_pattern_for_manual_clean_up_one(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1297,7 +1318,7 @@ def test_complicated_pattern_for_manual_clean_up_one(): assert all({isinstance(n, dace.SDFGState) for n in body1.nodes()}) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_try_clean_on_complicated_pattern_for_manual_clean_up_one(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1371,7 +1392,7 @@ def complicated_pattern_for_manual_clean_up_two( a[i, 3] = zlcrit * 2.0 -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_try_clean_on_complicated_pattern_for_manual_clean_up_two(): A = np.random.choice([0.001, 5.0], size=(N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -1444,14 +1465,14 @@ def single_assignment_cond_from_scalar(a: dace.float64[512]): a[i] = 0.0 -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_single_assignment(): if_cond_1 = np.array([1], dtype=np.float64) A = np.ones(shape=(N, ), dtype=np.float64) run_and_compare(single_assignment, 0, True, "single_assignment", a=A, _if_cond_1=if_cond_1[0]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_single_assignment_cond_from_scalar(): A = np.ones(shape=(512, ), dtype=np.float64) before = single_assignment_cond_from_scalar.to_sdfg() @@ -1514,7 +1535,7 @@ def _get_sdfg_with_condition_from_transient_scalar() -> dace.SDFG: return sdfg -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_condition_from_transient_scalar(): zsolac = np.random.choice([8.0, 11.0], size=(N, )) zlcond2 = np.random.choice([8.0, 11.0], size=(N, )) @@ -1628,7 +1649,7 @@ def _get_disjoint_chain_sdfg() -> dace.SDFG: @pytest.mark.parametrize("rtt_val", [0.0, 4.0, 6.0]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_disjoint_chain_split_branch_only(rtt_val): sdfg, nsdfg_parent_state = _get_disjoint_chain_sdfg() sdfg.name = f"disjoint_chain_split_branch_only_rtt_val_{str(rtt_val).replace('.','_')}" @@ -1664,7 +1685,7 @@ def test_disjoint_chain_split_branch_only(rtt_val): @pytest.mark.parametrize("rtt_val", [0.0, 4.0, 6.0]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_disjoint_chain(rtt_val): sdfg, _ = _get_disjoint_chain_sdfg() zsolqa = np.random.choice([0.001, 5.0], size=(N, 5, 5)) @@ -1704,7 +1725,7 @@ def pattern_from_cloudsc_one( @pytest.mark.parametrize("c_val", [0.0, 1.0, 6.0]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_pattern_from_cloudsc_one(c_val): A = np.random.choice([0.001, 5.0], size=( 2, @@ -1744,7 +1765,7 @@ def map_param_usage( b[i, i] = zmdn -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_can_be_applied_on_map_param_usage(): A = np.random.choice([0.001, 5.0], size=( N, @@ -1828,7 +1849,7 @@ def _get_safe_map_param_use_in_nested_sdfg() -> dace.SDFG: return outer_sdfg -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_safe_map_param_use_in_nested_sdfg(): sdfg = _get_safe_map_param_use_in_nested_sdfg() sdfg.validate() @@ -1925,7 +1946,7 @@ def _get_nsdfg_with_return(return_arr: bool) -> dace.SDFG: @pytest.mark.parametrize("ret_arr", [True, False]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_nested_sdfg_with_return(ret_arr): sdfg = _get_nsdfg_with_return(ret_arr) sdfg.validate() @@ -2050,7 +2071,7 @@ def huge_sdfg(pap: dace.float64[N], ptsphy: dace.float64, r2es: dace.float64, r3 @pytest.mark.parametrize("eps_operator_type_for_log_and_div", ["max", "add"]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_huge_sdfg_with_log_exp_div(eps_operator_type_for_log_and_div: str): """Generate test data for the loop body function""" @@ -2131,7 +2152,7 @@ def safe_uniform(low, high, size): @pytest.mark.parametrize("eps_operator_type_for_log_and_div", ["max", "add"]) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_mid_sdfg_with_log_exp_div(eps_operator_type_for_log_and_div: str): """Generate test data for the loop body function""" @@ -2231,7 +2252,7 @@ def loop_param_usage(A: dace.float64[6, N, N], B: dace.float64[N, N], C: dace.fl C[i, j] = 2.0 + C[i, j] -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_loop_param_usage(): A = np.random.choice([0.001, 5.0], size=(6, N, N)) B = np.random.choice([0.001, 5.0], size=(N, N)) @@ -2252,7 +2273,7 @@ def test_loop_param_usage(): run_and_compare_sdfg(sdfg, False, "loop_param_usage", A=A, B=B, C=C) -@pytest.mark.no_autoopt +@temporarily_disable_autoopt_and_serialization def test_can_be_applied_on_wcr_edge(): sdfg = wcr_edge.to_sdfg()