diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 0d709ee0fd..a2c2265fc2 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -4,7 +4,8 @@ from dace import nodes, data, subsets from dace.properties import CodeBlock from dace.sdfg import InterstateEdge -from dace.sdfg.state import ConditionalBlock, LoopRegion, SDFGState +from dace.sdfg.state import LoopRegion, SDFGState +from dace.sdfg.state import SDFGState from dace.symbolic import symbol from dace.memlet import Memlet from typing import TYPE_CHECKING, Dict, Iterator, List, Literal, Optional, Set, Union diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 41f6969836..2fc3b6ad22 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -423,6 +423,10 @@ def can_be_applied(self, dim_exprs, variable_context, node_range, orig_edges, di if symbolic.issymbolic(dim): used_symbols.update(dim.free_symbols) + if any(s not in defined_vars for s in (used_symbols - set(self.params))): + # Cannot propagate symbols that are undefined outside scope (e.g., internal symbols) + return False + if (used_symbols & set(self.params) and any(symbolic.pystr_to_symbolic(s) not in defined_vars for s in node_range.free_symbols)): # Cannot propagate symbols that are undefined in the outer range @@ -1426,7 +1430,7 @@ def propagate_memlet(dfg_state, # Propagate subset if isinstance(entry_node, nodes.MapEntry): mapnode = entry_node.map - return propagate_subset(aggdata, arr, mapnode.params, mapnode.range, defined_vars, use_dst=use_dst) + return propagate_subset(aggdata, arr, mapnode.params, mapnode.range, defined_variables=defined_vars, use_dst=use_dst) elif isinstance(entry_node, nodes.ConsumeEntry): # Nothing to analyze/propagate in consume @@ -1445,7 +1449,9 @@ def propagate_subset(memlets: List[Memlet], arr: data.Data, params: List[str], rng: subsets.Subset, + *, defined_variables: Set[symbolic.SymbolicType] = None, + undefined_variables: Set[symbolic.SymbolicType] = None, use_dst: bool = False) -> Memlet: """ Tries to propagate a list of memlets through a range (computes the image of the memlet function applied on an integer set of, e.g., a @@ -1458,8 +1464,12 @@ def propagate_subset(memlets: List[Memlet], range to propagate with. :param defined_variables: A set of symbols defined that will remain the same throughout propagation. If None, assumes - that all symbols outside of `params` have been - defined. + that all symbols outside of ``params``, except + for ``undefined_variables``, have been defined. + :param undefined_variables: A set of symbols that are explicitly considered + as not defined throughout propagation, such as + locals. Their existence will trigger propagating + the entire memlet. :param use_dst: Whether to propagate the memlets' dst subset or use the src instead, depending on propagation direction. :return: Memlet with propagated subset and volume. @@ -1473,6 +1483,13 @@ def propagate_subset(memlets: List[Memlet], defined_variables |= memlet.free_symbols defined_variables -= set(params) defined_variables = set(symbolic.pystr_to_symbolic(p) for p in defined_variables) + else: + defined_variables = set(defined_variables) + + if undefined_variables: + defined_variables = defined_variables - set(symbolic.pystr_to_symbolic(p) for p in undefined_variables) + else: + undefined_variables = set() # Propagate subset variable_context = [defined_variables, [symbolic.pystr_to_symbolic(p) for p in params]] @@ -1503,18 +1520,25 @@ def propagate_subset(memlets: List[Memlet], tmp_subset = pattern.propagate(arr, [subset], rng) break else: - # No patterns found. Emit a warning and propagate the entire - # array whenever symbols are used - warnings.warn('Cannot find appropriate memlet pattern to ' - 'propagate %s through %s' % (str(subset), str(rng))) + # No patterns found. Propagate the entire array whenever symbols are used entire_array = subsets.Range.from_array(arr) paramset = set(map(str, params)) # Fill in the entire array only if one of the parameters appears in the - # free symbols list of the subset dimension - tmp_subset = subsets.Range([ - ea if any(set(map(str, _freesyms(sd))) & paramset for sd in s) else s - for s, ea in zip(subset, entire_array) - ]) + # free symbols list of the subset dimension or is undefined outside + tmp_subset_rng = [] + for s, ea in zip(subset, entire_array): + contains_params = False + contains_undefs = False + for sdim in s: + fsyms = _freesyms(sdim) + fsyms_str = set(map(str, fsyms)) + contains_params |= len(fsyms_str & paramset) != 0 + contains_undefs |= len(fsyms & undefined_variables) != 0 + if contains_params or contains_undefs: + tmp_subset_rng.append(ea) + else: + tmp_subset_rng.append(s) + tmp_subset = subsets.Range(tmp_subset_rng) # Union edges as necessary if new_subset is None: diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 19e9820def..82ac4767b7 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -384,7 +384,7 @@ def test_map_in_loop(): def test_map_in_loop_multiplied_indices_first_dimension(): """ Map nested in a loop that writes to array. Subscript expression - of array access multiplies two indicies in first dimension + of array access multiplies two indices in first dimension --> Approximated write-set of loop to array is empty """ @@ -417,7 +417,7 @@ def test_map_in_loop_multiplied_indices_first_dimension(): def test_map_in_loop_multiplied_indices_second_dimension(): """ Map nested in a loop that writes to array. Subscript expression - of array access multiplies two indicies in second dimension + of array access multiplies two indices in second dimension --> Approximated write-set of loop to array is empty """ sdfg = dace.SDFG("nested")