From 9f31211eef6869e6cc7c8367ad2f77756e167706 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 2 Jan 2024 09:30:57 -0800 Subject: [PATCH 1/3] Support undefined variables in memlet propagation --- dace/sdfg/analysis/schedule_tree/treenodes.py | 3 +- dace/sdfg/propagation.py | 43 ++++++++++++++----- 2 files changed, 34 insertions(+), 12 deletions(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 0d709ee0fd..a2c2265fc2 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -4,7 +4,8 @@ from dace import nodes, data, subsets from dace.properties import CodeBlock from dace.sdfg import InterstateEdge -from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState +from dace.sdfg.state import LoopRegion, SDFGState +from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet from typing import TYPE_CHECKING, Dict, Iterator, List, Literal, Optional, Set, Union diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 41f6969836..4c7ff03a0a 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -423,6 +423,10 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di if symbolic.issymbolic(dim): used_symbols.update(dim.free_symbols) + if any(s not in defined_vars for s in (used_symbols - set(self.params))): + # Cannot propagate symbols that are undefined outside scope (e.g., internal symbols) + return False + if (used_symbols & set(self.params) and any(symbolic.pystr_to_symbolic(s) not in defined_vars for s in node_range.free_symbols)): # Cannot propagate symbols that are undefined in the outer range @@ -1446,6 +1450,7 @@ def propagate_subset(memlets: List[Memlet], params: List[str], rng: subsets.Subset, defined_variables: Set[symbolic.SymbolicType] = None, + undefined_variables: Set[symbolic.SymbolicType] = None, use_dst: bool = False) -> Memlet: """ Tries to propagate a list of memlets through a range (computes the image of the memlet function applied on an integer set of, e.g., a @@ -1458,8 +1463,12 @@ def propagate_subset(memlets: List[Memlet], range to propagate with. :param defined_variables: A set of symbols defined that will remain the same throughout propagation. If None, assumes - that all symbols outside of `params` have been - defined. + that all symbols outside of ``params``, except + for ``undefined_variables``, have been defined. + :param undefined_variables: A set of symbols that are explicitly considered + as not defined throughout propagation, such as + locals. Their existence will trigger propagating + the entire memlet. :param use_dst: Whether to propagate the memlets' dst subset or use the src instead, depending on propagation direction. :return: Memlet with propagated subset and volume. @@ -1473,6 +1482,11 @@ def propagate_subset(memlets: List[Memlet], defined_variables |= memlet.free_symbols defined_variables -= set(params) defined_variables = set(symbolic.pystr_to_symbolic(p) for p in defined_variables) + else: + defined_variables = set(defined_variables) + + if undefined_variables: + defined_variables = defined_variables - set(symbolic.pystr_to_symbolic(p) for p in undefined_variables) # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1503,18 +1517,25 @@ def propagate_subset(memlets: List[Memlet], tmp_subset = pattern.propagate(arr, [subset], rng) break else: - # No patterns found. Emit a warning and propagate the entire - # array whenever symbols are used - warnings.warn('Cannot find appropriate memlet pattern to ' - 'propagate %s through %s' % (str(subset), str(rng))) + # No patterns found. Propagate the entire array whenever symbols are used entire_array = subsets.Range.from_array(arr) paramset = set(map(str, params)) # Fill in the entire array only if one of the parameters appears in the - # free symbols list of the subset dimension - tmp_subset = subsets.Range([ - ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s - for s, ea in zip(subset, entire_array) - ]) + # free symbols list of the subset dimension or is undefined outside + tmp_subset_rng = [] + for s, ea in zip(subset, entire_array): + contains_params = False + contains_undefs = False + for sdim in s: + fsyms = _freesyms(sdim) + fsyms_str = set(map(str, fsyms)) + contains_params |= len(fsyms_str & paramset) != 0 + contains_undefs |= len(fsyms - defined_variables) != 0 + if contains_params or contains_undefs: + tmp_subset_rng.append(ea) + else: + tmp_subset_rng.append(s) + tmp_subset = subsets.Range(tmp_subset_rng) # Union edges as necessary if new_subset is None: From eb8a3c656425403f927faa49f32071aa1d49b126 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 12 Jan 2026 11:10:04 +0100 Subject: [PATCH 2/3] unrelated: fix typos in tests --- tests/passes/writeset_underapproximation_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 19e9820def..82ac4767b7 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -384,7 +384,7 @@ def test_map_in_loop(): def test_map_in_loop_multiplied_indices_first_dimension(): """ Map nested in a loop that writes to array. Subscript expression - of array access multiplies two indicies in first dimension + of array access multiplies two indices in first dimension --> Approximated write-set of loop to array is empty """ @@ -417,7 +417,7 @@ def test_map_in_loop_multiplied_indices_first_dimension(): def test_map_in_loop_multiplied_indices_second_dimension(): """ Map nested in a loop that writes to array. Subscript expression - of array access multiplies two indicies in second dimension + of array access multiplies two indices in second dimension --> Approximated write-set of loop to array is empty """ sdfg = dace.SDFG("nested") From f2aa826ed420c34062d4bbfde84ef11b5f7b001b Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Mon, 12 Jan 2026 11:11:13 +0100 Subject: [PATCH 3/3] only undefined symbols trigger full memlet propagation --- dace/sdfg/propagation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 4c7ff03a0a..2fc3b6ad22 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1430,7 +1430,7 @@ def propagate_memlet(dfg_state, # Propagate subset if isinstance(entry_node, nodes.MapEntry): mapnode = entry_node.map - return propagate_subset(aggdata, arr, mapnode.params, mapnode.range, defined_vars, use_dst=use_dst) + return propagate_subset(aggdata, arr, mapnode.params, mapnode.range, defined_variables=defined_vars, use_dst=use_dst) elif isinstance(entry_node, nodes.ConsumeEntry): # Nothing to analyze/propagate in consume @@ -1449,6 +1449,7 @@ def propagate_subset(memlets: List[Memlet], arr: data.Data, params: List[str], rng: subsets.Subset, + *, defined_variables: Set[symbolic.SymbolicType] = None, undefined_variables: Set[symbolic.SymbolicType] = None, use_dst: bool = False) -> Memlet: @@ -1487,6 +1488,8 @@ def propagate_subset(memlets: List[Memlet], if undefined_variables: defined_variables = defined_variables - set(symbolic.pystr_to_symbolic(p) for p in undefined_variables) + else: + undefined_variables = set() # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1530,7 +1533,7 @@ def propagate_subset(memlets: List[Memlet], fsyms = _freesyms(sdim) fsyms_str = set(map(str, fsyms)) contains_params |= len(fsyms_str & paramset) != 0 - contains_undefs |= len(fsyms - defined_variables) != 0 + contains_undefs |= len(fsyms & undefined_variables) != 0 if contains_params or contains_undefs: tmp_subset_rng.append(ea) else: