Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
923e3cd
Enable maxnreg setting
iomaganaris Nov 27, 2025
0d2b3e1
Added test for gpu_maxnreg
iomaganaris Jan 29, 2026
215e775
Fix formating
iomaganaris Jan 29, 2026
d58ade9
Mention that maxnreg takes precedence over launch bounds
iomaganaris Feb 3, 2026
956282f
Remove scalar copies
iomaganaris Jan 22, 2026
03c8cbb
Rename if statements in _make_if_block
iomaganaris Jan 23, 2026
e10b2ee
[WIP] Added fuse condition block transformation
iomaganaris Jan 23, 2026
a2029f3
[WIP] Copy access nodes between condition blocks
iomaganaris Jan 26, 2026
8961d5e
Fix branch selection
iomaganaris Jan 26, 2026
49b7619
Moved kill aliasing transformation and updates to the fuse conditionals
iomaganaris Jan 27, 2026
eb462f6
Rename properly second map arrays
iomaganaris Jan 27, 2026
73dc4c1
Enable if grouping in proper place
iomaganaris Jan 28, 2026
08ce68c
Rename kill aliasing scalars to remove
iomaganaris Jan 29, 2026
e110ab2
Cleared a bit FuseConditionalBlocks
iomaganaris Jan 29, 2026
15ae54d
Fixed imports
iomaganaris Jan 29, 2026
8f44eb5
Remove sdfg.save
iomaganaris Feb 2, 2026
6a8b1b4
Fix issues after cherry-picking
iomaganaris Feb 3, 2026
0ce4337
Applied suggestions to FuseHorizontalConditionBlocks and moved unique…
iomaganaris Feb 3, 2026
3319c00
Apply review comments on RemoveScalarCopies
iomaganaris Feb 3, 2026
202aac6
Remove _dacegraphs
iomaganaris Feb 3, 2026
0cf2be2
Merge remote-tracking branch 'origin/set_gpu_maxnreg' into new_graupe…
iomaganaris Feb 3, 2026
9e6671c
Added more elaborative comment in FuseHorizontalConditionBlocks for r…
iomaganaris Feb 3, 2026
c55a04a
make sure that the symbol mapping is the same between the fused neste…
iomaganaris Feb 3, 2026
d232b73
Added check for NestedSDFGs in FuseHorizontalConditionBlocks
iomaganaris Feb 3, 2026
ed7835d
Handle single use data check only for second AccessNode in RemoveScal…
iomaganaris Feb 3, 2026
678f18f
Fix issues with true/false_branch_x_x_x
iomaganaris Feb 4, 2026
562742b
Address Philip's comments
iomaganaris Feb 10, 2026
d237219
Merge branch 'set_gpu_maxnreg' into graupel_group_ifs_updated_main
iomaganaris Feb 10, 2026
636f4c5
Fix issues
iomaganaris Feb 10, 2026
caf69a3
Address Philip's comments
iomaganaris Feb 10, 2026
4f03b21
Removed check for nestedsdfg
iomaganaris Feb 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
gt_auto_optimize,
)
from .dead_dataflow_elimination import gt_eliminate_dead_dataflow, gt_remove_map
from .fuse_horizontal_conditionblocks import FuseHorizontalConditionBlocks
from .gpu_utils import (
GPUSetBlockSize,
gt_gpu_transform_non_standard_memlet,
Expand Down Expand Up @@ -53,6 +54,7 @@
)
from .redundant_array_removers import CopyChainRemover, DoubleWriteRemover, gt_remove_copy_chain
from .remove_access_node_copies import RemoveAccessNodeCopies
from .remove_scalar_copies import RemoveScalarCopies
from .remove_views import RemovePointwiseViews
from .scan_loop_unrolling import ScanLoopUnrolling
from .simplify import (
Expand All @@ -77,12 +79,13 @@
gt_propagate_strides_from_access_node,
gt_propagate_strides_of,
)
from .utils import gt_make_transients_persistent
from .utils import gt_make_transients_persistent, unique_name


__all__ = [
"CopyChainRemover",
"DoubleWriteRemover",
"FuseHorizontalConditionBlocks",
"GPUSetBlockSize",
"GT4PyAutoOptHook",
"GT4PyAutoOptHookFun",
Expand All @@ -105,6 +108,7 @@
"MultiStateGlobalSelfCopyElimination2",
"RemoveAccessNodeCopies",
"RemovePointwiseViews",
"RemoveScalarCopies",
"ScanLoopUnrolling",
"SingleStateGlobalDirectSelfCopyElimination",
"SingleStateGlobalSelfCopyElimination",
Expand Down Expand Up @@ -138,4 +142,5 @@
"gt_vertical_map_split_fusion",
"inline_dataflow_into_map",
"splitting_tools",
"unique_name",
]
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,25 @@ def _gt_auto_process_dataflow_inside_maps(
validate_all=validate_all,
)

find_single_use_data = dace_analysis.FindSingleUseData()
single_use_data = find_single_use_data.apply_pass(sdfg, None)

sdfg.apply_transformations_repeated(
gtx_transformations.RemoveScalarCopies(
single_use_data=single_use_data,
),
validate=False,
validate_all=validate_all,
)

# Make sure that this runs before MoveDataflowIntoIfBody because atm it doesn't handle
# NestedSDFGs inside the ConditionalBlocks it fuses.
sdfg.apply_transformations_repeated(
gtx_transformations.FuseHorizontalConditionBlocks(),
validate=True,
validate_all=True,
)

# Move dataflow into the branches of the `if` such that they are only evaluated
# if they are needed. Important to call it repeatedly.
# TODO(phimuell): It is unclear if `MoveDataflowIntoIfBody` should be called
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
# GT4Py - GridTools Framework
#
# Copyright (c) 2014-2024, ETH Zurich
# All rights reserved.
#
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

import copy
from typing import Any, Union

import dace
from dace import properties as dace_properties, transformation as dace_transformation
from dace.sdfg import graph as dace_graph, nodes as dace_nodes
from dace.transformation import helpers as dace_helpers

from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations


@dace_properties.make_properties
class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation):
"""Fuses two conditional blocks that share the same condition variable and are
not dependent to each other (i.e. the output of one of them is used as input to the other)
into a single conditional block.
The motivation for this transformation is to reduce the number of conditional blocks
which generate if statements in the CPU or GPU code, which can lead to performance improvements.
Example:
Before fusion:
```
if __cond:
__output1 = __arg1 * 2.0
else:
__output1 = __arg2 + 3.0
if __cond:
__output2 = __arg3 - 1.0
else:
__output2 = __arg4 / 4.0
```
After fusion:
```
if __cond:
__output1 = __arg1 * 2.0
__output2 = __arg3 - 1.0
else:
__output1 = __arg2 + 3.0
__output2 = __arg4 / 4.0
```
"""

conditional_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode)
first_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG)
second_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG)

@classmethod
def expressions(cls) -> Any:
conditionalblock_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph()
conditionalblock_fusion_parallel_match.add_nedge(
cls.conditional_access_node, cls.first_conditional_block, dace.Memlet()
)
conditionalblock_fusion_parallel_match.add_nedge(
cls.conditional_access_node, cls.second_conditional_block, dace.Memlet()
)
return [conditionalblock_fusion_parallel_match]

def can_be_applied(
self,
graph: Union[dace.SDFGState, dace.SDFG],
expr_index: int,
sdfg: dace.SDFG,
permissive: bool = False,
) -> bool:
conditional_access_node: dace_nodes.AccessNode = self.conditional_access_node
conditional_access_node_desc = conditional_access_node.desc(sdfg)
first_cb: dace_nodes.NestedSDFG = self.first_conditional_block
second_cb: dace_nodes.NestedSDFG = self.second_conditional_block
scope_dict = graph.scope_dict()

# Check that the common access node is a boolean scalar
if (
not isinstance(conditional_access_node_desc, dace.data.Scalar)
or conditional_access_node_desc.dtype != dace.bool_
):
return False

# Check that both conditional blocks are in the same parent SDFG
if first_cb.sdfg.parent != second_cb.sdfg.parent:
return False

# Check that the nested SDFGs' names starts with "if_stmt"
if not (
first_cb.sdfg.name.startswith("if_stmt") and second_cb.sdfg.name.startswith("if_stmt")
):
return False

# Make sure that the conditional blocks contain only one conditional block each
if first_cb.sdfg.number_of_nodes() != 1 or second_cb.sdfg.number_of_nodes() != 1:
return False

# Check that the symbol mappings are compatible
sym_map1 = first_cb.symbol_mapping
sym_map2 = second_cb.symbol_mapping
if any(str(sym_map1[sym]) != str(sym_map2[sym]) for sym in sym_map2 if sym in sym_map1):
return False

# Get the actual conditional blocks
first_conditional_block = next(iter(first_cb.sdfg.nodes()))
second_conditional_block = next(iter(second_cb.sdfg.nodes()))
if not (
isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock)
and len(first_conditional_block.sub_regions()) == 2
and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock)
and len(second_conditional_block.sub_regions()) == 2
):
return False
first_conditional_block_state_names = [
state.name for state in first_conditional_block.all_states()
]
second_conditional_block_state_names = [
state.name for state in second_conditional_block.all_states()
]
if not (
any("true_branch" in name for name in first_conditional_block_state_names)
and any("false_branch" in name for name in first_conditional_block_state_names)
and any("true_branch" in name for name in second_conditional_block_state_names)
and any("false_branch" in name for name in second_conditional_block_state_names)
):
return False
Comment on lines +115 to +127
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is kind of okay, but I would look at the conditions this would make it possible that there are multiple states inside the region.
However, I fine with a TODO in that regard.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think atm there's any possibility to get multiple states that are not true/false_branch so this check is mostly for sanity

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then a TODO should be sufficient.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't understand, would you like to remove this check or would you like to check if there can be more that true/false branches inside?
The latter is checked above:

if not (
            isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock)
            and len(first_conditional_block.sub_regions()) == 2
            and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock)
            and len(second_conditional_block.sub_regions()) == 2
        ):


# Make sure that both conditional blocks are in the same scope
if scope_dict[first_cb] != scope_dict[second_cb]:
return False

# Make sure that both conditional blocks are in a map scope
if not isinstance(scope_dict[first_cb], dace.nodes.MapEntry):
return False

# Check that there is an edge to the conditional blocks with dst_conn == "__cond"
cond_edges_first = [
e for e in graph.in_edges(first_cb) if e.dst_conn and e.dst_conn == "__cond"
]
if len(cond_edges_first) != 1:
return False
cond_edges_second = [
e for e in graph.in_edges(second_cb) if e.dst_conn and e.dst_conn == "__cond"
]
if len(cond_edges_second) != 1:
return False
cond_edge_first = cond_edges_first[0]
cond_edge_second = cond_edges_second[0]
if cond_edge_first.data.is_empty() or cond_edge_second.data.is_empty():
return False
if not all(
cond_edge.src is conditional_access_node
for cond_edge in [cond_edge_first, cond_edge_second]
):
return False

# Need to check also that first and second nested SDFGs are not reachable from each other
if gtx_transformations.utils.is_reachable(
start=first_cb,
target=second_cb,
state=graph,
) or gtx_transformations.utils.is_reachable(
start=second_cb,
target=first_cb,
state=graph,
):
return False

return True

def apply(
self,
graph: Union[dace.SDFGState, dace.SDFG],
sdfg: dace.SDFG,
) -> None:
conditional_access_node: dace_nodes.AccessNode = self.conditional_access_node
first_cb: dace_nodes.NestedSDFG = self.first_conditional_block
second_cb: dace_nodes.NestedSDFG = self.second_conditional_block

first_conditional_block = next(iter(first_cb.sdfg.nodes()))
second_conditional_block = next(iter(second_cb.sdfg.nodes()))

# Store number of original arrays to check later that all the necessary arrays have been moved
total_original_arrays = len(first_conditional_block.sdfg.arrays) + len(
second_conditional_block.sdfg.arrays
)

# Store the new names for the arrays of the second conditional block (transients and globals) to avoid name clashes and add their data descriptors
# to the first conditional block SDFG. We don't have to add `__cond` because we know it's the same for both conditional blocks.
# TODO(iomaganaris): Remove inputs to the conditional block that come from the same AccessNodes (same data)
second_arrays_rename_map: dict[str, str] = {}
for data_name, data_desc in second_conditional_block.sdfg.arrays.items():
if data_name == "__cond":
continue
new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion"
second_arrays_rename_map[data_name] = new_data_name
data_desc_renamed = copy.deepcopy(data_desc)
first_cb.sdfg.add_datadesc(new_data_name, data_desc_renamed)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is something missing here.
While in can_be_applied() you check if there are no symbol conflicts you never copy the symbols that are exclusive to the second conditional block.
Put it differently you detect the case where SDFG 1 has n := a + 1 and SDFG 2 has n := a - 1, but you do not handle the case where SDFG 1 has m := a + 1 and SDFG 2 has o := a + 1.
So you must copy them as well, for this you have to do something like:

missing_symbols = {sym2: val2 for sym2, val2 in second_cb.symbol_mapping.items() if sym2 not in first_cb.symbol_mapping}
for missing_symb, symb_def in missing_symbols.items():
    first_cb.symbol_mapping[missing_symb] = symb_def
    first_cb.add_symbol(missing_symb, second_cb.sdfg.symbols[missing_symb], find_new_name=False)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to figure out how to add this to a test 👍

second_conditional_states = list(second_conditional_block.all_states())

# Move the connectors from the second conditional block to the first
for edge in graph.in_edges(second_cb):
if edge.dst_conn == "__cond":
continue
first_cb.add_in_connector(second_arrays_rename_map[edge.dst_conn])
dace_helpers.redirect_edge(
state=graph,
edge=edge,
new_dst_conn=second_arrays_rename_map[edge.dst_conn],
new_dst=first_cb,
)
for edge in graph.out_edges(second_cb):
first_cb.add_out_connector(second_arrays_rename_map[edge.src_conn])
dace_helpers.redirect_edge(
state=graph,
edge=edge,
new_src_conn=second_arrays_rename_map[edge.src_conn],
new_src=first_cb,
)

def _find_corresponding_state_in_second(
inner_state: dace.SDFGState,
) -> dace.SDFGState:
is_true_branch = "true_branch" in inner_state.name
branch_type = "true_branch" if is_true_branch else "false_branch"
return next(state for state in second_conditional_states if branch_type in state.name)

# Copy first the nodes from the second conditional block to the first
# Create a dictionary that maps the original nodes in the second conditional
# block to the new nodes in the first conditional block to be able to properly connect the edges later
nodes_renamed_map: dict[dace_nodes.Node, dace_nodes.Node] = {}
for first_inner_state in first_conditional_block.all_states():
corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state)
# Save edges of second conditional block to a state to be able to delete the nodes from the second conditional block
edges_to_copy = list(corresponding_state_in_second.edges())
nodes_to_move = list(corresponding_state_in_second.nodes())
for node in nodes_to_move:
new_node = node
if isinstance(node, dace_nodes.AccessNode):
new_data_name = second_arrays_rename_map[node.data]
new_node = dace_nodes.AccessNode(new_data_name)
nodes_renamed_map[node] = new_node
# Remove the original node from the second conditional block to avoid any potential issues
# with the nodes coexisting in two states
corresponding_state_in_second.remove_node(node)
first_inner_state.add_node(new_node)

for edge_to_copy in edges_to_copy:
new_edge = first_inner_state.add_edge(
nodes_renamed_map[edge_to_copy.src],
edge_to_copy.src_conn,
nodes_renamed_map[edge_to_copy.dst],
edge_to_copy.dst_conn,
edge_to_copy.data,
)
if (
not new_edge.data.is_empty()
) and new_edge.data.data in second_arrays_rename_map:
new_edge.data.data = second_arrays_rename_map[new_edge.data.data]

for edge in list(graph.out_edges(conditional_access_node)):
if edge.dst == second_cb:
graph.remove_edge(edge)

# Copy missing symbols from second conditional block to the first one
missing_symbols = {
sym2: val2
for sym2, val2 in second_cb.symbol_mapping.items()
if sym2 not in first_cb.symbol_mapping
}
for missing_symb, symb_def in missing_symbols.items():
first_cb.symbol_mapping[missing_symb] = symb_def
first_cb.add_symbol(
missing_symb, second_cb.sdfg.symbols[missing_symb], find_new_name=False
)

# TODO(iomaganaris): Atm need to remove both references to remove NestedSDFG from graph
# second_conditional_block is inside the SDFG of NestedSDFG second_cb and removing only
# one of them keeps a reference to the other one so none is properly deleted from the SDFG.
# For now remove both but maybe this can be improved in the future.
graph.remove_node(second_conditional_block)
graph.remove_node(second_cb)

new_arrays = len(first_cb.sdfg.arrays)
assert new_arrays == total_original_arrays - 1, (
f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}"
)
Loading