Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import dace
from dace import (
data as dace_data,
dtypes as dace_dtypes,
properties as dace_properties,
subsets as dace_sbs,
Expand Down Expand Up @@ -154,6 +155,16 @@ def can_be_applied(
if all(len(rel_df) == 0 for rel_df in relocatable_dataflow.values()):
return False

# Check if relatability is possible.
if not self._check_for_data_and_symbol_conflicts(
sdfg=sdfg,
state=graph,
relocatable_dataflow=relocatable_dataflow,
enclosing_map=enclosing_map,
if_block=if_block,
):
return False

# Because the transformation can only handle `if` expressions that
# are _directly_ inside a Map, we must check if the upstream contains
# suitable `if` expressions that must be processed first. The simplest way
Expand Down Expand Up @@ -480,6 +491,74 @@ def _update_symbol_mapping(
or dace_dtypes.typeclass(int),
)

def _check_for_data_and_symbol_conflicts(
self,
sdfg: dace.SDFG,
state: dace.SDFGState,
relocatable_dataflow: dict[str, set[dace_nodes.Node]],
if_block: dace_nodes.NestedSDFG,
enclosing_map: dace_nodes.MapEntry,
) -> bool:
"""Check if the relocation would cause any conflict, such as a symbol clash."""

# TODO(phimuell): There is an obscure case where the nested SDFG, on its own,
# defines a symbol that is also mapped, for example a dynamic Map range.
# It is probably not a problem, because of the scopes DaCe adds when
# generating the C++ code.

# Create a subgraph to compute the free symbols, i.e. the symbols that
# need to be supplied from the outside. However, this are not all.
# Note, just adding some "well chosen" nodes to the set will not work.
all_relocated_dataflow: set[dace_nodes.Node] = functools.reduce(
lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set()
)
requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView(
state, all_relocated_dataflow
).free_symbols

inner_data_names = if_block.sdfg.arrays.keys()
for node_to_check in all_relocated_dataflow:
if (
isinstance(node_to_check, dace_nodes.AccessNode)
and node_to_check.data in inner_data_names
):
# There is already a data descriptor that is used on the inside as on
# the outside. Thus we would have to perform some renaming, which we
# currently do not.
# TODO(phimell): Handle this case.
return False

for iedge in state.in_edges(node_to_check):
src_node = iedge.src
if src_node not in all_relocated_dataflow:
# This means that `src_node` is not relocated but mapped into the
# `if` block. This means that `edge` is replicated as well.
# NOTE: This code is based on the one found in `DataflowGraphView`.
# TODO(phimuell): Do we have to inspect the full Memlet path here?
assert isinstance(src_node, dace_nodes.AccessNode) or src_node is enclosing_map
requiered_symbols |= iedge.data.used_symbols(True, edge=iedge)

# The (beyond the enclosing Map) data is also mapped into the `if` block, so we
# have to consider that as well.
for iedge in state.in_edges(if_block):
if iedge.src is enclosing_map and (not iedge.data.is_empty()):
outside_desc = sdfg.arrays[iedge.data.data]
if isinstance(outside_desc, dace_data.View):
return False # Handle this case.
requiered_symbols |= outside_desc.used_symbols(True)

# A conflicting symbol is a free symbol of the relocatable dataflow, that is not a
# direct mapping. For example if there is a symbol `n` on the inside and outside
# then everything is okay if the symbol mapping is `{n: n}` i.e. the symbol has the
# same meaning inside and outside. Everything else is not okay.
symbol_mapping = if_block.symbol_mapping
conflicting_symbols = requiered_symbols.intersection((str(k) for k in symbol_mapping))
for conflicting_symbol in conflicting_symbols:
if conflicting_symbol != str(symbol_mapping[conflicting_symbol]):
return False

return True

def _find_branch_for(
self,
if_block: dace_nodes.NestedSDFG,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def _perform_test(
return sdfg

# General case, run the SDFG first and then compare the result.
ref, res = util.make_sdfg_args(sdfg)

if explected_applies != 0:
ref, res = util.make_sdfg_args(sdfg)
util.compile_and_run_sdfg(sdfg, **ref)

nb_apply = sdfg.apply_transformations_repeated(
Expand Down Expand Up @@ -1199,3 +1198,102 @@ def test_if_mover_access_node_between():
"__cond",
}
assert set(top_if_block.sdfg.arrays.keys()) == expected_top_if_block_data


def test_if_mover_symbol_aliasing():
"""Tests if symbol clashes are detected.

Essentially there is a symbol `n` both in the parent SDFG and the `if_block`,
however, with different meanings. Thus the relocation will lead to invalid
behaviour and should be rejected.
"""
sdfg = dace.SDFG(util.unique_name("if_mover_symbol_alias"))
state = sdfg.add_state(is_start_block=True)

scalar_names = ["cond", "a1", "b2"]
array_names = list("abcd")
sdfg.add_symbol("n", stype=dace.int32)
for aname in array_names:
sdfg.add_array(
aname,
shape=((10, "n") if aname in "ab" else (10,)),
dtype=dace.float64,
transient=False,
)
for sname in scalar_names:
sdfg.add_scalar(
sname,
dtype=(dace.bool_ if sname == "cond" else dace.float64),
transient=True,
)
a, b, c, d, cond_ac, true_ac, false_ac = (
state.add_access(name) for name in array_names + scalar_names
)

me, mx = state.add_map("outer_map", ndrange={"__i": "0:10"})

for ac in [a, b, c]:
state.add_edge(
ac,
None,
me,
f"IN_{ac.data}",
dace.Memlet(f"{ac.data}[0:10" + ("]" if ac is c else ", 0:n]")),
)
me.add_scope_connectors(ac.data)

# Make the condition.
cond_tlet = state.add_tasklet(
"cond_tlet",
inputs={"__in0"},
outputs={"__out"},
code="__out = __in0 < 0.0",
)
state.add_edge(me, "OUT_c", cond_tlet, "__in0", dace.Memlet("c[__i]"))
state.add_edge(cond_tlet, "__out", cond_ac, None, dace.Memlet(f"{cond_ac.data}[0]"))

# The true branch.
true_tlet = state.add_tasklet(
"true_tlet",
inputs={"__in0"},
outputs={"__out"},
code="__out = __in0 + 1.0",
)
state.add_edge(me, "OUT_a", true_tlet, "__in0", dace.Memlet("a[__i, n - 1]"))
state.add_edge(true_tlet, "__out", true_ac, None, dace.Memlet(f"{true_ac.data}[0]"))

# False branch
false_tlet = state.add_tasklet(
"false_tlet",
inputs={"__in0"},
outputs={"__out"},
code="__out = __in0 + 1.0",
)
state.add_edge(me, "OUT_b", false_tlet, "__in0", dace.Memlet("b[__i, n - 3]"))
state.add_edge(false_tlet, "__out", false_ac, None, dace.Memlet(f"{false_ac.data}[0]"))

# Create the top `if_block`
if_block = _make_if_block(state, sdfg)

# By Adding this symbol mapping, we emulate the case where something is used
# inside and special case must be taken.
assert len(if_block.symbol_mapping) == 0
if_block.symbol_mapping["n"] = "n - 1"

# Connect the inputs to the if block.
state.add_edge(true_ac, None, if_block, "__arg1", dace.Memlet(f"{true_ac}[0]"))
state.add_edge(false_ac, None, if_block, "__arg2", dace.Memlet(f"{false_ac}[0]"))
state.add_edge(cond_ac, None, if_block, "__cond", dace.Memlet(f"{cond_ac}[0]"))

state.add_edge(if_block, "__output", mx, "IN_d", dace.Memlet("d[__i]"))
state.add_edge(mx, "OUT_d", d, None, dace.Memlet("d[0:10]"))
mx.add_scope_connectors("d")

sdfg.validate()

# Because `n` is already taken, see above, we need an additional symbol mapping
# to account for the access on the Memlets of the `{true, false}_tlet`.
_perform_test(
sdfg=sdfg,
explected_applies=0,
)