From 9cade71d42a129334bdb31de88facb0c2fc34e15 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Jan 2026 08:24:25 +0100 Subject: [PATCH 01/61] Let's try this order. NOT WORKING: 5.7612245082855225 --- .../dace/transformations/auto_optimize.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1d04c21fc3..92361f36b8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -15,6 +15,7 @@ import dace from dace import data as dace_data from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation, utils as dace_sdutils +from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize from dace.transformation.passes import analysis as dace_analysis @@ -674,13 +675,7 @@ def _gt_auto_process_dataflow_inside_maps( time, so the compiler will fully unroll them anyway. """ - # Constants (tasklets are needed to write them into a variable) should not be - # arguments to a kernel but be present inside the body. - sdfg.apply_transformations_once_everywhere( - gtx_transformations.GT4PyMoveTaskletIntoMap, - validate=False, - validate_all=validate_all, - ) + # TODO(phimuell): Find out if needed. gtx_transformations.gt_simplify( sdfg, skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST, @@ -737,6 +732,20 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + sdfg.apply_transformations_repeated( + dace_dataflow.TaskletFusion, + validate=False, + validate_all=validate_all, + ) + + # Constants (tasklets are needed to write them into a variable) should not be + # arguments to a kernel but be present inside the body. + sdfg.apply_transformations_once_everywhere( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate=False, + validate_all=validate_all, + ) + return sdfg From cba099634b6cec6c9389b17780ff50b07b940211 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Jan 2026 08:42:36 +0100 Subject: [PATCH 02/61] Maybe this is better. DOES NOT WORK: 5.77982020 --- .../runners/dace/transformations/auto_optimize.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 92361f36b8..1fbbf82af8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -696,6 +696,12 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + sdfg.apply_transformations_repeated( + dace_dataflow.TaskletFusion, + validate=False, + validate_all=validate_all, + ) + # Move dataflow into the branches of the `if` such that they are only evaluated # if they are needed. Important to call it repeatedly. # TODO(phimuell): It is unclear if `MoveDataflowIntoIfBody` should be called @@ -732,12 +738,6 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - sdfg.apply_transformations_repeated( - dace_dataflow.TaskletFusion, - validate=False, - validate_all=validate_all, - ) - # Constants (tasklets are needed to write them into a variable) should not be # arguments to a kernel but be present inside the body. sdfg.apply_transformations_once_everywhere( From 8a389ee590f59c4c2bca5c9414eb232adc1b0fa1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Jan 2026 08:58:36 +0100 Subject: [PATCH 03/61] Maybe the simplify call was unneeded. NOT WORKING: 5.89192s --- .../runners/dace/transformations/auto_optimize.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1fbbf82af8..27da81cfac 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -675,14 +675,6 @@ def _gt_auto_process_dataflow_inside_maps( time, so the compiler will fully unroll them anyway. """ - # TODO(phimuell): Find out if needed. - gtx_transformations.gt_simplify( - sdfg, - skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST, - validate=False, - validate_all=validate_all, - ) - # Blocking is performed first, because this ensures that as much as possible # is moved into the k independent part. if blocking_dim is not None: From 244dc10c2c987e77d52141fec9452fbfe8806101 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Jan 2026 09:26:52 +0100 Subject: [PATCH 04/61] This is nearer at the empirical version, let's try it. SEEMS WORKING: 4.57165s --- .../dace/transformations/auto_optimize.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 27da81cfac..5ccaa2c182 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -694,6 +694,16 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + # Constants (tasklets are needed to write them into a variable) should not be + # arguments to a kernel but be present inside the body. + sdfg.apply_transformations_once_everywhere( + gtx_transformations.GT4PyMoveTaskletIntoMap, + validate=False, + validate_all=validate_all, + ) + + # TODO(phimuell): Do we need a simplify here. + # Move dataflow into the branches of the `if` such that they are only evaluated # if they are needed. Important to call it repeatedly. # TODO(phimuell): It is unclear if `MoveDataflowIntoIfBody` should be called @@ -730,14 +740,6 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - # Constants (tasklets are needed to write them into a variable) should not be - # arguments to a kernel but be present inside the body. - sdfg.apply_transformations_once_everywhere( - gtx_transformations.GT4PyMoveTaskletIntoMap, - validate=False, - validate_all=validate_all, - ) - return sdfg From ac2c5ce1175e4706a01e8a3aaa86d882c1dbdfdd Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 22 Jan 2026 09:55:14 +0100 Subject: [PATCH 05/61] This is a bit nicer than the previous version, i.e. it has an explanation. But it also has an additional simplify that was present when TF was run in stage 1, but not in the other version. PERFORMANCE: 4.5106589s --- .../dace/transformations/auto_optimize.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 5ccaa2c182..43cb769d39 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -688,6 +688,13 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + # Empirical observation in MuPhys have shown that running `TaskletFusion` increases + # performance quite drastically. Thus it was added here. However, to ensure + # that `LoopBlocking` still works, i.e. independent and dependent Tasklets are + # not mixed it must run _after_ `LoopBlocking`. Furthermore, it has been shown + # that it has to run _before_ `GT4PyMoveTaskletIntoMap`. The reasons are not + # clear but it can be measured. + # TODO(phimuell): Restrict it to Tasklets only inside Maps. sdfg.apply_transformations_repeated( dace_dataflow.TaskletFusion, validate=False, @@ -701,8 +708,13 @@ def _gt_auto_process_dataflow_inside_maps( validate=False, validate_all=validate_all, ) - - # TODO(phimuell): Do we need a simplify here. + # TODO(phimuell): figuring out if this is needed? + gtx_transformations.gt_simplify( + sdfg, + skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST, + validate=False, + validate_all=validate_all, + ) # Move dataflow into the branches of the `if` such that they are only evaluated # if they are needed. Important to call it repeatedly. From d7077175397578fe50e37adf1da59a532e8a010f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Mon, 26 Jan 2026 07:56:20 +0100 Subject: [PATCH 06/61] Updated the description. --- .../dace/transformations/auto_optimize.py | 21 ++++++++++++------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 43cb769d39..c59c3532b3 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -675,8 +675,9 @@ def _gt_auto_process_dataflow_inside_maps( time, so the compiler will fully unroll them anyway. """ - # Blocking is performed first, because this ensures that as much as possible - # is moved into the k independent part. + # Separate Tasklets into dependent and independent parts to promote data + # reusability. It is important that this step has to be performed before + # `TaskletFusion` is used. if blocking_dim is not None: sdfg.apply_transformations_once_everywhere( gtx_transformations.LoopBlocking( @@ -688,13 +689,14 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - # Empirical observation in MuPhys have shown that running `TaskletFusion` increases - # performance quite drastically. Thus it was added here. However, to ensure - # that `LoopBlocking` still works, i.e. independent and dependent Tasklets are - # not mixed it must run _after_ `LoopBlocking`. Furthermore, it has been shown - # that it has to run _before_ `GT4PyMoveTaskletIntoMap`. The reasons are not - # clear but it can be measured. + # Merge Tasklets into bigger ones. + # NOTE: Empirical observation for Graupel have shown that this leads to an increase + # in performance, however, it has to be run before `GT4PyMoveTaskletIntoMap` + # (not fully clear why though, probably a compiler artefact) and as well as + # `MoveDataflowIntoIfBody` (not fully clear either, it `TaskletFusion` makes + # things simpler or prevent it from doing certain, negative, things). # TODO(phimuell): Restrict it to Tasklets only inside Maps. + # TODO(phimuell): Investigate more. sdfg.apply_transformations_repeated( dace_dataflow.TaskletFusion, validate=False, @@ -708,6 +710,7 @@ def _gt_auto_process_dataflow_inside_maps( validate=False, validate_all=validate_all, ) + # TODO(phimuell): figuring out if this is needed? gtx_transformations.gt_simplify( sdfg, @@ -729,6 +732,8 @@ def _gt_auto_process_dataflow_inside_maps( validate=False, validate_all=validate_all, ) + + # TODO(phimuell): figuring out if this is needed? gtx_transformations.gt_simplify( sdfg, skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST, From 6f1c95b10f71925a50d6e1cdb18bf9ae8c7accf1 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Jan 2026 08:50:02 +0100 Subject: [PATCH 07/61] Updated the CPU memory order. --- .../dace/transformations/auto_optimize.py | 19 +++++----- .../runners/dace/transformations/strides.py | 35 ++++++++++++------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1d04c21fc3..ac2fadab09 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -762,12 +762,12 @@ def _gt_auto_configure_maps_and_strides( For a description of the arguments see the `gt_auto_optimize()` function. """ - # We now set the iteration order of the Maps. For that we use `unit_strides_kind` - # argument and if not supplied we guess depending if we are on the GPU or not. + # If no unit stride is given explicitly we assume that it is in the horizontal. + # NOTE: Before the optimizer assumed that the memory layout was different for + # GPU (horizontal first) and CPU (vertical first). However this was wrong. if unit_strides_kind is None: - unit_strides_kind = ( - gtx_common.DimensionKind.HORIZONTAL if gpu else gtx_common.DimensionKind.VERTICAL - ) + unit_strides_kind = gtx_common.DimensionKind.HORIZONTAL + # It is not possible to use the `unit_strides_dim` argument of the # function, because `LoopBlocking`, if run, changed the name of the # parameter but the dimension can still be identified by its "kind". @@ -782,11 +782,10 @@ def _gt_auto_configure_maps_and_strides( # get expanded, i.e. turned into Maps because no `cudaMemcpy*()` call exists, # which requires that the final strides are there. Furthermore, Memlet expansion # has to happen before the GPU block size is set. There are several possible - # solutions for that, of which none is really good. The one that is the least - # bad thing is to set the strides of the transients here. The main downside - # is that this and the `_gt_auto_post_processing()` function has these weird - # names. - gtx_transformations.gt_change_strides(sdfg, gpu=gpu) + # solutions for that, of which none is really good. The least bad one is to + # set the strides of the transients here. The main downside is that this and + # the `_gt_auto_post_processing()` function has these weird names. + gtx_transformations.gt_change_strides(sdfg) if gpu: # TODO(phimuell): The GPU function might modify the map iteration order. diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py index 928fa04d54..84fc2ee14c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -33,37 +33,45 @@ def gt_change_strides( sdfg: dace.SDFG, - gpu: bool, ) -> dace.SDFG: """Modifies the strides of transients. The function will analyse the access patterns and set the strides of - transients in the optimal way. - The function should run after all maps have been created. + transients in the optimal way. The function should run after _all_ + Maps have been created. + After the adjustment of the strides they will be propagated into the nested + SDFGs, see `gt_propagate_strides_of()` for more. - After the strides have been adjusted the function will also propagate - the strides into nested SDFG, see `gt_propagate_strides_of()` for more. Args: sdfg: The SDFG to process. - gpu: If the SDFG is supposed to run on the GPU. Note: Currently the function will not scan the access pattern. Instead it will - either use FORTRAN order for GPU or C order. This function needs to be called + translate the memory layout such that the horizontal dimension has stride 1, + which is used by the GT4Py allocator. This function needs to be called for both CPU and GPU to handle strides of memlets inside nested SDFGs. Todo: - - Implement the estimation correctly. + - Update this function such that the memory order is computed based on the + access pattern. Probably also merge it with `gt_set_iteration_order()` + function as the task are related. + - Im """ - # TODO(phimeull): Implement this function correctly. + # TODO(phimeull): Implement this function correctly, such that it decides the + # order based on the access pattern. Probably also merge it with + # `gt_set_iteration_order()` as the two things are related. + # NOTE: This function builds on the fact that in GT4Py the horizontal dimension + # is always the first dimensions, i.e. column or FORTRAN order and that in + # DaCe the default order (which the lowering uses), is row or C order. + # Thus we just have to inverse the order for all transients and propagate + # the new strides. for nsdfg in sdfg.all_sdfgs_recursive(): - _gt_change_strides_non_recursive_impl(nsdfg, gpu) + _gt_change_strides_non_recursive_impl(nsdfg) def _gt_change_strides_non_recursive_impl( sdfg: dace.SDFG, - gpu: bool, ) -> None: """Set optimal strides of all access nodes in the SDFG. @@ -103,7 +111,7 @@ def _gt_change_strides_non_recursive_impl( # access nodes because the non-transients come from outside and have their # own strides. # TODO(phimuell): Set the stride based on the actual access pattern. - if desc.transient and gpu: + if desc.transient: new_stride_order = list(range(ndim)) desc.set_strides_from_layout(*new_stride_order) @@ -124,7 +132,8 @@ def _gt_change_strides_non_recursive_impl( ) # Now handle the views. - # TODO(phimuell): Remove once `gt_propagate_strides_from_access_node()` can handle views. + # TODO(phimuell): Remove once `gt_propagate_strides_from_access_node()` can + # handle views. However, we should get to a point where we do not have views. _gt_modify_strides_of_views_non_recursive(sdfg) From 2027ad6dbe498c63bf7bdd6d3ff71e272c5ddc42 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Jan 2026 09:01:06 +0100 Subject: [PATCH 08/61] Made some additional notes. --- .../runners/dace/transformations/strides.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py index 84fc2ee14c..c99e9b79a2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -50,16 +50,15 @@ def gt_change_strides( translate the memory layout such that the horizontal dimension has stride 1, which is used by the GT4Py allocator. This function needs to be called for both CPU and GPU to handle strides of memlets inside nested SDFGs. - - Todo: - - Update this function such that the memory order is computed based on the - access pattern. Probably also merge it with `gt_set_iteration_order()` - function as the task are related. - - Im + Furthermore, the current implementation assumes that there is only one + horizontal dimension. """ # TODO(phimeull): Implement this function correctly, such that it decides the # order based on the access pattern. Probably also merge it with # `gt_set_iteration_order()` as the two things are related. + # TODO(phimuell): The current implementation assumes that there is only one + # horizontal dimension. If there are multiple horizontal ones then we might + # have a problem. # NOTE: This function builds on the fact that in GT4Py the horizontal dimension # is always the first dimensions, i.e. column or FORTRAN order and that in # DaCe the default order (which the lowering uses), is row or C order. From 211956abb11ff234672f47906ebf664da4db6e36 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Jan 2026 14:25:08 +0100 Subject: [PATCH 09/61] Identified an additional source of indeterministic behaviour that is for now not so important. --- .../runners/dace/transformations/auto_optimize.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1d04c21fc3..404e9c01ee 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -707,6 +707,9 @@ def _gt_auto_process_dataflow_inside_maps( # before or after `LoopBlocking`. In cases where the condition is `False` # most of the times calling it before is better, but if the condition is # `True` then this order is better. Solve that issue. + # TODO(phimuell): Because of the limitation that the transformation only works + # for dataflow that is directly enclosed by a Map, the order in which it is + # applied matters. Instead we have to run it into a topological order. sdfg.apply_transformations_repeated( gtx_transformations.MoveDataflowIntoIfBody( ignore_upstream_blocks=False, From 6a57bc9fa8135cf1adb84b40bcbde57f232ea042 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Jan 2026 14:26:20 +0100 Subject: [PATCH 10/61] Refined the check what can be moved into the `if` body and what not. --- .../move_dataflow_into_if_body.py | 83 ++++++++++++++----- 1 file changed, 64 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 55497fc11a..6c133a7d19 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -147,6 +147,7 @@ def can_be_applied( if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + enclosing_map=enclosing_map, ) # If no branch has something to inline then we are done. @@ -204,6 +205,7 @@ def apply( if_block=if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + enclosing_map=enclosing_map, ) # Finally relocate the dataflow @@ -551,6 +553,7 @@ def _has_if_block_relocatable_dataflow( if_block=upstream_if_block, raw_relocatable_dataflow=raw_relocatable_dataflow, non_relocatable_dataflow=non_relocatable_dataflow, + enclosing_map=enclosing_map, ) if all(len(rel_df) == 0 for rel_df in filtered_relocatable_dataflow.values()): return False @@ -564,6 +567,7 @@ def _filter_relocatable_dataflow( if_block: dace_nodes.NestedSDFG, raw_relocatable_dataflow: dict[str, set[dace_nodes.Node]], non_relocatable_dataflow: dict[str, set[dace_nodes.Node]], + enclosing_map: dace_nodes.MapEntry, ) -> dict[str, set[dace_nodes.Node]]: """Partition the dependencies. @@ -581,6 +585,8 @@ def _filter_relocatable_dataflow( that can be relocated, not yet filtered. non_relocatable_dataflow: The connectors and their associated dataflow that can not be relocated. + enclosing_map: The limiting node, i.e. the MapEntry of the Map `if_block` + is located in. """ # Remove the parts of the dataflow that is unrelocatable. @@ -592,8 +598,9 @@ def _filter_relocatable_dataflow( for conn_name, rel_df in raw_relocatable_dataflow.items() } - # Now we determine the nodes that are in more than one sets. - # These sets must be removed, from the individual sets. + # Relocating nodes that are in more than one set is difficult. In the most + # common case of just two branches, this anyway means they have to be + # executed in any case. Thus we remove them now. known_nodes: set[dace_nodes.Node] = set() multiple_df_nodes: set[dace_nodes.Node] = set() for rel_df in relocatable_dataflow.values(): @@ -606,35 +613,73 @@ def _filter_relocatable_dataflow( for conn_name, rel_df in relocatable_dataflow.items() } - # However, not all dataflow can be moved inside the branch. For example if - # something is used outside the dataflow, that is moved inside the `if`, - # then we can not relocate it. # TODO(phimuell): If we operate outside of a Map we also have to make sure that # the data is single use data, is not an AccessNode that refers to global # memory nor is a source AccessNode. def filter_nodes( - branch_nodes: set[dace_nodes.Node], - sdfg: dace.SDFG, - state: dace.SDFGState, + nodes_proposed_for_reloc: set[dace_nodes.Node], ) -> set[dace_nodes.Node]: - # For this to work the `if_block` must be considered part, we remove it later. - branch_nodes.add(if_block) has_been_updated = True while has_been_updated: has_been_updated = False - for node in list(branch_nodes): - if node is if_block: + + for reloc_node in list(nodes_proposed_for_reloc): + assert ( + state.in_degree(reloc_node) > 0 + ) # Because we are currently always inside a Map + + # If the node is needed by anything that is not also moved + # into the `if` body, then it has to remain outside. For that we + # have to pretend that `if_block` is also relocated. + if any( + oedge.dst not in nodes_proposed_for_reloc + for oedge in state.out_edges(reloc_node) + if oedge.dst is not if_block + ): + nodes_proposed_for_reloc.discard(reloc_node) + has_been_updated = True continue - if any(oedge.dst not in branch_nodes for oedge in state.out_edges(node)): - branch_nodes.remove(node) + + # We do not look at all incoming nodes, but have to ignore some of them. + # We ignore `enclosed_map` because it acts as boundary, and the node + # on the other side of it is mapped into the `if` body anyway. Then we + # have to ignore all AccessNodes, since they are either relocated into + # the `if` body or are mapped into. We then have to look only at the + # remaining nodes. + incoming_nodes: set[dace_nodes.Node] = { + iedge.src + for iedge in state.in_edges(reloc_node) + if not ( + (iedge.src is enclosing_map) + or isinstance(iedge.src, dace_nodes.AccessNode) + ) + } + if incoming_nodes.issubset(nodes_proposed_for_reloc): + # All nodes will be moved into the `if` body too, so no problem. + pass + + elif incoming_nodes.isdisjoint(nodes_proposed_for_reloc): + # None of the incoming nodes will be moved into the if body, + # thus `reloc_node` is an interface node, it might be _mapped_ + # into the `if` body (if it is an `AccessNode`), but the node + # itself will not be moved into the `if` body. + nodes_proposed_for_reloc.discard(reloc_node) has_been_updated = True - assert if_block in branch_nodes - branch_nodes.remove(if_block) - return branch_nodes + + else: + # Only some of the incoming nodes will be moved into the `if` + # body. This is legal only if the not moved nodes are + # AccessNodes, because we have ignored them in the definition + # of `incoming_nodes`, `reloc_node` can not be moved into + # the `if` body and neither can the incoming nodes. + nodes_proposed_for_reloc.difference_update(incoming_nodes) + nodes_proposed_for_reloc.discard(reloc_node) + has_been_updated = True + + return nodes_proposed_for_reloc return { - conn_name: filter_nodes(rel_df, sdfg, state) - for conn_name, rel_df in relocatable_dataflow.items() + conn_name: filter_nodes(rel_df) for conn_name, rel_df in relocatable_dataflow.items() } def _partition_if_block( From 83d99f3df453e2b1e8beb202383401bcd903be4f Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Jan 2026 14:27:49 +0100 Subject: [PATCH 11/61] Added a new unit test, however it is currently disabled because it is not done. --- .../test_move_dataflow_into_if_body.py | 128 ++++++++++++++++++ 1 file changed, 128 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 45c1620108..a5b5f4f51b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -1047,3 +1047,131 @@ def test_if_mover_symbolic_tasklet(): assert if_block.sdfg.symbols["__i"] in {dace.int32, dace.int64} assert if_block.symbol_mapping.keys() == expected_symb.union(["__i"]) assert all(str(sym) == str(symval) for sym, symval in if_block.symbol_mapping.items()) + + +def test_if_mover_access_node_between(): + """ + Essentially tests the following situation: + ```python + a = foo(...) + b = bar(...) + c = baz(...) + bb = a if c else b + cc = baz2(d, ...) + aa = foo2(...) + e = aa if cc else bb + ``` + """ + # This test is temporarily disabled. + return + sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks")) + state = sdfg.add_state(is_start_block=True) + + # Inputs + input_names = ["a", "b", "c", "d", "e", "f"] + for name in input_names: + sdfg.add_array( + name, + shape=(10,), + dtype=dace.float64, + transient=False, + ) + + # Temporaries + temporary_names = ["a1", "b1", "c1", "a2", "b2", "c2", "o"] + for name in temporary_names: + sdfg.add_scalar( + name, dtype=dace.bool_ if name.startswith("c") else dace.float64, transient=True + ) + + a1, b1, c1, a2, b2, c2, o = (state.add_access(name) for name in temporary_names) + me, mx = state.add_map("comp", ndrange={"__i": "0:10"}) + + # First branch of top `if_block` + tasklet_a1 = state.add_tasklet( + "tasklet_a1", inputs={"__in"}, outputs={"__out"}, code="__out = math.sin(__in)" + ) + state.add_edge(state.add_access("a"), None, me, "IN_a", dace.Memlet("a[0:10]")) + state.add_edge(me, "OUT_a", tasklet_a1, "__in", dace.Memlet("a[__i]")) + state.add_edge(tasklet_a1, "__out", a1, None, dace.Memlet("a1[0]")) + + # Second branch of the top `if_block` + tasklet_b1 = state.add_tasklet( + "tasklet_b1", inputs={"__in"}, outputs={"__out"}, code="__out = math.cos(__in)" + ) + state.add_edge(state.add_access("b"), None, me, "IN_b", dace.Memlet("b[0:10]")) + state.add_edge(me, "OUT_b", tasklet_b1, "__in", dace.Memlet("b[__i]")) + state.add_edge(tasklet_b1, "__out", b1, None, dace.Memlet("b1[0]")) + + # The condition of the top `if_block` + tasklet_c1 = state.add_tasklet( + "tasklet_c1", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5" + ) + state.add_edge(state.add_access("c"), None, me, "IN_c", dace.Memlet("c[0:10]")) + state.add_edge(me, "OUT_c", tasklet_c1, "__in", dace.Memlet("c[__i]")) + state.add_edge(tasklet_c1, "__out", c1, None, dace.Memlet("c1[0]")) + + # Create the top `if_block` + top_if_block = _make_if_block(state, sdfg) + state.add_edge(a1, None, top_if_block, "__arg1", dace.Memlet("a1[0]")) + state.add_edge(b1, None, top_if_block, "__arg2", dace.Memlet("b1[0]")) + state.add_edge(c1, None, top_if_block, "__cond", dace.Memlet("c1[0]")) + state.add_edge(top_if_block, "__output", t1, None, dace.Memlet("t1[0]")) + + # The first branch of the lower/second `if_block`, which uses data computed + # by the top `if_block`. + tasklet_t2 = state.add_tasklet( + "tasklet_t2", inputs={"__in"}, outputs={"__out"}, code="__out = math.exp(__in)" + ) + state.add_edge(t1, None, tasklet_t2, "__in", dace.Memlet("t1[0]")) + state.add_edge(tasklet_t2, "__out", t2, None, dace.Memlet("t2[0]")) + + # Second branch of the second `if_block`. + tasklet_d1 = state.add_tasklet( + "tasklet_d1", inputs={"__in"}, outputs={"__out"}, code="__out = math.atan(__in)" + ) + state.add_edge(state.add_access("d"), None, me, "IN_d", dace.Memlet("d[0:10]")) + state.add_edge(me, "OUT_d", tasklet_d1, "__in", dace.Memlet("d[__i]")) + state.add_edge(tasklet_d1, "__out", d1, None, dace.Memlet("d1[0]")) + + # Condition branch of the second `if_block`. + tasklet_cc1 = state.add_tasklet( + "tasklet_cc1", inputs={"__in"}, outputs={"__out"}, code="__out = __in < 0.5" + ) + state.add_edge(state.add_access("cc"), None, me, "IN_cc", dace.Memlet("cc[0:10]")) + state.add_edge(me, "OUT_cc", tasklet_cc1, "__in", dace.Memlet("cc[__i]")) + state.add_edge(tasklet_cc1, "__out", cc1, None, dace.Memlet("cc1[0]")) + + # Create the second `if_block` + bot_if_block = _make_if_block(state, sdfg) + state.add_edge(t2, None, bot_if_block, "__arg1", dace.Memlet("t2[0]")) + state.add_edge(d1, None, bot_if_block, "__arg2", dace.Memlet("d1[0]")) + state.add_edge(cc1, None, bot_if_block, "__cond", dace.Memlet("cc1[0]")) + + # Generate the output + state.add_edge(bot_if_block, "__output", mx, "IN_e", dace.Memlet("e[__i]")) + state.add_edge(mx, "OUT_e", state.add_access("e"), None, dace.Memlet("e[0:10]")) + + # Now add the connectors to the Map* + for iname in input_names: + if iname == "e": + mx.add_in_connector(f"IN_{iname}") + mx.add_out_connector(f"OUT_{iname}") + else: + me.add_in_connector(f"IN_{iname}") + me.add_out_connector(f"OUT_{iname}") + sdfg.validate() + + # It is not possible to apply the transformation on the lower `if_block`, + # because it is limited by the top one. + _perform_test( + sdfg, + explected_applies=0, + if_block=bot_if_block, + ) + + # But we are able to inline both. + _perform_test( + sdfg, + explected_applies=2, + ) From 6179a6adb9d2762a1a508c34778db48a15e8ec11 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 28 Jan 2026 15:45:25 +0100 Subject: [PATCH 12/61] Updated the description and naming a bit. --- .../runners/dace/transformations/auto_optimize.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index ac2fadab09..82750703b8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -763,17 +763,19 @@ def _gt_auto_configure_maps_and_strides( """ # If no unit stride is given explicitly we assume that it is in the horizontal. - # NOTE: Before the optimizer assumed that the memory layout was different for - # GPU (horizontal first) and CPU (vertical first). However this was wrong. - if unit_strides_kind is None: - unit_strides_kind = gtx_common.DimensionKind.HORIZONTAL + # This has also technical reasons to avoid launch errors (on GPU we have to make + # sure that the biggest dimension ends up on the `x` direction, which is most + # likely the horizontal dimension). + prime_direction_kind = ( + gtx_common.DimensionKind.HORIZONTAL if unit_strides_kind is None else unit_strides_kind + ) # It is not possible to use the `unit_strides_dim` argument of the # function, because `LoopBlocking`, if run, changed the name of the # parameter but the dimension can still be identified by its "kind". gtx_transformations.gt_set_iteration_order( sdfg=sdfg, - unit_strides_kind=unit_strides_kind, + unit_strides_kind=prime_direction_kind, validate=False, validate_all=validate_all, ) From 88ac3babc838bec85f8b36ea8613fa4301eff531 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Jan 2026 07:56:12 +0100 Subject: [PATCH 13/61] Changed the selection of the leading kind and also clarified on the description. If the leading kind is not known then it will not reorder strides nor the iteration order. However, for cetain reasons (launch errors) we have to set one for GPU in that case. --- .../dace/transformations/auto_optimize.py | 70 ++++++++++++------- .../runners/dace/transformations/gpu_utils.py | 1 + .../runners/dace/transformations/strides.py | 62 +++++++++------- 3 files changed, 84 insertions(+), 49 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 82750703b8..5116f72075 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -762,34 +762,56 @@ def _gt_auto_configure_maps_and_strides( For a description of the arguments see the `gt_auto_optimize()` function. """ - # If no unit stride is given explicitly we assume that it is in the horizontal. - # This has also technical reasons to avoid launch errors (on GPU we have to make - # sure that the biggest dimension ends up on the `x` direction, which is most - # likely the horizontal dimension). - prime_direction_kind = ( - gtx_common.DimensionKind.HORIZONTAL if unit_strides_kind is None else unit_strides_kind - ) - - # It is not possible to use the `unit_strides_dim` argument of the - # function, because `LoopBlocking`, if run, changed the name of the - # parameter but the dimension can still be identified by its "kind". - gtx_transformations.gt_set_iteration_order( - sdfg=sdfg, - unit_strides_kind=prime_direction_kind, - validate=False, - validate_all=validate_all, - ) + # If `unit_strides_kind` is unknown we will not modify the Map order nor the + # strides, except if we are on GPU. The reason for this is that the maximal + # number of blocks is different for each dimension. If the largest dimension + # is for example associated with the `z` dimension, we would get launch errors + # at some point. Thus in that case we pretend that it is horizontal. Which is + # a valid assumption for any ICON-like code or if the GT4Py allocator is used. + # TODO(phimuell): Make this selection more intelligent. + if unit_strides_kind is None and gpu: + prefered_direction_kind: Optional[gtx_common.DimensionKind] = ( + gtx_common.DimensionKind.HORIZONTAL + ) + else: + prefered_direction_kind = unit_strides_kind + + # We should actually use a `gtx.Dimension` here and not a `gtx.DimensionKind`, + # since they are unique. However at this stage, especially after the expansion + # of non standard Memlets (which happens in the GPU transformation) associating + # Map parameters with GT4Py dimension is very hard to impossible. At this stage + # the kind is the most reliable indicator we have. + # NOTE: This is not the only location where we manipulate the Map order, we also + # do it in the GPU transformation, where we have to set the order of the + # expanded Memlets. + if prefered_direction_kind is not None: + gtx_transformations.gt_set_iteration_order( + sdfg=sdfg, + unit_strides_kind=prefered_direction_kind, + validate=False, + validate_all=validate_all, + ) # NOTE: We have to set the strides of transients before the non-standard Memlets - # get expanded, i.e. turned into Maps because no `cudaMemcpy*()` call exists, - # which requires that the final strides are there. Furthermore, Memlet expansion - # has to happen before the GPU block size is set. There are several possible - # solutions for that, of which none is really good. The least bad one is to - # set the strides of the transients here. The main downside is that this and - # the `_gt_auto_post_processing()` function has these weird names. - gtx_transformations.gt_change_strides(sdfg) + # get expanded, i.e. turned into Maps because no matching `cudaMemcpy*()` call + # exists, which requires that the final strides are there. Furthermore, Memlet + # expansion has to happen before the GPU block size is set. There are several + # possible solutions for that, of which none is really good. The least bad one + # is to set the strides of the transients here. The main downside is that we + # slightly modify the SDFG in the GPU transformation after we have set the + # strides. + if prefered_direction_kind is not None: + gtx_transformations.gt_change_strides(sdfg, prefered_direction_kind=prefered_direction_kind) if gpu: + if unit_strides_kind != gtx_common.DimensionKind.HORIZONTAL: + warnings.warn( + "The GT4Py DaCe backend assumes that in GPU mode the leading dimension" + f" is horizontal, but it was '{unit_strides_kind}', this might lead" + " to suboptimal performance", + stacklevel=2, + ) + # TODO(phimuell): The GPU function might modify the map iteration order. # This is because how it is implemented (promotion and fusion). However, # because of its current state, this should not happen, but we have to look diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 1786913edb..ac81d2cd64 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -241,6 +241,7 @@ def restrict_fusion_to_newly_created_maps_horizontal( if len(maps_to_modify) == 0: return sdfg + # NOTE: This inherently assumes a particular memory order, see `gt_change_strides()`. for me_to_modify in maps_to_modify: map_to_modify: dace_nodes.Map = me_to_modify.map map_to_modify.params = list(reversed(map_to_modify.params)) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py index c99e9b79a2..0840f77755 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -12,6 +12,7 @@ from dace import data as dace_data from dace.sdfg import nodes as dace_nodes +from gt4py.next import common as gtx_common from gt4py.next.program_processors.runners.dace import ( sdfg_args as gtx_dace_args, transformations as gtx_transformations, @@ -33,6 +34,7 @@ def gt_change_strides( sdfg: dace.SDFG, + prefered_direction_kind: gtx_common.DimensionKind, ) -> dace.SDFG: """Modifies the strides of transients. @@ -44,35 +46,50 @@ def gt_change_strides( Args: sdfg: The SDFG to process. + prefered_direction_kind: `DimensionKind` of the dimension with stride 1. Note: - Currently the function will not scan the access pattern. Instead it will - translate the memory layout such that the horizontal dimension has stride 1, - which is used by the GT4Py allocator. This function needs to be called - for both CPU and GPU to handle strides of memlets inside nested SDFGs. - Furthermore, the current implementation assumes that there is only one - horizontal dimension. + - This function should be run after `gt_set_iteration_order()` has been run. + - Currently the function will not scan the access pattern. Instead it relies + on the default behaviour of the lowering and how the GT4Py allocator works. + - The current implementation assumes that there is only one dimension of the + given kind. """ # TODO(phimeull): Implement this function correctly, such that it decides the # order based on the access pattern. Probably also merge it with # `gt_set_iteration_order()` as the two things are related. - # TODO(phimuell): The current implementation assumes that there is only one - # horizontal dimension. If there are multiple horizontal ones then we might - # have a problem. - # NOTE: This function builds on the fact that in GT4Py the horizontal dimension - # is always the first dimensions, i.e. column or FORTRAN order and that in - # DaCe the default order (which the lowering uses), is row or C order. - # Thus we just have to inverse the order for all transients and propagate - # the new strides. - for nsdfg in sdfg.all_sdfgs_recursive(): - _gt_change_strides_non_recursive_impl(nsdfg) + # NOTE: This function builds inherently assumes the dimension order defined by + # `gtx_common.order_dimensions()`, the default behaviour of the lowering, + # partially how the GT4Py allocator works and that there is only one dimension + # of any kind (which is true for ICON4Py, but not in general, for example + # in Cartesian grids). Its base assumption is that the ordering (Map parameters + # and strides) generated by the lowering "out of the box" are in row major/C + # order. Because of the GT4Py dimension order this is the right order for + # `gtx_common.DimensionKind.VERTICAL`. If the primary direction kind is + # `HORIZONTAL`, then according to the GT4Py dimension order column major/FORTRAN + # order should be used. To get there we have to reverse the strides order, which + # `_gt_change_strides_non_recursive_impl()` does. This is very brittle but at + # this point the best thing we can do. + + match prefered_direction_kind: + case gtx_common.DimensionKind.VERTICAL: + return # Nothing to do in that case. Maybe run Memlet propagation here? + + case gtx_common.DimensionKind.HORIZONTAL: + for nsdfg in sdfg.all_sdfgs_recursive(): + _gt_change_strides_non_recursive_impl(nsdfg) + + case _: + raise ValueError( + f"Encountered unknown `DimensionKind` value: {prefered_direction_kind}" + ) def _gt_change_strides_non_recursive_impl( sdfg: dace.SDFG, ) -> None: - """Set optimal strides of all access nodes in the SDFG. + """Set "optimal" strides of all access nodes in the SDFG. The function will look for all top level access node, see `_gt_find_toplevel_data_accesses()` and set their strides such that the access is optimal, see Note. The function @@ -81,14 +98,9 @@ def _gt_change_strides_non_recursive_impl( This function should never be called directly but always through `gt_change_strides()`! Note: - Currently the function just reverses the strides of the data descriptor - of transient access nodes it processes. Since DaCe generates `C` order by default - this lead to FORTRAN order, which is (for now) sufficient to optimize the memory - layout to GPU. - - Todo: - Make this function more intelligent to analyse the access pattern and then - figuring out the best order. + This function has the same underlying assumption as they are outlined in + `gt_change_strides()`, see there from more informations about the underlying + assumptions and limitations. """ # NOTE: We have to process all access nodes (transient and globals). If we are inside a # NestedSDFG then they were handled before on the level above us. From f390cff3e52659114469f5952697ce22c4177647 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Jan 2026 10:09:54 +0100 Subject: [PATCH 14/61] Added a compatibility layer for ICON4Py. --- src/gt4py/next/metrics.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 src/gt4py/next/metrics.py diff --git a/src/gt4py/next/metrics.py b/src/gt4py/next/metrics.py new file mode 100644 index 0000000000..735465a81a --- /dev/null +++ b/src/gt4py/next/metrics.py @@ -0,0 +1,12 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +# Needed for compatibility with ICON4Py +from gt4py.next.instrumentation.metrics import * # noqa: F403 [undefined-local-with-import-star] From 080c669d20e929cd9884b5171dc5d877f4e9f6af Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Jan 2026 10:52:04 +0100 Subject: [PATCH 15/61] Updated the warning. --- .../runners/dace/transformations/auto_optimize.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 5116f72075..a32c446799 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -806,9 +806,10 @@ def _gt_auto_configure_maps_and_strides( if gpu: if unit_strides_kind != gtx_common.DimensionKind.HORIZONTAL: warnings.warn( - "The GT4Py DaCe backend assumes that in GPU mode the leading dimension" - f" is horizontal, but it was '{unit_strides_kind}', this might lead" - " to suboptimal performance", + "The GT4Py DaCe GPU backend assumes that the leading dimension, i.e." + " where stride is 1, is of kind 'HORIZONTAL', however it was" + f" '{unit_strides_kind}' and is the last index. Other configurations" + " might lead to suboptimal performance.", stacklevel=2, ) From 7569c4ca63fc19a02bbc794c8cce3ab472484f82 Mon Sep 17 00:00:00 2001 From: Edoardo Paone Date: Thu, 29 Jan 2026 10:55:51 +0100 Subject: [PATCH 16/61] Revert "Added a compatibility layer for ICON4Py." This reverts commit f390cff3e52659114469f5952697ce22c4177647. --- src/gt4py/next/metrics.py | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 src/gt4py/next/metrics.py diff --git a/src/gt4py/next/metrics.py b/src/gt4py/next/metrics.py deleted file mode 100644 index 735465a81a..0000000000 --- a/src/gt4py/next/metrics.py +++ /dev/null @@ -1,12 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -# Needed for compatibility with ICON4Py -from gt4py.next.instrumentation.metrics import * # noqa: F403 [undefined-local-with-import-star] From 9d092c44cfe2e3881f50406dc772b1060c7747e9 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Jan 2026 13:31:39 +0100 Subject: [PATCH 17/61] Added an option to disable TaskletFusion. By default it is off. --- .../dace/transformations/auto_optimize.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index c59c3532b3..a1dae67352 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -131,6 +131,7 @@ def gt_auto_optimize( assume_pointwise: bool = True, optimization_hooks: Optional[dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun]] = None, demote_fields: Optional[list[str]] = None, + compact_tasklets: bool = False, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -198,6 +199,7 @@ def gt_auto_optimize( see `GT4PyAutoOptHook` for more information. demote_fields: Consider these fields as transients for the purpose of optimization. Use at your own risk. See Notes for all implications. + compact_tasklets: Reduces the number of Tasklets by fusing them. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -325,6 +327,7 @@ def gt_auto_optimize( blocking_only_if_independent_nodes=blocking_only_if_independent_nodes, scan_loop_unrolling=scan_loop_unrolling, scan_loop_unrolling_factor=scan_loop_unrolling_factor, + compact_tasklets=compact_tasklets, validate_all=validate_all, ) @@ -661,6 +664,7 @@ def _gt_auto_process_dataflow_inside_maps( blocking_only_if_independent_nodes: Optional[bool], scan_loop_unrolling: bool, scan_loop_unrolling_factor: int, + compact_tasklets: bool, validate_all: bool, ) -> dace.SDFG: """Optimizes the dataflow inside the top level Maps of the SDFG inplace. @@ -695,13 +699,14 @@ def _gt_auto_process_dataflow_inside_maps( # (not fully clear why though, probably a compiler artefact) and as well as # `MoveDataflowIntoIfBody` (not fully clear either, it `TaskletFusion` makes # things simpler or prevent it from doing certain, negative, things). - # TODO(phimuell): Restrict it to Tasklets only inside Maps. # TODO(phimuell): Investigate more. - sdfg.apply_transformations_repeated( - dace_dataflow.TaskletFusion, - validate=False, - validate_all=validate_all, - ) + # TODO(phimuell): Restrict it to Tasklets only inside Maps. + if compact_tasklets: + sdfg.apply_transformations_repeated( + dace_dataflow.TaskletFusion, + validate=False, + validate_all=validate_all, + ) # Constants (tasklets are needed to write them into a variable) should not be # arguments to a kernel but be present inside the body. From ad3a0489519f37cf3efe29aa017963e99b1a6479 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Jan 2026 14:19:16 +0100 Subject: [PATCH 18/61] Removed the compatibility hack. --- src/gt4py/next/metrics.py | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 src/gt4py/next/metrics.py diff --git a/src/gt4py/next/metrics.py b/src/gt4py/next/metrics.py deleted file mode 100644 index 735465a81a..0000000000 --- a/src/gt4py/next/metrics.py +++ /dev/null @@ -1,12 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2024, ETH Zurich -# All rights reserved. -# -# Please, refer to the LICENSE file in the root directory. -# SPDX-License-Identifier: BSD-3-Clause - -from __future__ import annotations - -# Needed for compatibility with ICON4Py -from gt4py.next.instrumentation.metrics import * # noqa: F403 [undefined-local-with-import-star] From 2d84e42b8db40c8dcbd461d844ab2c947b82d17d Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 27 Nov 2025 13:57:07 +0100 Subject: [PATCH 19/61] Enable maxnreg setting --- .../dace/transformations/auto_optimize.py | 4 ++++ .../runners/dace/transformations/gpu_utils.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1d04c21fc3..1bd8e7ad9a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -116,6 +116,7 @@ def gt_auto_optimize( gpu_block_size_1d: Optional[Sequence[int | str] | str] = (64, 1, 1), gpu_block_size_2d: Optional[Sequence[int | str] | str] = None, gpu_block_size_3d: Optional[Sequence[int | str] | str] = None, + gpu_maxnreg: Optional[int] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, blocking_only_if_independent_nodes: bool = True, @@ -371,6 +372,7 @@ def gt_auto_optimize( gpu_block_size=gpu_block_size, gpu_launch_factor=gpu_launch_factor, gpu_launch_bounds=gpu_launch_bounds, + gpu_maxnreg=gpu_maxnreg, optimization_hooks=optimization_hooks, gpu_block_size_spec=gpu_block_size_spec if gpu_block_size_spec else None, validate_all=validate_all, @@ -747,6 +749,7 @@ def _gt_auto_configure_maps_and_strides( gpu_block_size: Optional[Sequence[int | str] | str], gpu_launch_bounds: Optional[int | str], gpu_launch_factor: Optional[int], + gpu_maxnreg: Optional[int], optimization_hooks: dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun], gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]], validate_all: bool, @@ -799,6 +802,7 @@ def _gt_auto_configure_maps_and_strides( gpu_launch_bounds=gpu_launch_bounds, gpu_launch_factor=gpu_launch_factor, gpu_block_size_spec=gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, try_removing_trivial_maps=True, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 1786913edb..a1df77cad4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -34,6 +34,7 @@ def gt_gpu_transformation( gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -123,6 +124,7 @@ def gt_gpu_transformation( launch_bounds=gpu_launch_bounds, launch_factor=gpu_launch_factor, **gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, ) @@ -362,6 +364,7 @@ def gt_set_gpu_blocksize( block_size: Optional[Sequence[int | str] | str], launch_bounds: Optional[int | str] = None, launch_factor: Optional[int] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -393,6 +396,7 @@ def gt_set_gpu_blocksize( }.items(): if f"{arg}_{dim}d" not in kwargs: kwargs[f"{arg}_{dim}d"] = val + kwargs["maxnreg"] = gpu_maxnreg setter = GPUSetBlockSize(**kwargs) @@ -590,6 +594,12 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): default=None, desc="Set the launch bound property for 3 dimensional map.", ) + maxnreg = dace_properties.Property( + dtype=int, + allow_none=True, + default=None, + desc="Set the maxnreg property for the GPU maps.", + ) # Pattern matching map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) @@ -605,6 +615,7 @@ def __init__( launch_factor_1d: int | None = None, launch_factor_2d: int | None = None, launch_factor_3d: int | None = None, + maxnreg: int | None = None, ) -> None: super().__init__() if block_size_1d is not None: @@ -632,6 +643,8 @@ def __init__( self.launch_bounds_3d = _gpu_launch_bound_parser( self.block_size_3d, launch_bounds_3d, launch_factor_3d ) + if maxnreg is not None: + self.maxnreg = maxnreg @classmethod def expressions(cls) -> Any: @@ -748,7 +761,9 @@ def apply( block_size[i] = map_size[map_dim_idx_to_inspect] gpu_map.gpu_block_size = tuple(block_size) - if launch_bounds is not None: # Note: empty string has a meaning in DaCe + if self.maxnreg is not None: + gpu_map.gpu_maxnreg = self.maxnreg + elif launch_bounds is not None: # Note: empty string has a meaning in DaCe gpu_map.gpu_launch_bounds = launch_bounds From 7a8f22b9e2a9a4891959ab0773d127e8dfcd9e4d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Thu, 29 Jan 2026 14:26:09 +0100 Subject: [PATCH 20/61] Made the suggested renaming. --- .../runners/dace/transformations/auto_optimize.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index a1dae67352..91aba1ba4c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -131,7 +131,7 @@ def gt_auto_optimize( assume_pointwise: bool = True, optimization_hooks: Optional[dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun]] = None, demote_fields: Optional[list[str]] = None, - compact_tasklets: bool = False, + fuse_tasklets: bool = False, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -199,7 +199,7 @@ def gt_auto_optimize( see `GT4PyAutoOptHook` for more information. demote_fields: Consider these fields as transients for the purpose of optimization. Use at your own risk. See Notes for all implications. - compact_tasklets: Reduces the number of Tasklets by fusing them. + fuse_tasklets: Reduces the number of Tasklets by fusing them. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -327,7 +327,7 @@ def gt_auto_optimize( blocking_only_if_independent_nodes=blocking_only_if_independent_nodes, scan_loop_unrolling=scan_loop_unrolling, scan_loop_unrolling_factor=scan_loop_unrolling_factor, - compact_tasklets=compact_tasklets, + fuse_tasklets=fuse_tasklets, validate_all=validate_all, ) @@ -664,7 +664,7 @@ def _gt_auto_process_dataflow_inside_maps( blocking_only_if_independent_nodes: Optional[bool], scan_loop_unrolling: bool, scan_loop_unrolling_factor: int, - compact_tasklets: bool, + fuse_tasklets: bool, validate_all: bool, ) -> dace.SDFG: """Optimizes the dataflow inside the top level Maps of the SDFG inplace. @@ -701,7 +701,7 @@ def _gt_auto_process_dataflow_inside_maps( # things simpler or prevent it from doing certain, negative, things). # TODO(phimuell): Investigate more. # TODO(phimuell): Restrict it to Tasklets only inside Maps. - if compact_tasklets: + if fuse_tasklets: sdfg.apply_transformations_repeated( dace_dataflow.TaskletFusion, validate=False, From bb7bd9190f1b17c597a81be10f9ea2fedd9eef3b Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:29:09 +0100 Subject: [PATCH 21/61] Added test for gpu_maxnreg --- .../transformation_tests/test_gpu_utils.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 1ef64da559..720a269547 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -351,3 +351,36 @@ def test_set_gpu_properties_2D_3D(): assert len(map4.params) == 4 assert map4.gpu_block_size == [1, 2, 32] assert map4.gpu_launch_bounds == "0" + + +def test_set_gpu_maxnreg(): + """Tests if gpu_maxnreg property is set correctly to GPU maps.""" + sdfg = dace.SDFG(util.unique_name("gpu_maxnreg_test")) + state = sdfg.add_state(is_start_block=True) + dim = 2 + shape = (10,) * (dim - 1) + (1,) + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + sdfg.validate() + sdfg.apply_gpu_transformations() + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size_1d=(128, 1, 1), + block_size_2d=(64, 2, 1), + block_size_3d=(2, 2, 32), + block_size=(32, 4, 1), + gpu_maxnreg=128, + ) + assert me.gpu_maxnreg == 128 From 80a4b8e07c2aefe327564b363f6bafc2b2e83bbe Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:31:08 +0100 Subject: [PATCH 22/61] Fix formating --- .../dace_tests/transformation_tests/test_gpu_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 720a269547..dd4927bb03 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -359,12 +359,8 @@ def test_set_gpu_maxnreg(): state = sdfg.add_state(is_start_block=True) dim = 2 shape = (10,) * (dim - 1) + (1,) - sdfg.add_array( - f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) - sdfg.add_array( - f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) + sdfg.add_array(f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) + sdfg.add_array(f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) _, me, _ = state.add_mapped_tasklet( f"map_{dim}", map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, From c080943508402bfa69521a4382c5c0debc08890b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Philip=20M=C3=BCller?= <147368808+philip-paul-mueller@users.noreply.github.com> Date: Thu, 29 Jan 2026 14:36:05 +0100 Subject: [PATCH 23/61] Updated The Intranode Optimization Branch (#2463) --- .../dace/transformations/auto_optimize.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 7e827a90ba..10da20e1c2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -131,6 +131,7 @@ def gt_auto_optimize( assume_pointwise: bool = True, optimization_hooks: Optional[dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun]] = None, demote_fields: Optional[list[str]] = None, + fuse_tasklets: bool = False, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -198,6 +199,7 @@ def gt_auto_optimize( see `GT4PyAutoOptHook` for more information. demote_fields: Consider these fields as transients for the purpose of optimization. Use at your own risk. See Notes for all implications. + fuse_tasklets: Reduces the number of Tasklets by fusing them. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -325,6 +327,7 @@ def gt_auto_optimize( blocking_only_if_independent_nodes=blocking_only_if_independent_nodes, scan_loop_unrolling=scan_loop_unrolling, scan_loop_unrolling_factor=scan_loop_unrolling_factor, + fuse_tasklets=fuse_tasklets, validate_all=validate_all, ) @@ -661,6 +664,7 @@ def _gt_auto_process_dataflow_inside_maps( blocking_only_if_independent_nodes: Optional[bool], scan_loop_unrolling: bool, scan_loop_unrolling_factor: int, + fuse_tasklets: bool, validate_all: bool, ) -> dace.SDFG: """Optimizes the dataflow inside the top level Maps of the SDFG inplace. @@ -695,13 +699,14 @@ def _gt_auto_process_dataflow_inside_maps( # (not fully clear why though, probably a compiler artefact) and as well as # `MoveDataflowIntoIfBody` (not fully clear either, it `TaskletFusion` makes # things simpler or prevent it from doing certain, negative, things). - # TODO(phimuell): Restrict it to Tasklets only inside Maps. # TODO(phimuell): Investigate more. - sdfg.apply_transformations_repeated( - dace_dataflow.TaskletFusion, - validate=False, - validate_all=validate_all, - ) + # TODO(phimuell): Restrict it to Tasklets only inside Maps. + if fuse_tasklets: + sdfg.apply_transformations_repeated( + dace_dataflow.TaskletFusion, + validate=False, + validate_all=validate_all, + ) # Constants (tasklets are needed to write them into a variable) should not be # arguments to a kernel but be present inside the body. From 458f99801ac883f332166e8265c10a0ddec939f7 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Fri, 30 Jan 2026 08:48:12 +0100 Subject: [PATCH 24/61] Correction. --- .../runners/dace/transformations/auto_optimize.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index a32c446799..80e43711d3 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -808,8 +808,8 @@ def _gt_auto_configure_maps_and_strides( warnings.warn( "The GT4Py DaCe GPU backend assumes that the leading dimension, i.e." " where stride is 1, is of kind 'HORIZONTAL', however it was" - f" '{unit_strides_kind}' and is the last index. Other configurations" - " might lead to suboptimal performance.", + f" '{unit_strides_kind}'. Furthermore, it should be the last dimension." + " Other configurations might lead to suboptimal performance.", stacklevel=2, ) From 17f8b37cb722020bcafac95bda61eb87e3e2e7c1 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 09:19:43 +0100 Subject: [PATCH 25/61] Mention that maxnreg takes precedence over launch bounds --- .../runners/dace/transformations/gpu_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index a1df77cad4..c565ff0f0b 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -598,7 +598,7 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): dtype=int, allow_none=True, default=None, - desc="Set the maxnreg property for the GPU maps.", + desc="Set the maxnreg property for the GPU maps. Takes precedence over any launch_bounds.", ) # Pattern matching From 923e3cd7d5d1c21551aa815aaba90fb66eff46b0 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 27 Nov 2025 13:57:07 +0100 Subject: [PATCH 26/61] Enable maxnreg setting --- .../dace/transformations/auto_optimize.py | 4 ++++ .../runners/dace/transformations/gpu_utils.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 486af807ab..8a11b64714 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -117,6 +117,7 @@ def gt_auto_optimize( gpu_block_size_1d: Optional[Sequence[int | str] | str] = (64, 1, 1), gpu_block_size_2d: Optional[Sequence[int | str] | str] = None, gpu_block_size_3d: Optional[Sequence[int | str] | str] = None, + gpu_maxnreg: Optional[int] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, blocking_only_if_independent_nodes: bool = True, @@ -375,6 +376,7 @@ def gt_auto_optimize( gpu_block_size=gpu_block_size, gpu_launch_factor=gpu_launch_factor, gpu_launch_bounds=gpu_launch_bounds, + gpu_maxnreg=gpu_maxnreg, optimization_hooks=optimization_hooks, gpu_block_size_spec=gpu_block_size_spec if gpu_block_size_spec else None, validate_all=validate_all, @@ -781,6 +783,7 @@ def _gt_auto_configure_maps_and_strides( gpu_block_size: Optional[Sequence[int | str] | str], gpu_launch_bounds: Optional[int | str], gpu_launch_factor: Optional[int], + gpu_maxnreg: Optional[int], optimization_hooks: dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun], gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]], validate_all: bool, @@ -833,6 +836,7 @@ def _gt_auto_configure_maps_and_strides( gpu_launch_bounds=gpu_launch_bounds, gpu_launch_factor=gpu_launch_factor, gpu_block_size_spec=gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, try_removing_trivial_maps=True, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 1786913edb..a1df77cad4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -34,6 +34,7 @@ def gt_gpu_transformation( gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -123,6 +124,7 @@ def gt_gpu_transformation( launch_bounds=gpu_launch_bounds, launch_factor=gpu_launch_factor, **gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, ) @@ -362,6 +364,7 @@ def gt_set_gpu_blocksize( block_size: Optional[Sequence[int | str] | str], launch_bounds: Optional[int | str] = None, launch_factor: Optional[int] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -393,6 +396,7 @@ def gt_set_gpu_blocksize( }.items(): if f"{arg}_{dim}d" not in kwargs: kwargs[f"{arg}_{dim}d"] = val + kwargs["maxnreg"] = gpu_maxnreg setter = GPUSetBlockSize(**kwargs) @@ -590,6 +594,12 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): default=None, desc="Set the launch bound property for 3 dimensional map.", ) + maxnreg = dace_properties.Property( + dtype=int, + allow_none=True, + default=None, + desc="Set the maxnreg property for the GPU maps.", + ) # Pattern matching map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) @@ -605,6 +615,7 @@ def __init__( launch_factor_1d: int | None = None, launch_factor_2d: int | None = None, launch_factor_3d: int | None = None, + maxnreg: int | None = None, ) -> None: super().__init__() if block_size_1d is not None: @@ -632,6 +643,8 @@ def __init__( self.launch_bounds_3d = _gpu_launch_bound_parser( self.block_size_3d, launch_bounds_3d, launch_factor_3d ) + if maxnreg is not None: + self.maxnreg = maxnreg @classmethod def expressions(cls) -> Any: @@ -748,7 +761,9 @@ def apply( block_size[i] = map_size[map_dim_idx_to_inspect] gpu_map.gpu_block_size = tuple(block_size) - if launch_bounds is not None: # Note: empty string has a meaning in DaCe + if self.maxnreg is not None: + gpu_map.gpu_maxnreg = self.maxnreg + elif launch_bounds is not None: # Note: empty string has a meaning in DaCe gpu_map.gpu_launch_bounds = launch_bounds From 0d2b3e14cf9acc520a2d89174a010678cea1df44 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:29:09 +0100 Subject: [PATCH 27/61] Added test for gpu_maxnreg --- .../transformation_tests/test_gpu_utils.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 1ef64da559..720a269547 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -351,3 +351,36 @@ def test_set_gpu_properties_2D_3D(): assert len(map4.params) == 4 assert map4.gpu_block_size == [1, 2, 32] assert map4.gpu_launch_bounds == "0" + + +def test_set_gpu_maxnreg(): + """Tests if gpu_maxnreg property is set correctly to GPU maps.""" + sdfg = dace.SDFG(util.unique_name("gpu_maxnreg_test")) + state = sdfg.add_state(is_start_block=True) + dim = 2 + shape = (10,) * (dim - 1) + (1,) + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + sdfg.validate() + sdfg.apply_gpu_transformations() + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size_1d=(128, 1, 1), + block_size_2d=(64, 2, 1), + block_size_3d=(2, 2, 32), + block_size=(32, 4, 1), + gpu_maxnreg=128, + ) + assert me.gpu_maxnreg == 128 From 215e775cc41826a3030f98c289b7343dd81f3726 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:31:08 +0100 Subject: [PATCH 28/61] Fix formating --- .../dace_tests/transformation_tests/test_gpu_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 720a269547..dd4927bb03 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -359,12 +359,8 @@ def test_set_gpu_maxnreg(): state = sdfg.add_state(is_start_block=True) dim = 2 shape = (10,) * (dim - 1) + (1,) - sdfg.add_array( - f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) - sdfg.add_array( - f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) + sdfg.add_array(f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) + sdfg.add_array(f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) _, me, _ = state.add_mapped_tasklet( f"map_{dim}", map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, From d58ade95788075d3d159483856534482e5a17dba Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 09:19:43 +0100 Subject: [PATCH 29/61] Mention that maxnreg takes precedence over launch bounds --- .../runners/dace/transformations/gpu_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index a1df77cad4..c565ff0f0b 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -598,7 +598,7 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): dtype=int, allow_none=True, default=None, - desc="Set the maxnreg property for the GPU maps.", + desc="Set the maxnreg property for the GPU maps. Takes precedence over any launch_bounds.", ) # Pattern matching From 956282f77338facfb6d73ed127ce127125d2c67c Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 22 Jan 2026 15:31:28 +0100 Subject: [PATCH 30/61] Remove scalar copies --- .../runners/dace/transformations/__init__.py | 1 + .../dace/transformations/auto_optimize.py | 20 +++ .../transformations/kill_aliasing_scalars.py | 143 ++++++++++++++++++ .../test_kill_aliasing_scalars.py | 85 +++++++++++ 4 files changed, 249 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index a1b766100a..31ed0a5ef7 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -27,6 +27,7 @@ gt_set_gpu_blocksize, ) from .inline_fuser import inline_dataflow_into_map +from .kill_aliasing_scalars import KillAliasingScalars from .local_double_buffering import gt_create_local_double_buffering from .loop_blocking import LoopBlocking from .map_fusion import ( diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 8a11b64714..6a8d83268b 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -773,6 +773,26 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + sdfg.save("before_kill_aliasing_scalars.sdfg") + + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + ), + validate=False, + validate_all=validate_all, + ) + # sdfg.apply_transformations_repeated( + # gtx_transformations.CopyChainRemover( + # single_use_data=single_use_data, + # ), + # validate=False, + # validate_all=validate_all, + # ) + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py b/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py new file mode 100644 index 0000000000..db5bc96a44 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py @@ -0,0 +1,143 @@ +import copy +import warnings +from typing import Any, Callable, Mapping, Optional, TypeAlias, Union + +import dace +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes +from dace.transformation import helpers as dace_helpers +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + +@dace_properties.make_properties +class KillAliasingScalars(dace_transformation.SingleStateTransformation): + first_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + second_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + + assume_single_use_data = dace_properties.Property( + dtype=bool, + default=False, + desc="Always assume that `self.access_node` is single use data. Only useful if used through `SplitAccessNode.apply_to()`.", + ) + + # Name of all data that is used at only one place. Is computed by the + # `FindSingleUseData` pass and be passed at construction time. Needed until + # [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + _single_use_data: Optional[dict[dace.SDFG, set[str]]] + + def __init__( + self, + *args: Any, + single_use_data: Optional[dict[dace.SDFG, set[str]]] = None, + assume_single_use_data: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._single_use_data = single_use_data + if assume_single_use_data is not None: + self.assume_single_use_data = assume_single_use_data + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.first_access_node, cls.second_access_node)] + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + first_node: dace_nodes.AccessNode = self.first_access_node + first_node_desc = first_node.desc(sdfg) + second_node: dace_nodes.AccessNode = self.second_access_node + second_node_desc = second_node.desc(sdfg) + + scope_dict = graph.scope_dict() + if first_node not in scope_dict or second_node not in scope_dict: + return False + + if scope_dict[first_node] != scope_dict[second_node]: + return False + + if not first_node_desc.transient or not second_node_desc.transient: + return False + + edges = graph.edges_between(first_node, second_node) + assert len(edges) == 1 + edge = next(iter(edges)) + # Check if edge volume is 1 + if edge.data.num_elements() != 1: + return False + if edge.data.dynamic: + return False + + for out_edges in graph.out_edges(second_node): + if out_edges.data.num_elements() != 1: + return False + # if out_edges.data.dynamic: + # return False + # breakpoint() + # subset: dace_subsets.Subset = edge.data.get("subset", None) + # if subset is None: + # return False + + # Make sure that the edge subset is 1 + if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( + second_node_desc, dace.data.Scalar): + return False + + # Make sure that both access nodes are not views + if isinstance(first_node_desc, dace.data.View) or isinstance( + second_node_desc, dace.data.View): + return False + + # Make sure that both access nodes are transients + if not first_node_desc.transient or not second_node_desc.transient: + return False + + if graph.in_degree(second_node) != 1: + return False + + if self.assume_single_use_data: + single_use_data = {sdfg: {first_node.data}} + if self._single_use_data is None: + find_single_use_data = first_node.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + else: + single_use_data = self._single_use_data + if first_node.data not in single_use_data[sdfg]: + return False + + if self.assume_single_use_data: + single_use_data = {sdfg: {second_node.data}} + if self._single_use_data is None: + find_single_use_data = second_node.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + else: + single_use_data = self._single_use_data + if second_node.data not in single_use_data[sdfg]: + return False + + return True + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + first_node: dace_nodes.AccessNode = self.first_access_node + second_node: dace_nodes.AccessNode = self.second_access_node + + # Redirect all outcoming edges of the second access node to the first + for edge in list(graph.out_edges(second_node)): + dace_helpers.redirect_edge(state=graph, edge=edge, new_src=first_node, new_data=first_node.data if edge.data.data == second_node.data else edge.data.data) + # edge.subset = first_node.desc(sdfg).get_subset() + # if edge.other_subset is not None: + # edge.other_subset = edge.dst.desc(sdfg).get_subset() + + # Remove the second access node + graph.remove_node(second_node) \ No newline at end of file diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py new file mode 100644 index 0000000000..1d03b3e3e8 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py @@ -0,0 +1,85 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace.transformation.passes import analysis as dace_analysis + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_map_with_scalar_copies() -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit +]: + sdfg = dace.SDFG(util.unique_name("scalar_elimination")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array( + "a", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "b", + shape=(10,), + dtype=dace.float64, + transient=True, + ) + + a, b = (state.add_access(name) for name in "ab") + for i in range(3): + sdfg.add_scalar(f"tmp{i}", dtype=dace.float64, transient=True) + tmp0, tmp1, tmp2 = (state.add_access(f"tmp{i}") for i in range(3)) + + me, mx = state.add_map("copy_map", ndrange={"__i": "0:10"}) + me.add_in_connector("IN_a") + me.add_out_connector("OUT_a") + mx.add_in_connector("IN_b") + mx.add_out_connector("OUT_b") + state.add_edge(a, None, me, "IN_a", dace.Memlet("a[__i]")) + state.add_edge(me, "OUT_a", tmp0, None, dace.Memlet("a[__i]")) + state.add_edge(tmp0, None, tmp1, None, dace.Memlet("tmp1[0]")) + state.add_edge(tmp1, None, tmp2, None, dace.Memlet("tmp1[0]")) + state.add_edge(tmp2, None, mx, "IN_b", dace.Memlet("[0] -> b[__i]")) + state.add_edge(mx, "OUT_b", b, None, dace.Memlet("b[__i]")) + + sdfg.validate() + return sdfg, state, me, mx + + +def test_remove_double_write_single_consumer(): + sdfg, state, me, mx = _make_map_with_scalar_copies() + + access_nodes_inside_original_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + assert access_nodes_inside_original_map == 3 + + sdfg.view() + breakpoint() + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + assume_single_use_data=False, + ), + validate=True, + validate_all=True, + ) + sdfg.view() + breakpoint() + access_nodes_inside_new_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + assert access_nodes_inside_new_map == 1 From 03c8cbb15d424a6ad1a9ccf1de2b860cf54c8566 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Fri, 23 Jan 2026 10:08:28 +0100 Subject: [PATCH 31/61] Rename if statements in _make_if_block --- .../transformation_tests/test_move_dataflow_into_if_body.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 03aba6599e..5338b49920 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -37,7 +37,7 @@ def _make_if_block( b2_type: dace.typeclass = dace.float64, output_type: dace.typeclass = dace.float64, ) -> dace_nodes.NestedSDFG: - inner_sdfg = dace.SDFG(util.unique_name("inner_sdfg")) + inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} for name in {b1_name, b2_name, cond_name, output_name}: @@ -47,7 +47,7 @@ def _make_if_block( transient=False, ) - if_region = dace.sdfg.state.ConditionalBlock("if") + if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) From e10b2ee6f16dc28829fb7aa4db76bba2ad0e7b31 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Fri, 23 Jan 2026 10:08:50 +0100 Subject: [PATCH 32/61] [WIP] Added fuse condition block transformation --- .../runners/dace/transformations/__init__.py | 1 + .../fuse_horizontal_conditionblocks.py | 115 +++++++++++++++++ .../test_fuse_horizontal_conditionblocks.py | 119 ++++++++++++++++++ 3 files changed, 235 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 31ed0a5ef7..f899c7e6a2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -20,6 +20,7 @@ gt_auto_optimize, ) from .dead_dataflow_elimination import gt_eliminate_dead_dataflow, gt_remove_map +from .fuse_horizontal_conditionblocks import FuseHorizontalConditionBlocks from .gpu_utils import ( GPUSetBlockSize, gt_gpu_transform_non_standard_memlet, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py new file mode 100644 index 0000000000..e68f93a48a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -0,0 +1,115 @@ +import copy +import warnings +from typing import Any, Callable, Mapping, Optional, TypeAlias, Union + +import dace +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) + +from dace.sdfg import nodes as dace_nodes, graph as dace_graph +from dace.transformation import helpers as dace_helpers +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + +@dace_properties.make_properties +class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + first_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) + second_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) + + @classmethod + def expressions(cls) -> Any: + map_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() + map_fusion_parallel_match.add_nedge(cls.access_node, cls.first_conditional_block, dace.Memlet()) + map_fusion_parallel_match.add_nedge(cls.access_node, cls.second_conditional_block, dace.Memlet()) + return [map_fusion_parallel_match] + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + access_node: dace_nodes.AccessNode = self.access_node + access_node_desc = access_node.desc(sdfg) + first_cb: dace_nodes.NestedSDFG = self.first_conditional_block + second_cb: dace_nodes.NestedSDFG = self.second_conditional_block + scope_dict = graph.scope_dict() + + if first_cb.sdfg.parent != second_cb.sdfg.parent: + return False + # Check that the nested SDFGs' names starts with "if_stmt" + if not (first_cb.sdfg.name.startswith("if_stmt") and second_cb.sdfg.name.startswith("if_stmt")): + return False + + # Check that the common access node is a boolean scalar + if not isinstance(access_node_desc, dace.data.Scalar) or access_node_desc.dtype != dace.bool_: + return False + + if len(first_cb.sdfg.nodes()) > 1 or len(second_cb.sdfg.nodes()) > 1: + return False + + first_conditional_block = next(iter(first_cb.sdfg.nodes())) + second_conditional_block = next(iter(second_cb.sdfg.nodes())) + if not (isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock)): + return False + + if scope_dict[first_cb] != scope_dict[second_cb]: + return False + + if not isinstance(scope_dict[first_cb], dace.nodes.MapEntry): + return False + + # Check that there is an edge to the conditional blocks with dst_conn == "__cond" + cond_edges_first = [e for e in graph.in_edges(first_cb) if e.dst_conn == "__cond"] + if len(cond_edges_first) != 1: + return False + cond_edges_second = [e for e in graph.in_edges(second_cb) if e.dst_conn == "__cond"] + if len(cond_edges_second) != 1: + return False + cond_edge_first = cond_edges_first[0] + cond_edge_second = cond_edges_second[0] + if cond_edge_first.src != cond_edge_second.src and (cond_edge_first.src != access_node or cond_edge_second.src != access_node): + return False + + print(f"Found valid conditional blocks: {first_cb} and {second_cb}", flush=True) + # breakpoint() + + # TODO(iomaganaris): Need to check also that first and second nested SDFGs are not reachable from each other + + return True + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + access_node: dace_nodes.AccessNode = self.access_node + first_cb: dace.sdfg.state.ConditionalBlock = self.first_conditional_block + second_cb: dace.sdfg.state.ConditionalBlock = self.second_conditional_block + + first_conditional_block = next(iter(first_cb.sdfg.nodes())) + second_conditional_block = next(iter(second_cb.sdfg.nodes())) + + second_conditional_states = list(second_conditional_block.all_states()) + + for first_inner_state in first_conditional_block.all_states(): + first_inner_state_name = first_inner_state.name + corresponding_state_in_second = None + for state in second_conditional_states: + if state.name == first_inner_state_name: + corresponding_state_in_second = state + break + if corresponding_state_in_second is None: + raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + nodes_to_move = list(corresponding_state_in_second.nodes()) + in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} + out_connectors_to_move = second_cb.out_connectors + breakpoint() + + + # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) + # breakpoint() \ No newline at end of file diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py new file mode 100644 index 0000000000..2d5025affa --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -0,0 +1,119 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace.transformation.passes import analysis as dace_analysis + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util +from .test_move_dataflow_into_if_body import _make_if_block + +import dace + +def _make_map_with_conditional_blocks() -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit +]: + sdfg = dace.SDFG(util.unique_name("map_with_conditional_blocks")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array( + "a", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "b", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "c", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "d", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + a, b, c, d = (state.add_access(name) for name in "abcd") + + for tmp_name in ["tmp_a", "tmp_b", "tmp_c", "tmp_d"]: + sdfg.add_scalar(tmp_name, dtype=dace.float64, transient=True) + tmp_a, tmp_b, tmp_c, tmp_d = (state.add_access(f"tmp_{name}") for name in "abcd") + + sdfg.add_scalar("cond_var", dtype=dace.bool_, transient=True) + cond_var = state.add_access("cond_var") + + me, mx = state.add_map("map_with_ifs", ndrange={"__i": "0:10"}) + me.add_in_connector("IN_a") + me.add_out_connector("OUT_a") + me.add_in_connector("IN_b") + me.add_out_connector("OUT_b") + mx.add_in_connector("IN_c") + mx.add_out_connector("OUT_c") + mx.add_in_connector("IN_d") + mx.add_out_connector("OUT_d") + state.add_edge(a, None, me, "IN_a", dace.Memlet("a[__i]")) + state.add_edge(b, None, me, "IN_b", dace.Memlet("b[__i]")) + state.add_edge(me, "OUT_a", tmp_a, None, dace.Memlet("a[__i]")) + state.add_edge(me, "OUT_b", tmp_b, None, dace.Memlet("b[__i]")) + + tasklet_cond = state.add_tasklet( + "tasklet_cond", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in <= 0.0", + ) + state.add_edge(tmp_a, None, tasklet_cond, "__in", dace.Memlet("tmp_a[0]")) + state.add_edge(tasklet_cond, "__out", cond_var, None, dace.Memlet("cond_var")) + + if_block_0 = _make_if_block(state=state, outer_sdfg=sdfg) + state.add_edge(cond_var, None, if_block_0, "__cond", dace.Memlet("cond_var")) + state.add_edge(tmp_a, None, if_block_0, "__arg1", dace.Memlet("tmp_a[0]")) + state.add_edge(tmp_b, None, if_block_0, "__arg2", dace.Memlet("tmp_b[0]")) + state.add_edge(if_block_0, "__output", tmp_c, None, dace.Memlet("tmp_c[0]")) + state.add_edge(tmp_c, None, mx, "IN_c", dace.Memlet("c[__i]")) + + if_block_1 = _make_if_block(state=state, outer_sdfg=sdfg) + state.add_edge(cond_var, None, if_block_1, "__cond", dace.Memlet("cond_var")) + state.add_edge(tmp_a, None, if_block_1, "__arg1", dace.Memlet("tmp_a[0]")) + state.add_edge(tmp_b, None, if_block_1, "__arg2", dace.Memlet("tmp_b[0]")) + state.add_edge(if_block_1, "__output", tmp_d, None, dace.Memlet("tmp_d[0]")) + state.add_edge(tmp_d, None, mx, "IN_d", dace.Memlet("d[__i]")) + + state.add_edge(mx, "OUT_c", c, None, dace.Memlet("c[__i]")) + state.add_edge(mx, "OUT_d", d, None, dace.Memlet("d[__i]")) + + sdfg.validate() + return sdfg, state, me, mx + +def test_fuse_horizontal_condition_blocks(): + sdfg, state, me, mx = _make_map_with_conditional_blocks() + + # sdfg.view() + # breakpoint() + + sdfg.apply_transformations_repeated( + gtx_transformations.FuseHorizontalConditionBlocks(), + validate=True, + validate_all=True, + ) + + # sdfg.view() + # breakpoint() From a2029f32b28b9ab52a16ade7ea8026414691a688 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 26 Jan 2026 17:47:07 +0100 Subject: [PATCH 33/61] [WIP] Copy access nodes between condition blocks --- .../fuse_horizontal_conditionblocks.py | 80 +++++++++++++++++-- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index e68f93a48a..4836f2115d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -12,6 +12,7 @@ from dace.sdfg import nodes as dace_nodes, graph as dace_graph from dace.transformation import helpers as dace_helpers from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations +from dace.sdfg import utils as sdutil @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): @@ -88,14 +89,66 @@ def apply( sdfg: dace.SDFG, ) -> None: access_node: dace_nodes.AccessNode = self.access_node - first_cb: dace.sdfg.state.ConditionalBlock = self.first_conditional_block - second_cb: dace.sdfg.state.ConditionalBlock = self.second_conditional_block + first_cb: dace_nodes.NestedSDFG = self.first_conditional_block + second_cb: dace_nodes.NestedSDFG = self.second_conditional_block first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) second_conditional_states = list(second_conditional_block.all_states()) + in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} + out_connectors_to_move = second_cb.out_connectors + in_connectors_to_move_rename_map = {} + out_connectors_to_move_rename_map = {} + for k, v in in_connectors_to_move.items(): + new_connector_name = k + if new_connector_name in first_cb.in_connectors: + new_connector_name = f"{k}_from_second" + in_connectors_to_move_rename_map[k] = new_connector_name + first_cb.add_in_connector(new_connector_name) + for edge in graph.in_edges(second_cb): + if edge.dst_conn == k: + dace_helpers.redirect_edge(state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb) + for k, v in out_connectors_to_move.items(): + new_connector_name = k + if new_connector_name in first_cb.out_connectors: + new_connector_name = f"{k}_from_second" + out_connectors_to_move_rename_map[k] = new_connector_name + first_cb.add_out_connector(new_connector_name) + for edge in graph.out_edges(second_cb): + if edge.src_conn == k: + dace_helpers.redirect_edge(state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb) + + nodes_renamed_map = {} + for first_inner_state in first_conditional_block.all_states(): + first_inner_state_name = first_inner_state.name + corresponding_state_in_second = None + for state in second_conditional_states: + if state.name == first_inner_state_name: + corresponding_state_in_second = state + break + if corresponding_state_in_second is None: + raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + nodes_to_move = list(corresponding_state_in_second.nodes()) + for node in nodes_to_move: + new_node = node + if isinstance(node, dace_nodes.AccessNode): + if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: + new_data_name = f"{node.data}_from_second" + new_node = dace_nodes.AccessNode(new_data_name) + new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) + new_desc.name = new_data_name + if new_data_name not in first_cb.sdfg.arrays: + first_cb.sdfg.add_datadesc(new_data_name, new_desc) + else: + second_cb.sdfg.remove_data(node.data) + nodes_renamed_map[node] = new_node + first_inner_state.add_node(new_node) + + second_to_first_connections = {} + for node in nodes_renamed_map: + second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): first_inner_state_name = first_inner_state.name corresponding_state_in_second = None @@ -106,9 +159,26 @@ def apply( if corresponding_state_in_second is None: raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") nodes_to_move = list(corresponding_state_in_second.nodes()) - in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} - out_connectors_to_move = second_cb.out_connectors - breakpoint() + for node in nodes_to_move: + for edge in list(corresponding_state_in_second.out_edges(node)): + dst = edge.dst + if dst in nodes_to_move: + new_memlet = copy.deepcopy(edge.data) + if edge.data.data in second_to_first_connections: + new_memlet.data = second_to_first_connections[edge.data.data] + first_inner_state.add_edge(nodes_renamed_map[node], nodes_renamed_map[node].data, nodes_renamed_map[edge.dst], second_to_first_connections[node.data], new_memlet) + for edge in list(graph.out_edges(access_node)): + if edge.dst == second_cb: + graph.remove_edge(edge) + + # TODO(iomaganaris): Figure out if I have to handle any symbols + + # Need to remove both references to remove NestedSDFG from graph + graph.remove_node(second_conditional_block) + graph.remove_node(second_cb) + + sdfg.view() + breakpoint() # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) From 8961d5e764daa37e97084e851c5e781499b57317 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 26 Jan 2026 20:36:10 +0100 Subject: [PATCH 34/61] Fix branch selection --- .../fuse_horizontal_conditionblocks.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 4836f2115d..9c239cb1a5 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -123,9 +123,13 @@ def apply( nodes_renamed_map = {} for first_inner_state in first_conditional_block.all_states(): first_inner_state_name = first_inner_state.name + true_branch = "true_branch" in first_inner_state_name corresponding_state_in_second = None for state in second_conditional_states: - if state.name == first_inner_state_name: + if true_branch and "true_branch" in state.name: + corresponding_state_in_second = state + break + elif not true_branch and "false_branch" in state.name: corresponding_state_in_second = state break if corresponding_state_in_second is None: @@ -151,9 +155,13 @@ def apply( second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): first_inner_state_name = first_inner_state.name + true_branch = "true_branch" in first_inner_state_name corresponding_state_in_second = None for state in second_conditional_states: - if state.name == first_inner_state_name: + if true_branch and "true_branch" in state.name: + corresponding_state_in_second = state + break + elif not true_branch and "false_branch" in state.name: corresponding_state_in_second = state break if corresponding_state_in_second is None: From 49b7619347b0b17f57c1c1ad3edee2db70deae6c Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 27 Jan 2026 09:16:23 +0100 Subject: [PATCH 35/61] Moved kill aliasing transformation and updates to the fuse conditionals --- .../dace/transformations/auto_optimize.py | 34 ++++++++----------- .../fuse_horizontal_conditionblocks.py | 28 ++++++++++++--- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 6a8d83268b..1b9d9b2157 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -652,9 +652,23 @@ def _gt_auto_process_top_level_maps( validate_all=validate_all, skip=gtx_transformations.constants._GT_AUTO_OPT_TOP_LEVEL_STAGE_SIMPLIFY_SKIP_LIST, ) + sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") + + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + ), + validate=False, + validate_all=validate_all, + ) + sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") if GT4PyAutoOptHook.TopLevelDataFlowPost in optimization_hooks: optimization_hooks[GT4PyAutoOptHook.TopLevelDataFlowPost](sdfg) # type: ignore[call-arg] + sdfg.save("after_top_level_map_optimization.sdfg") return sdfg @@ -773,26 +787,6 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - sdfg.save("before_kill_aliasing_scalars.sdfg") - - find_single_use_data = dace_analysis.FindSingleUseData() - single_use_data = find_single_use_data.apply_pass(sdfg, None) - - sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( - single_use_data=single_use_data, - ), - validate=False, - validate_all=validate_all, - ) - # sdfg.apply_transformations_repeated( - # gtx_transformations.CopyChainRemover( - # single_use_data=single_use_data, - # ), - # validate=False, - # validate_all=validate_all, - # ) - return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 9c239cb1a5..83fbae70c8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -1,4 +1,5 @@ import copy +import uuid import warnings from typing import Any, Callable, Mapping, Optional, TypeAlias, Union @@ -14,6 +15,14 @@ from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations from dace.sdfg import utils as sdutil +def unique_name(name: str) -> str: + """Adds a unique string to `name`.""" + maximal_length = 200 + unique_sufix = str(uuid.uuid1()).replace("-", "_") + if len(name) > (maximal_length - len(unique_sufix)): + name = name[: (maximal_length - len(unique_sufix) - 1)] + return f"{name}_{unique_sufix}" + @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -80,6 +89,16 @@ def can_be_applied( # breakpoint() # TODO(iomaganaris): Need to check also that first and second nested SDFGs are not reachable from each other + if gtx_transformations.utils.is_reachable( + start=first_cb, + target=second_cb, + state=graph, + ) or gtx_transformations.utils.is_reachable( + start=second_cb, + target=first_cb, + state=graph, + ): + return False return True @@ -104,7 +123,7 @@ def apply( for k, v in in_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.in_connectors: - new_connector_name = f"{k}_from_second" + new_connector_name = unique_name(k) in_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): @@ -113,7 +132,7 @@ def apply( for k, v in out_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.out_connectors: - new_connector_name = f"{k}_from_second" + new_connector_name = unique_name(k) out_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): @@ -139,14 +158,12 @@ def apply( new_node = node if isinstance(node, dace_nodes.AccessNode): if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: - new_data_name = f"{node.data}_from_second" + new_data_name = unique_name(node.data) new_node = dace_nodes.AccessNode(new_data_name) new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) new_desc.name = new_data_name if new_data_name not in first_cb.sdfg.arrays: first_cb.sdfg.add_datadesc(new_data_name, new_desc) - else: - second_cb.sdfg.remove_data(node.data) nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) @@ -180,6 +197,7 @@ def apply( graph.remove_edge(edge) # TODO(iomaganaris): Figure out if I have to handle any symbols + sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") # Need to remove both references to remove NestedSDFG from graph graph.remove_node(second_conditional_block) From eb462f6a47d8ff3bf1e08e1dce8222cd22364e55 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 27 Jan 2026 09:31:56 +0100 Subject: [PATCH 36/61] Rename properly second map arrays --- .../fuse_horizontal_conditionblocks.py | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 83fbae70c8..ab7105a3b5 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -114,6 +114,30 @@ def apply( first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) + original_arrays_first_conditional_block = {} + for data_name, data_desc in first_conditional_block.sdfg.arrays.items(): + original_arrays_first_conditional_block[data_name] = data_desc + original_arrays_second_conditional_block = {} + for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): + original_arrays_second_conditional_block[data_name] = data_desc + total_original_arrays = len(original_arrays_first_conditional_block) + len(original_arrays_second_conditional_block) + + second_arrays_rename_map = {} + for data_name, data_desc in original_arrays_second_conditional_block.items(): + if data_name == "__cond": + continue + if data_name in original_arrays_first_conditional_block: + new_data_name = unique_name(data_name) + second_arrays_rename_map[data_name] = new_data_name + data_desc_renamed = copy.deepcopy(data_desc) + data_desc_renamed.name = new_data_name + if new_data_name not in first_cb.sdfg.arrays: + first_cb.sdfg.add_datadesc(new_data_name, data_desc_renamed) + else: + second_arrays_rename_map[data_name] = data_name + if data_name not in first_cb.sdfg.arrays: + first_cb.sdfg.add_datadesc(data_name, copy.deepcopy(data_desc)) + second_conditional_states = list(second_conditional_block.all_states()) in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} @@ -123,7 +147,7 @@ def apply( for k, v in in_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.in_connectors: - new_connector_name = unique_name(k) + new_connector_name = second_arrays_rename_map[k] in_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): @@ -132,7 +156,7 @@ def apply( for k, v in out_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.out_connectors: - new_connector_name = unique_name(k) + new_connector_name = second_arrays_rename_map[k] out_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): @@ -158,12 +182,12 @@ def apply( new_node = node if isinstance(node, dace_nodes.AccessNode): if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: - new_data_name = unique_name(node.data) + new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) new_desc.name = new_data_name - if new_data_name not in first_cb.sdfg.arrays: - first_cb.sdfg.add_datadesc(new_data_name, new_desc) + # if new_data_name not in first_cb.sdfg.arrays: + # first_cb.sdfg.add_datadesc(new_data_name, new_desc) nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) @@ -203,8 +227,11 @@ def apply( graph.remove_node(second_conditional_block) graph.remove_node(second_cb) - sdfg.view() - breakpoint() + new_arrays = len(first_cb.sdfg.arrays) + assert new_arrays == total_original_arrays - 1, f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" + + # sdfg.view() + # breakpoint() # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) From 73dc4c1e1e41a74c880783649fc3e8e14459f143 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Wed, 28 Jan 2026 18:11:56 +0100 Subject: [PATCH 37/61] Enable if grouping in proper place --- .../dace/transformations/auto_optimize.py | 33 +++++++++++-------- .../fuse_horizontal_conditionblocks.py | 6 ++-- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1b9d9b2157..10101a7b7e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -652,19 +652,6 @@ def _gt_auto_process_top_level_maps( validate_all=validate_all, skip=gtx_transformations.constants._GT_AUTO_OPT_TOP_LEVEL_STAGE_SIMPLIFY_SKIP_LIST, ) - sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") - - find_single_use_data = dace_analysis.FindSingleUseData() - single_use_data = find_single_use_data.apply_pass(sdfg, None) - - sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( - single_use_data=single_use_data, - ), - validate=False, - validate_all=validate_all, - ) - sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") if GT4PyAutoOptHook.TopLevelDataFlowPost in optimization_hooks: optimization_hooks[GT4PyAutoOptHook.TopLevelDataFlowPost](sdfg) # type: ignore[call-arg] @@ -740,6 +727,26 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") + + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + ), + validate=False, + validate_all=validate_all, + ) + sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") + + sdfg.apply_transformations_repeated( + gtx_transformations.FuseHorizontalConditionBlocks(), + validate=True, + validate_all=True, + ) + # Move dataflow into the branches of the `if` such that they are only evaluated # if they are needed. Important to call it repeatedly. # TODO(phimuell): It is unclear if `MoveDataflowIntoIfBody` should be called diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index ab7105a3b5..0417500d45 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -220,16 +220,14 @@ def apply( if edge.dst == second_cb: graph.remove_edge(edge) - # TODO(iomaganaris): Figure out if I have to handle any symbols - sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") - # Need to remove both references to remove NestedSDFG from graph graph.remove_node(second_conditional_block) graph.remove_node(second_cb) new_arrays = len(first_cb.sdfg.arrays) assert new_arrays == total_original_arrays - 1, f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" - + # TODO(iomaganaris): Figure out if I have to handle any symbols + sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") # sdfg.view() # breakpoint() From 08ce68cb3f94b0e44e27496b73a70ea2ce874e1c Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 10:19:06 +0100 Subject: [PATCH 38/61] Rename kill aliasing scalars to remove --- .../runners/dace/transformations/__init__.py | 3 +- .../dace/transformations/auto_optimize.py | 2 +- ..._scalars.py => remove_aliasing_scalars.py} | 62 ++++++++++--------- ...ars.py => test_remove_aliasing_scalars.py} | 15 ++--- 4 files changed, 44 insertions(+), 38 deletions(-) rename src/gt4py/next/program_processors/runners/dace/transformations/{kill_aliasing_scalars.py => remove_aliasing_scalars.py} (78%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/{test_kill_aliasing_scalars.py => test_remove_aliasing_scalars.py} (85%) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index f899c7e6a2..0b575fafd9 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -28,7 +28,7 @@ gt_set_gpu_blocksize, ) from .inline_fuser import inline_dataflow_into_map -from .kill_aliasing_scalars import KillAliasingScalars +from .remove_aliasing_scalars import RemoveAliasingScalars from .local_double_buffering import gt_create_local_double_buffering from .loop_blocking import LoopBlocking from .map_fusion import ( @@ -94,6 +94,7 @@ "GT4PyStateFusion", "HorizontalMapFusionCallback", "HorizontalMapSplitCallback", + "RemoveAliasingScalars", "LoopBlocking", "MapFusionHorizontal", "MapFusionVertical", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 10101a7b7e..85a9e78591 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -733,7 +733,7 @@ def _gt_auto_process_dataflow_inside_maps( single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( + gtx_transformations.RemoveAliasingScalars( single_use_data=single_use_data, ), validate=False, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py similarity index 78% rename from src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py rename to src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py index db5bc96a44..3572f43980 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py @@ -1,19 +1,21 @@ -import copy -import warnings -from typing import Any, Callable, Mapping, Optional, TypeAlias, Union +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Union import dace -from dace import ( - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) +from dace import properties as dace_properties, transformation as dace_transformation from dace.sdfg import nodes as dace_nodes from dace.transformation import helpers as dace_helpers -from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + @dace_properties.make_properties -class KillAliasingScalars(dace_transformation.SingleStateTransformation): +class RemoveAliasingScalars(dace_transformation.SingleStateTransformation): first_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) second_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -59,40 +61,40 @@ def can_be_applied( scope_dict = graph.scope_dict() if first_node not in scope_dict or second_node not in scope_dict: return False - + + # Make sure that both access nodes are in the same scope if scope_dict[first_node] != scope_dict[second_node]: return False + # Make sure that both access nodes are transients if not first_node_desc.transient or not second_node_desc.transient: return False edges = graph.edges_between(first_node, second_node) assert len(edges) == 1 edge = next(iter(edges)) - # Check if edge volume is 1 + + # Check if the edge transfers only one element if edge.data.num_elements() != 1: return False if edge.data.dynamic: return False - + + # Check that all the outgoing edges of the second access node transfer only one element for out_edges in graph.out_edges(second_node): if out_edges.data.num_elements() != 1: return False - # if out_edges.data.dynamic: - # return False - # breakpoint() - # subset: dace_subsets.Subset = edge.data.get("subset", None) - # if subset is None: - # return False # Make sure that the edge subset is 1 if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( - second_node_desc, dace.data.Scalar): + second_node_desc, dace.data.Scalar + ): return False # Make sure that both access nodes are not views if isinstance(first_node_desc, dace.data.View) or isinstance( - second_node_desc, dace.data.View): + second_node_desc, dace.data.View + ): return False # Make sure that both access nodes are transients @@ -101,7 +103,8 @@ def can_be_applied( if graph.in_degree(second_node) != 1: return False - + + # Make sure that both access nodes are single use data if self.assume_single_use_data: single_use_data = {sdfg: {first_node.data}} if self._single_use_data is None: @@ -111,7 +114,6 @@ def can_be_applied( single_use_data = self._single_use_data if first_node.data not in single_use_data[sdfg]: return False - if self.assume_single_use_data: single_use_data = {sdfg: {second_node.data}} if self._single_use_data is None: @@ -123,7 +125,7 @@ def can_be_applied( return False return True - + def apply( self, graph: Union[dace.SDFGState, dace.SDFG], @@ -134,10 +136,12 @@ def apply( # Redirect all outcoming edges of the second access node to the first for edge in list(graph.out_edges(second_node)): - dace_helpers.redirect_edge(state=graph, edge=edge, new_src=first_node, new_data=first_node.data if edge.data.data == second_node.data else edge.data.data) - # edge.subset = first_node.desc(sdfg).get_subset() - # if edge.other_subset is not None: - # edge.other_subset = edge.dst.desc(sdfg).get_subset() + dace_helpers.redirect_edge( + state=graph, + edge=edge, + new_src=first_node, + new_data=first_node.data if edge.data.data == second_node.data else edge.data.data, + ) # Remove the second access node - graph.remove_node(second_node) \ No newline at end of file + graph.remove_node(second_node) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py similarity index 85% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py index 1d03b3e3e8..a755504f40 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py @@ -64,22 +64,23 @@ def _make_map_with_scalar_copies() -> tuple[ def test_remove_double_write_single_consumer(): sdfg, state, me, mx = _make_map_with_scalar_copies() - access_nodes_inside_original_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + access_nodes_inside_original_map = util.count_nodes( + state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode + ) assert access_nodes_inside_original_map == 3 - sdfg.view() - breakpoint() find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( + gtx_transformations.RemoveAliasingScalars( single_use_data=single_use_data, assume_single_use_data=False, ), validate=True, validate_all=True, ) - sdfg.view() - breakpoint() - access_nodes_inside_new_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + + access_nodes_inside_new_map = util.count_nodes( + state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode + ) assert access_nodes_inside_new_map == 1 From e110ab2b129a84c204a5077f1ae0c230aa664d54 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 10:19:35 +0100 Subject: [PATCH 39/61] Cleared a bit FuseConditionalBlocks --- .../runners/dace/transformations/__init__.py | 1 + .../fuse_horizontal_conditionblocks.py | 140 +++++++++++------- .../test_fuse_horizontal_conditionblocks.py | 94 ++++++++++-- 3 files changed, 171 insertions(+), 64 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 0b575fafd9..f3b2f0bc3a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -85,6 +85,7 @@ __all__ = [ "CopyChainRemover", "DoubleWriteRemover", + "FuseHorizontalConditionBlocks", "GPUSetBlockSize", "GT4PyAutoOptHook", "GT4PyAutoOptHookFun", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 0417500d45..4ea72c7ddf 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -1,19 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import copy import uuid -import warnings -from typing import Any, Callable, Mapping, Optional, TypeAlias, Union +from typing import Any, Optional, Union import dace -from dace import ( - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) - -from dace.sdfg import nodes as dace_nodes, graph as dace_graph +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import graph as dace_graph, nodes as dace_nodes from dace.transformation import helpers as dace_helpers + from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations -from dace.sdfg import utils as sdutil + def unique_name(name: str) -> str: """Adds a unique string to `name`.""" @@ -23,6 +26,7 @@ def unique_name(name: str) -> str: name = name[: (maximal_length - len(unique_sufix) - 1)] return f"{name}_{unique_sufix}" + @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -32,8 +36,12 @@ class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformatio @classmethod def expressions(cls) -> Any: map_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() - map_fusion_parallel_match.add_nedge(cls.access_node, cls.first_conditional_block, dace.Memlet()) - map_fusion_parallel_match.add_nedge(cls.access_node, cls.second_conditional_block, dace.Memlet()) + map_fusion_parallel_match.add_nedge( + cls.access_node, cls.first_conditional_block, dace.Memlet() + ) + map_fusion_parallel_match.add_nedge( + cls.access_node, cls.second_conditional_block, dace.Memlet() + ) return [map_fusion_parallel_match] def can_be_applied( @@ -49,27 +57,41 @@ def can_be_applied( second_cb: dace_nodes.NestedSDFG = self.second_conditional_block scope_dict = graph.scope_dict() + # Check that both conditional blocks are in the same parent SDFG if first_cb.sdfg.parent != second_cb.sdfg.parent: return False + # Check that the nested SDFGs' names starts with "if_stmt" - if not (first_cb.sdfg.name.startswith("if_stmt") and second_cb.sdfg.name.startswith("if_stmt")): + if not ( + first_cb.sdfg.name.startswith("if_stmt") and second_cb.sdfg.name.startswith("if_stmt") + ): return False # Check that the common access node is a boolean scalar - if not isinstance(access_node_desc, dace.data.Scalar) or access_node_desc.dtype != dace.bool_: + if ( + not isinstance(access_node_desc, dace.data.Scalar) + or access_node_desc.dtype != dace.bool_ + ): return False + # Make sure that the conditional blocks contain only one conditional block each if len(first_cb.sdfg.nodes()) > 1 or len(second_cb.sdfg.nodes()) > 1: return False + # Get the actual conditional blocks first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) - if not (isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock)): + if not ( + isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) + and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock) + ): return False + # Make sure that both conditional blocks are in the same scope if scope_dict[first_cb] != scope_dict[second_cb]: return False + # Make sure that both conditional blocks are in a map scope if not isinstance(scope_dict[first_cb], dace.nodes.MapEntry): return False @@ -82,13 +104,12 @@ def can_be_applied( return False cond_edge_first = cond_edges_first[0] cond_edge_second = cond_edges_second[0] - if cond_edge_first.src != cond_edge_second.src and (cond_edge_first.src != access_node or cond_edge_second.src != access_node): + if cond_edge_first.src != cond_edge_second.src and ( + cond_edge_first.src != access_node or cond_edge_second.src != access_node + ): return False - print(f"Found valid conditional blocks: {first_cb} and {second_cb}", flush=True) - # breakpoint() - - # TODO(iomaganaris): Need to check also that first and second nested SDFGs are not reachable from each other + # Need to check also that first and second nested SDFGs are not reachable from each other if gtx_transformations.utils.is_reachable( start=first_cb, target=second_cb, @@ -114,14 +135,19 @@ def apply( first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) + # Store original arrays to check later that all the necessary arrays have been moved original_arrays_first_conditional_block = {} for data_name, data_desc in first_conditional_block.sdfg.arrays.items(): original_arrays_first_conditional_block[data_name] = data_desc original_arrays_second_conditional_block = {} for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): original_arrays_second_conditional_block[data_name] = data_desc - total_original_arrays = len(original_arrays_first_conditional_block) + len(original_arrays_second_conditional_block) + total_original_arrays = len(original_arrays_first_conditional_block) + len( + original_arrays_second_conditional_block + ) + # Store the new names for the arrays in the second conditional block to avoid name clashes and add their data descriptors + # to the first conditional block SDFG second_arrays_rename_map = {} for data_name, data_desc in original_arrays_second_conditional_block.items(): if data_name == "__cond": @@ -140,11 +166,12 @@ def apply( second_conditional_states = list(second_conditional_block.all_states()) + # Move the connectors from the second conditional block to the first in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} out_connectors_to_move = second_cb.out_connectors in_connectors_to_move_rename_map = {} out_connectors_to_move_rename_map = {} - for k, v in in_connectors_to_move.items(): + for k, _v in in_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.in_connectors: new_connector_name = second_arrays_rename_map[k] @@ -152,8 +179,10 @@ def apply( first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): if edge.dst_conn == k: - dace_helpers.redirect_edge(state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb) - for k, v in out_connectors_to_move.items(): + dace_helpers.redirect_edge( + state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb + ) + for k, _v in out_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.out_connectors: new_connector_name = second_arrays_rename_map[k] @@ -161,12 +190,15 @@ def apply( first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): if edge.src_conn == k: - dace_helpers.redirect_edge(state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb) + dace_helpers.redirect_edge( + state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb + ) - nodes_renamed_map = {} - for first_inner_state in first_conditional_block.all_states(): - first_inner_state_name = first_inner_state.name - true_branch = "true_branch" in first_inner_state_name + def _find_corresponding_state_in_second( + inner_state: dace.SDFGState, + ) -> dace.SDFGState: + inner_state_name = inner_state.name + true_branch = "true_branch" in inner_state_name corresponding_state_in_second = None for state in second_conditional_states: if true_branch and "true_branch" in state.name: @@ -176,7 +208,15 @@ def apply( corresponding_state_in_second = state break if corresponding_state_in_second is None: - raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + raise RuntimeError( + f"Could not find corresponding state in second conditional block for state {inner_state_name}" + ) + return corresponding_state_in_second + + # Copy first the nodes from the second conditional block to the first + nodes_renamed_map = {} + for first_inner_state in first_conditional_block.all_states(): + corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: new_node = node @@ -186,27 +226,15 @@ def apply( new_node = dace_nodes.AccessNode(new_data_name) new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) new_desc.name = new_data_name - # if new_data_name not in first_cb.sdfg.arrays: - # first_cb.sdfg.add_datadesc(new_data_name, new_desc) nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) + # Then copy the edges second_to_first_connections = {} for node in nodes_renamed_map: second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): - first_inner_state_name = first_inner_state.name - true_branch = "true_branch" in first_inner_state_name - corresponding_state_in_second = None - for state in second_conditional_states: - if true_branch and "true_branch" in state.name: - corresponding_state_in_second = state - break - elif not true_branch and "false_branch" in state.name: - corresponding_state_in_second = state - break - if corresponding_state_in_second is None: - raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: for edge in list(corresponding_state_in_second.out_edges(node)): @@ -215,7 +243,17 @@ def apply( new_memlet = copy.deepcopy(edge.data) if edge.data.data in second_to_first_connections: new_memlet.data = second_to_first_connections[edge.data.data] - first_inner_state.add_edge(nodes_renamed_map[node], nodes_renamed_map[node].data, nodes_renamed_map[edge.dst], second_to_first_connections[node.data], new_memlet) + first_inner_state.add_edge( + nodes_renamed_map[node], + nodes_renamed_map[node].data + if isinstance(node, dace_nodes.AccessNode) and edge.src_conn + else edge.src_conn, + nodes_renamed_map[dst], + second_to_first_connections[dst.data] + if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn + else edge.dst_conn, + new_memlet, + ) for edge in list(graph.out_edges(access_node)): if edge.dst == second_cb: graph.remove_edge(edge) @@ -225,12 +263,6 @@ def apply( graph.remove_node(second_cb) new_arrays = len(first_cb.sdfg.arrays) - assert new_arrays == total_original_arrays - 1, f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" - # TODO(iomaganaris): Figure out if I have to handle any symbols - sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") - # sdfg.view() - # breakpoint() - - - # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) - # breakpoint() \ No newline at end of file + assert new_arrays == total_original_arrays - 1, ( + f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 2d5025affa..04dd6a7dcc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -21,9 +21,72 @@ import dace -def _make_map_with_conditional_blocks() -> tuple[ - dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit -]: + +def _make_if_block_with_tasklet( + state: dace.SDFGState, + b1_name: str = "__arg1", + b2_name: str = "__arg2", + cond_name: str = "__cond", + output_name: str = "__output", + b1_type: dace.typeclass = dace.float64, + b2_type: dace.typeclass = dace.float64, + output_type: dace.typeclass = dace.float64, +) -> dace_nodes.NestedSDFG: + inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) + + types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} + for name in {b1_name, b2_name, cond_name, output_name}: + inner_sdfg.add_scalar( + name, + dtype=types[name], + transient=False, + ) + + if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) + inner_sdfg.add_node(if_region, is_start_block=True) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + tasklet = tstate.add_tasklet( + "true_tasklet", + inputs={"__tasklet_in"}, + outputs={"__tasklet_out"}, + code="__tasklet_out = __tasklet_in * 2.0", + ) + tstate.add_edge( + tstate.add_access(b1_name), + None, + tasklet, + "__tasklet_in", + dace.Memlet(f"{b1_name}[0]"), + ) + tstate.add_edge( + tasklet, + "__tasklet_out", + tstate.add_access(output_name), + None, + dace.Memlet(f"{output_name}[0]"), + ) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=inner_sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + fstate.add_nedge( + fstate.add_access(b2_name), + fstate.add_access(output_name), + dace.Memlet(f"{b2_name}[0] -> [0]"), + ) + + if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"not {cond_name}"), else_body) + + return state.add_nested_sdfg( + sdfg=inner_sdfg, + inputs={b1_name, b2_name, cond_name}, + outputs={output_name}, + ) + + +def _make_map_with_conditional_blocks() -> dace.SDFG: sdfg = dace.SDFG(util.unique_name("map_with_conditional_blocks")) state = sdfg.add_state(is_start_block=True) @@ -90,7 +153,7 @@ def _make_map_with_conditional_blocks() -> tuple[ state.add_edge(if_block_0, "__output", tmp_c, None, dace.Memlet("tmp_c[0]")) state.add_edge(tmp_c, None, mx, "IN_c", dace.Memlet("c[__i]")) - if_block_1 = _make_if_block(state=state, outer_sdfg=sdfg) + if_block_1 = _make_if_block_with_tasklet(state=state) state.add_edge(cond_var, None, if_block_1, "__cond", dace.Memlet("cond_var")) state.add_edge(tmp_a, None, if_block_1, "__arg1", dace.Memlet("tmp_a[0]")) state.add_edge(tmp_b, None, if_block_1, "__arg2", dace.Memlet("tmp_b[0]")) @@ -101,13 +164,19 @@ def _make_map_with_conditional_blocks() -> tuple[ state.add_edge(mx, "OUT_d", d, None, dace.Memlet("d[__i]")) sdfg.validate() - return sdfg, state, me, mx + return sdfg + def test_fuse_horizontal_condition_blocks(): - sdfg, state, me, mx = _make_map_with_conditional_blocks() + sdfg = _make_map_with_conditional_blocks() + + conditional_blocks = [ + n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.sdfg.state.ConditionalBlock) + ] + assert len(conditional_blocks) == 2 - # sdfg.view() - # breakpoint() + ref, res = util.make_sdfg_args(sdfg) + util.compile_and_run_sdfg(sdfg, **ref) sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), @@ -115,5 +184,10 @@ def test_fuse_horizontal_condition_blocks(): validate_all=True, ) - # sdfg.view() - # breakpoint() + new_conditional_blocks = [ + n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.sdfg.state.ConditionalBlock) + ] + assert len(new_conditional_blocks) == 1 + + util.compile_and_run_sdfg(sdfg, **res) + assert util.compare_sdfg_res(ref=ref, res=res) From 15ae54d62d721e9f2eb52a90d4c0ffe72880e8f4 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 10:27:07 +0100 Subject: [PATCH 40/61] Fixed imports --- .../runners/dace/transformations/__init__.py | 4 ++-- .../dace/transformations/fuse_horizontal_conditionblocks.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index f3b2f0bc3a..edd52d0280 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -28,7 +28,6 @@ gt_set_gpu_blocksize, ) from .inline_fuser import inline_dataflow_into_map -from .remove_aliasing_scalars import RemoveAliasingScalars from .local_double_buffering import gt_create_local_double_buffering from .loop_blocking import LoopBlocking from .map_fusion import ( @@ -55,6 +54,7 @@ ) from .redundant_array_removers import CopyChainRemover, DoubleWriteRemover, gt_remove_copy_chain from .remove_access_node_copies import RemoveAccessNodeCopies +from .remove_aliasing_scalars import RemoveAliasingScalars from .remove_views import RemovePointwiseViews from .scan_loop_unrolling import ScanLoopUnrolling from .simplify import ( @@ -95,7 +95,6 @@ "GT4PyStateFusion", "HorizontalMapFusionCallback", "HorizontalMapSplitCallback", - "RemoveAliasingScalars", "LoopBlocking", "MapFusionHorizontal", "MapFusionVertical", @@ -108,6 +107,7 @@ "MultiStateGlobalSelfCopyElimination", "MultiStateGlobalSelfCopyElimination2", "RemoveAccessNodeCopies", + "RemoveAliasingScalars", "RemovePointwiseViews", "ScanLoopUnrolling", "SingleStateGlobalDirectSelfCopyElimination", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 4ea72c7ddf..9fcec15b14 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -8,7 +8,7 @@ import copy import uuid -from typing import Any, Optional, Union +from typing import Any, Union import dace from dace import properties as dace_properties, transformation as dace_transformation From 8f44eb5475e90c7f48dbd6be7bbc05ef11738476 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 2 Feb 2026 17:16:20 +0100 Subject: [PATCH 41/61] Remove sdfg.save --- .../runners/dace/transformations/auto_optimize.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 85a9e78591..10a6256fef 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -655,7 +655,6 @@ def _gt_auto_process_top_level_maps( if GT4PyAutoOptHook.TopLevelDataFlowPost in optimization_hooks: optimization_hooks[GT4PyAutoOptHook.TopLevelDataFlowPost](sdfg) # type: ignore[call-arg] - sdfg.save("after_top_level_map_optimization.sdfg") return sdfg @@ -727,8 +726,6 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") - find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) @@ -739,7 +736,6 @@ def _gt_auto_process_dataflow_inside_maps( validate=False, validate_all=validate_all, ) - sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), From 6a8b1b4fea50ebba496ec3c8ce594cf2ac80f32d Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 09:32:44 +0100 Subject: [PATCH 42/61] Fix issues after cherry-picking --- .../dace/transformations/fuse_horizontal_conditionblocks.py | 3 ++- .../test_fuse_horizontal_conditionblocks.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 9fcec15b14..470610aa5f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -232,7 +232,8 @@ def _find_corresponding_state_in_second( # Then copy the edges second_to_first_connections = {} for node in nodes_renamed_map: - second_to_first_connections[node.data] = nodes_renamed_map[node].data + if isinstance(node, dace_nodes.AccessNode): + second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) nodes_to_move = list(corresponding_state_in_second.nodes()) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 04dd6a7dcc..ec23b4ee24 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -141,7 +141,7 @@ def _make_map_with_conditional_blocks() -> dace.SDFG: "tasklet_cond", inputs={"__in"}, outputs={"__out"}, - code="__out = __in <= 0.0", + code="__out = __in <= 0.5", ) state.add_edge(tmp_a, None, tasklet_cond, "__in", dace.Memlet("tmp_a[0]")) state.add_edge(tasklet_cond, "__out", cond_var, None, dace.Memlet("cond_var")) From 0ce43372b2c6cc4a89968a1ec01768343edc3bb4 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 14:08:39 +0100 Subject: [PATCH 43/61] Applied suggestions to FuseHorizontalConditionBlocks and moved unique_name to gtx_transformations utils --- _dacegraphs/invalid.sdfgz | Bin 0 -> 5370 bytes .../runners/dace/transformations/__init__.py | 3 +- .../fuse_horizontal_conditionblocks.py | 142 +++++++++++------- .../runners/dace/transformations/utils.py | 10 ++ .../test_auto_optimizer_hooks.py | 2 +- .../test_copy_chain_remover.py | 20 +-- .../test_create_local_double_buffering.py | 10 +- .../test_dead_dataflow_elimination.py | 4 +- .../test_distributed_buffer_relocator.py | 24 ++- .../test_double_write_remover.py | 8 +- .../test_fuse_horizontal_conditionblocks.py | 6 +- .../transformation_tests/test_gpu_utils.py | 11 +- .../test_horizontal_map_split_fusion.py | 2 +- .../transformation_tests/test_inline_fuser.py | 8 +- .../test_loop_blocking.py | 28 ++-- .../test_make_transients_persistent.py | 4 +- .../test_map_buffer_elimination.py | 10 +- .../test_map_fusion_utils.py | 5 +- .../transformation_tests/test_map_order.py | 2 +- .../transformation_tests/test_map_promoter.py | 10 +- .../transformation_tests/test_map_splitter.py | 2 +- .../transformation_tests/test_map_to_copy.py | 2 +- .../test_move_dataflow_into_if_body.py | 24 +-- .../test_move_tasklet_into_map.py | 2 +- ...ulti_state_global_self_copy_elimination.py | 16 +- .../test_multiple_copies_global.py | 2 +- .../test_remove_aliasing_scalars.py | 2 +- .../test_remove_point_view.py | 2 +- ...ngle_state_global_self_copy_elimination.py | 32 ++-- .../test_split_access_node.py | 40 ++--- .../transformation_tests/test_split_memlet.py | 6 +- .../test_splitting_tools.py | 10 +- .../transformation_tests/test_state_fusion.py | 28 ++-- .../transformation_tests/test_strides.py | 28 ++-- .../test_vertical_map_split_fusion.py | 6 +- .../dace_tests/transformation_tests/util.py | 15 +- 36 files changed, 306 insertions(+), 220 deletions(-) create mode 100644 _dacegraphs/invalid.sdfgz diff --git a/_dacegraphs/invalid.sdfgz b/_dacegraphs/invalid.sdfgz new file mode 100644 index 0000000000000000000000000000000000000000..2869814a1c9a9c302e27afa792eef634827ef0c0 GIT binary patch literal 5370 zcmV95~>_x$uSL%B9$2`B7ndQRH^Z9b2 zE@%33Y5u*QWhBE?3p2yms~I7l%%s!|esZitCaLADd;>ghrdBKU`SAtCYb!r0<@$SZ z;oWVU>-LLxcV-PIYez&bHlc}4cxV&a z*o22RVSaTzYv99Qjk;8;uy!0)g_bs6sl|Hk7DYqms~dMn_4H|1&zyGk?CDmouxw_Z zpcdoRP%oF4FI>H|N+qY~mrD&~X_nSiEik{@Sb3Fd7fif3pIMa`bTqBj`%*2=Z`65N z!LNV*d0P9X`C^_w$9vqL&-0%1n|T5KX(BbjsWd22BtzU&hNSfZVZt26NhUcJ_qOhZ zT3;0RWWuD95t1^@wN4bKMp2T%fi!VcCxm32WGT^t+BDIU5SLQyOy)DHG51A@%5qE> z%j(zhp}tj9{?}aP?xA{6Iqfb>C>FN*z@h~ZjH`|##h_N8<}HgL~oy}yzraiq92Nj-tAtp&ebZPUCtN9g#fX= z3mmXq%jwPifQOr_MHM1j_qUw(^GkO>Y5ZUSQ~7fBVQp3_Hy0nw+N?j!muj)NQs*DO z|K;z$|NQO$eK^ng_2&=J8u|Le!u^@gHh#ovy0_w{-5O{FS+n;KD(LBzQ6uF-Kx^0Tv3cNpJ0?^8Q07(_#C+o*pefGt%r9LVhm->k`CC?BY~}rust={+D6zaNhu=|<_$@TIFlM> zne<3XZxoX*$zU4M4qJ94u-z7*z=+zOur-CZhKg(c5V8N=FGSKqkm09?+|Kf-AuTm0hULk;rrBPh_glg7--t;)o`>Arepu{xOXY5NU3N_U=^O^l=Jej#+npmH+tE2Ldx&=S5RpB4vd4T;+_G>6>b0)x zqaS`Mr9HcAGxz88?#fKvOO`iz$6NmK`)|#!DNY8x=bq|BDkMpSMV@kpxXCi4d5V$8 zm}16+mYq{2Ie49_H`f)nJq^``bgpUXTsvOpnzqiheL6SOKBb^CKrkVOnS?U}W)cJ= zD>6zP%e>8K*Om`{mJiu)YW?p^*VMWB=4?4)!@G1xG}5FLna!YECSOhd2B=v~{^zR+ zexH1?&#qMw$6B&fB@Fk zpSO*^y%Kdjam!(1c%kome9u#22M}s?P8!a(=@$FcCPcBh-KMqj#%8Hwi}@d-a%{7v zYX(PkEM4E^vGE2q9^j6~rW@3B-;5-xx96L7sp)Q`)LuQ zG&4dbD6xhiSi~=cSCk_~oY02pL~hGmYoAx+7tp%Ed_UN}e)OIG-z|5Hz9aadI+hup zN(tpxpkm-Q;YfS$k&s?;sGS10lW8zbHCL;tYP;3TSIf)qE|-7&;?AL(&8!l{QdPl` z$TB9O+qBoP4k>V2i0V_M661xW*f1}fSm?H~6z%nGt8<6jO#_zimlVK@Yo=l8!d$f9 zvF{A9cS)n;PW`B>cjl?RcF)15?!VT;O?DlD>i(l$7qG4ic<@~T?h%; zP;Q6H-7LTJhQRR@wEWmwi{>TC2AZ!dQzjzlp32H);2m9ynT@aMdCRdcdR{F*x$=_= zR^1=4^q4PUzhkDC_VeC)XYb*#vqv5X%q_`tqEio=(Imz2sq)gn{y9@f!FGf8+*l@( zt}=F(Fn}@xwxQZ?sgY0*DCtwP#@zzfiSS0+n>W(P8^_8UfzYhFrWr~>T~3g7!ho|x zYLF^81nfSqz|jZnLOv36kkes~x;@hXyEG5ItuJ-7OcuLlhqL=qJ?_QcAAmjC|DMqI z_=_-J>Ds0p4DT1=MckVgapc7(+(aIVo?#;}U<11Lrt07cA*9e+bZ%mIl%`>WduVYU z^zv33!V-$8%3vcLlr2aSr4cyVX(DKvxs;9c^`LLw?hN0kc~a9fq+$0aGvk!IDxL#j zd%_ZhWR`%J9JVRa%sRwKYOI2FJLAWi@vu1K>9K`%T-Y14X;YA+Cx}7#23Y9^ zcMF|@yG7!hahk#3T#ST1egx=iIs377na|ntF_-y%yt&NhU*CTQzqRM?3E;NDPA^0A zM|TxwRP--TdC#Z9^mENZ__`jx(VN(eVXx>E{x{y%;t(*jzxy~4w3*BU?rL{9Xl}b^ zM}#4RDs+v#m(jMdBl8QlnniF#aE!73Jn;Lxy#PfWORQkXdJ92j5&`?Inj3_{4U7p0 z*kO4ZOnNI5u^q8pi&9u}h|q}8s57(IJ^ z4x5?;^m8#R!n(E`Ko2WW!Sf6tQEhNulm25Kk_k|-^}<}fL7kd9)IY0;xBait;-ogmYq(CXmC3 zaqvZhWGJBbS4d-#A^{=$g%Um?)V~^+tr8T`XKNDbvvuk&S@c9)D+KQ=Ar8DFN{C;^ zC5!fvh$Y5q4S6UK#O&7~P^2CtVu?W}O;WA|WU{h6x@1L{tQy`(dh4Od+yjoGCu znK=#NCF67=4uCkFh|q}BiKh|1vZI+!JgeH(WY`g!dX%aRq>jvKq$LE2z=CJiU@0-> zuUqZvO8H~A+LZ*VUHM^Z*R#MD>1d6xme@)}S7dXOU{mzAcv@soJ8eVXoprWL)2Nwl zB&sbRjiygYi=uZg8f#51#DP&(8#$<@IQ_NDeTMZYDJNRVHAuzU!94wR%tMk-1Pb^1 ziJAM!nL`E5lMu}p>Ar+^U&6XC;b(^!9UUL#Km~M&9XLcAG9Wd@koBzIn#RGZ#$hQc zh9s#NsBIjiZajT`W1v+&BwX022`j>sm|+U(L5iQ=nxB9YKL@vjOHyAGyQBlyC5`NI z{Ol48*OAj&^xzkAiv03hc&z-vLH7F-OpjC-PnhgtV7iN7@V^*~;sGp*Ba6OB`ioFB z4IP~WlVA)?gArPrL8&l?Cc_9V)R+k|dVF_iv}>OvLB_$RZ|ElLlZDwEHBm+nM6!s; zC-U)m2~3V1p=s0v8a+_Sj|P?P!5cajD8oE|dO>@Drf(P!!*zYblzqdreS_3}L-c)v zg$l!@3jM?i!{rJ?1q*{E3&TVU;j)EL;X<%6}Q!9u^MNyjl%>=$wW zhH*bc0vG}Fb^$X~7&ufK7*^ZcQKl}z9!GtAjt0sEhX@77orK4kN}Q=gHhoUBbIhb4 zqu8e7CH{!EX&f8}*`~b!{058^*274lLyVMu2|}W|63vy!$ZyA789$ZC7zWFT=|qMn z6p5=UabO*`s?sCT$Sbl~Xb+2pKjX4WpTr}FZYcNqu#c92q`lP5=UQqL;F{UbpQ=wX zl6#dwdxwYSBnkA>Y-*^6`XBC^__xI}pKSaSzc_gZ#X}OT9I=C9aH^H)p*ZQCdga92 z59x_H2!$nA8Iq7CO50!D{gC#PAC8fnCAuF*_rtzPTZW`kO-LQ1`^#s;!GvZR3e+knM#DrY0p%~OkOj_bY{c^HgP%=2S=RFM8L%9 z%=5{w*^y0WUO;nwXd0Z?rMd1-i8DxZ&7VqheHLTT91$dGrwA1&DO3=h1B*q+I}`@783Td%GuVY?hH=+C@5=uS0A-SK#?B4 zARMMZX^N3ci_pXlvM(u>z`l;tnkf(AVfB^*XzMw5Iv>%(cU_W>#WN>pUvjYiRJE0@z!g) z&F(G(`!?+RSkrv>%L=NlZM&BR9>AXwU8pb(l3qX*WGWfR-#7zLT;VxMJ;V>OwcXb> z)sJ1+*@fLTu<1Z7#BU}xg-n1T=LJn% z=7_^-8nCAtyYSQ(>qsY~?WB^6d=|w6FX5g*mBX6<20!vbXjWk|m z(LLdfahj~-p4dNu71K+iA+|G4z++v()HB0aAy0fJ zG}l^c8aXX;THUJL08Z19)8c@5Y6k@&S;8btIa%R6$`ye)QbbGSwbHOVs?)SKp=;Fg zp0bvg)H{;WF<8w+R*P_dY7YftA^3xn96?&kB3Cv1BU48p3kSyxCtjsiMplcgwjYXX zAgkHPYH@-288$wJ76#LdAcKpIPi+MRO=GaOMqsl*d74O?y>4QAwLSepq*m`BpB+kS zv+Y#vk<=p8>o-0Roq$r75XLO{#S`g^AYKy0x)9EbdPTtL%6ZrUjg0n|7>(94niz%A zt|O^Mw0D!*q4Q6hf_TFbu*(z$zXOg50u={@Gg49z#!VuLj-(bzEgg&0!eVQ-#H(!J zok#cRB(o4xkfTH}SW%JSBP}7(DWbw=puyysmk+TvH%~AAEuX&=1h_u%?@qAa?Et?! zL4C_mx4VIS%VTBS*-sFM7DCOId&}&#OEko5p*#Lv~%gMtRVDhh_&T#*lPX18q$@%<~TTGypC-e1WvCJoGvU2B8FSl}b zQowRz9YpAzUl(tiFD8(+WU_!*xp4N^Z@&Bf{p8y})z$T-`*L!vAPTR88{sbish#9Y zc$&Xin6m7D%<~IC1{}I7!u`${=M%MBsn6vJc str: - """Adds a unique string to `name`.""" - maximal_length = 200 - unique_sufix = str(uuid.uuid1()).replace("-", "_") - if len(name) > (maximal_length - len(unique_sufix)): - name = name[: (maximal_length - len(unique_sufix) - 1)] - return f"{name}_{unique_sufix}" - - @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): - access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + """Fuses two conditional blocks that share the same condition variable and are + not dependent to each other (i.e. the output of one of them is used as input to the other) + into a single conditional block. + The motivation for this transformation is to reduce the number of conditional blocks + which generate if statements in the CPU or GPU code, which can lead to performance improvements. + Example: + Before fusion: + ``` + if __cond: + __output1 = __arg1 * 2.0 + else: + __output1 = __arg2 + 3.0 + if __cond: + __output2 = __arg3 - 1.0 + else: + __output2 = __arg4 / 4.0 + ``` + After fusion: + ``` + if __cond: + __output1 = __arg1 * 2.0 + __output2 = __arg3 - 1.0 + else: + __output1 = __arg2 + 3.0 + __output2 = __arg4 / 4.0 + ``` + """ + + conditional_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) first_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) second_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) @classmethod def expressions(cls) -> Any: - map_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() - map_fusion_parallel_match.add_nedge( - cls.access_node, cls.first_conditional_block, dace.Memlet() + conditionalblock_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() + conditionalblock_fusion_parallel_match.add_nedge( + cls.conditional_access_node, cls.first_conditional_block, dace.Memlet() ) - map_fusion_parallel_match.add_nedge( - cls.access_node, cls.second_conditional_block, dace.Memlet() + conditionalblock_fusion_parallel_match.add_nedge( + cls.conditional_access_node, cls.second_conditional_block, dace.Memlet() ) - return [map_fusion_parallel_match] + return [conditionalblock_fusion_parallel_match] def can_be_applied( self, @@ -51,12 +69,19 @@ def can_be_applied( sdfg: dace.SDFG, permissive: bool = False, ) -> bool: - access_node: dace_nodes.AccessNode = self.access_node - access_node_desc = access_node.desc(sdfg) + conditional_access_node: dace_nodes.AccessNode = self.conditional_access_node + conditional_access_node_desc = conditional_access_node.desc(sdfg) first_cb: dace_nodes.NestedSDFG = self.first_conditional_block second_cb: dace_nodes.NestedSDFG = self.second_conditional_block scope_dict = graph.scope_dict() + # Check that the common access node is a boolean scalar + if ( + not isinstance(conditional_access_node_desc, dace.data.Scalar) + or conditional_access_node_desc.dtype != dace.bool_ + ): + return False + # Check that both conditional blocks are in the same parent SDFG if first_cb.sdfg.parent != second_cb.sdfg.parent: return False @@ -67,15 +92,8 @@ def can_be_applied( ): return False - # Check that the common access node is a boolean scalar - if ( - not isinstance(access_node_desc, dace.data.Scalar) - or access_node_desc.dtype != dace.bool_ - ): - return False - # Make sure that the conditional blocks contain only one conditional block each - if len(first_cb.sdfg.nodes()) > 1 or len(second_cb.sdfg.nodes()) > 1: + if first_cb.sdfg.number_of_nodes() > 1 or second_cb.sdfg.number_of_nodes() > 1: return False # Get the actual conditional blocks @@ -83,7 +101,22 @@ def can_be_applied( second_conditional_block = next(iter(second_cb.sdfg.nodes())) if not ( isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) + and len(first_conditional_block.sub_regions()) == 2 and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock) + and len(second_conditional_block.sub_regions()) == 2 + ): + return False + first_conditional_block_state_names = [ + state.name for state in first_conditional_block.all_states() + ] + second_conditional_block_state_names = [ + state.name for state in second_conditional_block.all_states() + ] + if not ( + "true_branch" in first_conditional_block_state_names + and "false_branch" in first_conditional_block_state_names + and "true_branch" in second_conditional_block_state_names + and "false_branch" in second_conditional_block_state_names ): return False @@ -96,16 +129,23 @@ def can_be_applied( return False # Check that there is an edge to the conditional blocks with dst_conn == "__cond" - cond_edges_first = [e for e in graph.in_edges(first_cb) if e.dst_conn == "__cond"] + cond_edges_first = [ + e for e in graph.in_edges(first_cb) if e.dst_conn and e.dst_conn == "__cond" + ] if len(cond_edges_first) != 1: return False - cond_edges_second = [e for e in graph.in_edges(second_cb) if e.dst_conn == "__cond"] + cond_edges_second = [ + e for e in graph.in_edges(second_cb) if e.dst_conn and e.dst_conn == "__cond" + ] if len(cond_edges_second) != 1: return False cond_edge_first = cond_edges_first[0] cond_edge_second = cond_edges_second[0] - if cond_edge_first.src != cond_edge_second.src and ( - cond_edge_first.src != access_node or cond_edge_second.src != access_node + if cond_edge_first.data.is_empty() or cond_edge_second.data.is_empty(): + return False + if not all( + cond_edge.src is conditional_access_node + for cond_edge in [cond_edge_first, cond_edge_second] ): return False @@ -128,7 +168,7 @@ def apply( graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, ) -> None: - access_node: dace_nodes.AccessNode = self.access_node + conditional_access_node: dace_nodes.AccessNode = self.conditional_access_node first_cb: dace_nodes.NestedSDFG = self.first_conditional_block second_cb: dace_nodes.NestedSDFG = self.second_conditional_block @@ -136,12 +176,8 @@ def apply( second_conditional_block = next(iter(second_cb.sdfg.nodes())) # Store original arrays to check later that all the necessary arrays have been moved - original_arrays_first_conditional_block = {} - for data_name, data_desc in first_conditional_block.sdfg.arrays.items(): - original_arrays_first_conditional_block[data_name] = data_desc - original_arrays_second_conditional_block = {} - for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): - original_arrays_second_conditional_block[data_name] = data_desc + original_arrays_first_conditional_block = first_conditional_block.sdfg.arrays.copy() + original_arrays_second_conditional_block = second_conditional_block.sdfg.arrays.copy() total_original_arrays = len(original_arrays_first_conditional_block) + len( original_arrays_second_conditional_block ) @@ -153,7 +189,7 @@ def apply( if data_name == "__cond": continue if data_name in original_arrays_first_conditional_block: - new_data_name = unique_name(data_name) + new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" second_arrays_rename_map[data_name] = new_data_name data_desc_renamed = copy.deepcopy(data_desc) data_desc_renamed.name = new_data_name @@ -171,25 +207,25 @@ def apply( out_connectors_to_move = second_cb.out_connectors in_connectors_to_move_rename_map = {} out_connectors_to_move_rename_map = {} - for k, _v in in_connectors_to_move.items(): - new_connector_name = k + for original_in_connector_name, _v in in_connectors_to_move.items(): + new_connector_name = original_in_connector_name if new_connector_name in first_cb.in_connectors: - new_connector_name = second_arrays_rename_map[k] - in_connectors_to_move_rename_map[k] = new_connector_name + new_connector_name = second_arrays_rename_map[original_in_connector_name] + in_connectors_to_move_rename_map[original_in_connector_name] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): - if edge.dst_conn == k: + if edge.dst_conn == original_in_connector_name: dace_helpers.redirect_edge( state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb ) - for k, _v in out_connectors_to_move.items(): - new_connector_name = k + for original_out_connector_name, _v in out_connectors_to_move.items(): + new_connector_name = original_out_connector_name if new_connector_name in first_cb.out_connectors: - new_connector_name = second_arrays_rename_map[k] - out_connectors_to_move_rename_map[k] = new_connector_name + new_connector_name = second_arrays_rename_map[original_out_connector_name] + out_connectors_to_move_rename_map[original_out_connector_name] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): - if edge.src_conn == k: + if edge.src_conn == original_out_connector_name: dace_helpers.redirect_edge( state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb ) @@ -207,10 +243,6 @@ def _find_corresponding_state_in_second( elif not true_branch and "false_branch" in state.name: corresponding_state_in_second = state break - if corresponding_state_in_second is None: - raise RuntimeError( - f"Could not find corresponding state in second conditional block for state {inner_state_name}" - ) return corresponding_state_in_second # Copy first the nodes from the second conditional block to the first @@ -224,8 +256,6 @@ def _find_corresponding_state_in_second( if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) - new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) - new_desc.name = new_data_name nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) @@ -255,7 +285,7 @@ def _find_corresponding_state_in_second( else edge.dst_conn, new_memlet, ) - for edge in list(graph.out_edges(access_node)): + for edge in list(graph.out_edges(conditional_access_node)): if edge.dst == second_cb: graph.remove_edge(edge) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index e3f417276e..3a18c40353 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -8,6 +8,7 @@ """Common functionality for the transformations/optimization pipeline.""" +import uuid from typing import Optional, Sequence, TypeVar, Union import dace @@ -20,6 +21,15 @@ _PassT = TypeVar("_PassT", bound=dace_ppl.Pass) +def unique_name(name: str) -> str: + """Adds a unique string to `name`.""" + maximal_length = 200 + unique_sufix = str(uuid.uuid1()).replace("-", "_") + if len(name) > (maximal_length - len(unique_sufix)): + name = name[: (maximal_length - len(unique_sufix) - 1)] + return f"{name}_{unique_sufix}" + + def gt_make_transients_persistent( sdfg: dace.SDFG, device: dace.DeviceType, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py index 3f686ae583..03e94b2858 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py @@ -21,7 +21,7 @@ def _make_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("test")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("test")) state = sdfg.add_state(is_start_block=True) for name in "abcde": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py index 53bba67408..7bb46d302b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py @@ -29,7 +29,7 @@ def _make_simple_linear_chain_sdfg() -> dace.SDFG: All intermediates have the same size. """ - sdfg = dace.SDFG(util.unique_name("simple_linear_chain_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_linear_chain_sdfg")) for name in ["a", "b", "c", "d", "e"]: sdfg.add_array( @@ -75,7 +75,7 @@ def _make_diff_sizes_pull_chain_sdfg() -> tuple[ - The AccessNode that is used as final output, refers to `e`. - The Tasklet that is within the Map. """ - sdfg = dace.SDFG(util.unique_name("diff_size_linear_pull_chain_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("diff_size_linear_pull_chain_sdfg")) array_size_increment = 10 array_size = 10 @@ -118,7 +118,7 @@ def _make_diff_sizes_push_chain_sdfg() -> tuple[ Same as `_make_simple_linear_pull_chain_sdfg()` but the intermediates become smaller and smaller, so the full shape of the destination array is always written. """ - sdfg = dace.SDFG(util.unique_name("diff_size_linear_push_chain_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("diff_size_linear_push_chain_sdfg")) array_size_decrement = 10 array_size = 50 @@ -155,7 +155,7 @@ def _make_diff_sizes_push_chain_sdfg() -> tuple[ def _make_multi_stage_reduction_sdfg() -> dace.SDFG: """Creates an SDFG that has a two stage copy reduction.""" - sdfg = dace.SDFG(util.unique_name("multi_stage_reduction")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multi_stage_reduction")) state: dace.SDFGState = sdfg.add_state(is_start_block=True) # This is the size of the arrays, if not mentioned here, then its size is 10. @@ -221,7 +221,7 @@ def _make_not_fully_copied() -> dace.SDFG: Make an SDFG where two intermediate array is not fully copied. Thus the transformation only applies once, when `d` is removed. """ - sdfg = dace.SDFG(util.unique_name("not_fully_copied_intermediate")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("not_fully_copied_intermediate")) for name in ["a", "b", "c", "d", "e"]: sdfg.add_array( @@ -257,7 +257,7 @@ def _make_possible_cyclic_sdfg() -> dace.SDFG: If the transformation would remove `a1` then it would create a cycle. Thus the transformation should not apply. """ - sdfg = dace.SDFG(util.unique_name("possible_cyclic_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("possible_cyclic_sdfg")) anames = ["i1", "a1", "a2", "o1"] for name in anames: @@ -331,7 +331,7 @@ def make_inner_sdfg() -> dace.SDFG: inner_sdfg = make_inner_sdfg() - sdfg = dace.SDFG(util.unique_name("linear_chain_with_nested_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("linear_chain_with_nested_sdfg")) state = sdfg.add_state(is_start_block=True) array_size_increment = 10 @@ -370,7 +370,7 @@ def make_inner_sdfg() -> dace.SDFG: def _make_a1_has_output_sdfg() -> dace.SDFG: """Here `a1` has an output degree of 2, one to `a2` and one to another output.""" - sdfg = dace.SDFG(util.unique_name("a1_has_an_additional_output_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("a1_has_an_additional_output_sdfg")) state = sdfg.add_state(is_start_block=True) # All other arrays have a size of 10. @@ -411,7 +411,9 @@ def _make_copy_chain_with_reduction_node( dace_nodes.AccessNode, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name("copy_chain_remover_with_reduction_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("copy_chain_remover_with_reduction_sdfg") + ) state = sdfg.add_state(is_start_block=True) if output_an_array: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py index 0db2706c0d..4bb61a01ad 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py @@ -73,7 +73,7 @@ def _create_sdfg_double_read_part_2( def _create_sdfg_double_read( version: int, ) -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"double_read_version_{version}")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"double_read_version_{version}")) state = sdfg.add_state(is_start_block=True) for name in "AB": sdfg.add_array( @@ -98,7 +98,7 @@ def _create_sdfg_double_read( def _create_non_scalar_read() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"non_scalar_read_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"non_scalar_read_sdfg")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -145,7 +145,7 @@ def test_local_double_buffering_double_read_sdfg(): def test_local_double_buffering_no_connection(): """There is no direct connection between read and write.""" - sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_connection")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("local_double_buffering_no_connection")) state = sdfg.add_state(is_start_block=True) for name in "AB": sdfg.add_array( @@ -212,7 +212,7 @@ def test_local_double_buffering_no_connection(): def test_local_double_buffering_no_apply(): """Here it does not apply, because are all distinct.""" - sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("local_double_buffering_no_apply")) state = sdfg.add_state(is_start_block=True) for name in "AB": sdfg.add_array( @@ -237,7 +237,7 @@ def test_local_double_buffering_no_apply(): def test_local_double_buffering_already_buffered(): """It is already buffered.""" - sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("local_double_buffering_no_apply")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( "A", diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py index ef13fd6ea7..375f8da1d5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py @@ -19,7 +19,7 @@ def _make_empty_memlets_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("empty_memlets")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("empty_memlets")) state = sdfg.add_state(is_start_block=True) anames = ["a", "b", "c"] @@ -46,7 +46,7 @@ def _make_empty_memlets_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: def _make_zero_iter_step_map() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("empty_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("empty_map")) state = sdfg.add_state(is_start_block=True) anames = ["a", "b", "c"] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index cadf525159..a2673295a8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -21,7 +21,7 @@ def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("distributed_buffer_sdfg")) for name in ["a", "b", "tmp"]: sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) @@ -84,7 +84,9 @@ def test_distributed_buffer_remover(): def _make_distributed_buffer_global_memory_data_race_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_global_memory_data_race") + ) arr_names = ["a", "b", "t"] for name in arr_names: sdfg.add_array( @@ -136,7 +138,9 @@ def test_distributed_buffer_global_memory_data_race(): def _make_distributed_buffer_global_memory_data_race_sdfg2() -> tuple[ dace.SDFG, dace.SDFGState, dace.SDFGState ]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race2_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_global_memory_data_race2_sdfg") + ) arr_names = ["a", "b", "t"] for name in arr_names: sdfg.add_array( @@ -189,7 +193,9 @@ def test_distributed_buffer_global_memory_data_race2(): def _make_distributed_buffer_global_memory_data_no_rance() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg") + ) arr_names = ["a", "t"] for name in arr_names: sdfg.add_array( @@ -235,7 +241,11 @@ def test_distributed_buffer_global_memory_data_no_rance(): def _make_distributed_buffer_global_memory_data_no_rance2() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance2_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name( + "distributed_buffer_global_memory_data_no_rance2_sdfg" + ) + ) arr_names = ["a", "t"] for name in arr_names: sdfg.add_array( @@ -291,7 +301,9 @@ def test_distributed_buffer_global_memory_data_no_rance2(): def _make_distributed_buffer_non_sink_temporary_sdfg() -> tuple[ dace.SDFG, dace.SDFGState, dace.SDFGState ]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_non_sink_temporary_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_non_sink_temporary_sdfg") + ) state = sdfg.add_state(is_start_block=True) wb_state = sdfg.add_state_after(state) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py index da70155cf3..6579f1219e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py @@ -25,7 +25,7 @@ def _make_double_write_single_consumer( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.AccessNode, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("double_write_elimination_1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_write_elimination_1")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -148,7 +148,7 @@ def _make_double_write_multi_consumer( dace_nodes.AccessNode, dace_nodes.MapExit, ]: - sdfg = dace.SDFG(util.unique_name("double_write_elimination_2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_write_elimination_2")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -255,7 +255,9 @@ def _make_double_write_multi_producer_map( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.AccessNode, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("double_write_elimination_multi_producer_map")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("double_write_elimination_multi_producer_map") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index ec23b4ee24..39004d3786 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -32,7 +32,7 @@ def _make_if_block_with_tasklet( b2_type: dace.typeclass = dace.float64, output_type: dace.typeclass = dace.float64, ) -> dace_nodes.NestedSDFG: - inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) + inner_sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_stmt_")) types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} for name in {b1_name, b2_name, cond_name, output_name}: @@ -42,7 +42,7 @@ def _make_if_block_with_tasklet( transient=False, ) - if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) + if_region = dace.sdfg.state.ConditionalBlock(gtx_transformations.utils.unique_name("if")) inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) @@ -87,7 +87,7 @@ def _make_if_block_with_tasklet( def _make_map_with_conditional_blocks() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("map_with_conditional_blocks")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_with_conditional_blocks")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index dd4927bb03..2fcacd191d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -15,6 +15,7 @@ from gt4py.next.program_processors.runners.dace.transformations import ( gpu_utils as gtx_dace_fieldview_gpu_utils, + utils as gtx_transformations_utils, ) from . import util @@ -42,7 +43,7 @@ def _get_trivial_gpu_promotable( tasklet_code: The body of the Tasklet inside the trivial map. trivial_map_range: Range of the trivial map, defaults to `"0"`. """ - sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) @@ -155,7 +156,7 @@ def test_trivial_gpu_map_promoter_2(): @pytest.mark.parametrize("method", [0, 1]) def test_set_gpu_properties(method: int): """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()`.""" - sdfg = dace.SDFG(util.unique_name("gpu_properties_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_properties_test")) state = sdfg.add_state(is_start_block=True) map_entries: dict[int, dace_nodes.MapEntry] = {} @@ -232,7 +233,7 @@ def test_set_gpu_properties(method: int): def test_set_gpu_properties_1D(): """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()` with 1D maps.""" - sdfg = dace.SDFG(util.unique_name("gpu_properties_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_properties_test")) state = sdfg.add_state(is_start_block=True) map_entries: dict[int, dace_nodes.MapEntry] = {} @@ -299,7 +300,7 @@ def test_set_gpu_properties_1D(): def test_set_gpu_properties_2D_3D(): """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()` with 2D, 3D and 4D maps.""" - sdfg = dace.SDFG(util.unique_name("gpu_properties_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_properties_test")) state = sdfg.add_state(is_start_block=True) map_entries: dict[int, dace_nodes.MapEntry] = {} @@ -355,7 +356,7 @@ def test_set_gpu_properties_2D_3D(): def test_set_gpu_maxnreg(): """Tests if gpu_maxnreg property is set correctly to GPU maps.""" - sdfg = dace.SDFG(util.unique_name("gpu_maxnreg_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_maxnreg_test")) state = sdfg.add_state(is_start_block=True) dim = 2 shape = (10,) * (dim - 1) + (1,) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py index d7fd220b5f..6bad175597 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py @@ -34,7 +34,7 @@ def _make_sdfg_with_multiple_maps_that_share_inputs( - Outputs: out1[i, j], out2[i, j], out3[i, j], out4[i, j] """ shape = (N, N) - sdfg = dace.SDFG(util.unique_name("multiple_maps")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_maps")) state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "c", "d", "out1", "out2", "out3", "out4"]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py index b0c7df401d..6e6afecfb0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py @@ -33,7 +33,7 @@ def _create_simple_fusable_sdfg() -> tuple[ dace_nodes.AccessNode, dace_graph.MultiConnectorEdge[dace.Memlet], ]: - sdfg = dace.SDFG(util.unique_name(f"simple_fusable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"simple_fusable_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "abc": @@ -121,7 +121,7 @@ def _make_laplap_sdfg( dace_graph.MultiConnectorEdge[dace.Memlet], dace_nodes.Tasklet, ]: - sdfg = dace.SDFG(util.unique_name(f"laplap1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"laplap1")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -223,7 +223,7 @@ def _make_multiple_value_read_sdfg( dace_nodes.MapEntry, dace_graph.MultiConnectorEdge[dace.Memlet], ]: - sdfg = dace.SDFG(util.unique_name(f"multiple_value_generator")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"multiple_value_generator")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -367,7 +367,7 @@ def test_multiple_value_exchange_partial(): def _make_sdfg_with_dref_tasklet(): - sdfg = dace.SDFG(util.unique_name(f"sdfg_with_dref_target")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"sdfg_with_dref_target")) state = sdfg.add_state(is_start_block=True) for name in "abc": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index a94e3262c5..683ad9ebf0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -32,7 +32,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np can be taken out. This is because how it is constructed. However, applying some simplistic transformations will enable the transformation. """ - sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -55,7 +55,7 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n The bottom Tasklet is the only dependent Tasklet. """ - sdfg = dace.SDFG(util.unique_name("chained_block_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("chained_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -159,7 +159,7 @@ def _get_sdfg_with_empty_memlet( is either dependent or independent), the access node between the tasklets and the second tasklet that is always dependent. """ - sdfg = dace.SDFG(util.unique_name("empty_memlet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("empty_memlet_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -404,7 +404,7 @@ def test_direct_map_exit_connection() -> dace.SDFG: Because the tasklet is connected to the map exit it can not be independent. """ - sdfg = dace.SDFG(util.unique_name("mapped_tasklet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("mapped_tasklet_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_array("a", (10,), dace.float64, transient=False) sdfg.add_array("b", (10, 30), dace.float64, transient=False) @@ -591,7 +591,7 @@ def _make_loop_blocking_sdfg_with_inner_map( The function will return the SDFG, the state and the map entry for the outer and inner map. """ - sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_inner_map")) state = sdfg.add_state(is_start_block=True) for name in "AB": @@ -756,7 +756,7 @@ def _make_loop_blocking_sdfg_with_independent_inner_map() -> tuple[ """ Creates a nested Map that is independent from the blocking parameter. """ - sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_independent_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_inner_independent_map")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40, 3), dtype=dace.float64, transient=False) @@ -833,7 +833,7 @@ def _make_loop_blocking_with_reduction( Depending on `reduction_is_dependent` the node is either dependent or not. """ sdfg = dace.SDFG( - util.unique_name( + gtx_transformations.utils.unique_name( "sdfg_with_" + ("" if reduction_is_dependent else "in") + "dependent_reduction" ) ) @@ -954,7 +954,7 @@ def _make_mixed_memlet_sdfg( - `tskl1`. - `tskl2`. """ - sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("mixed_memlet_sdfg")) state = sdfg.add_state(is_start_block=True) names_array = ["A", "B", "C"] names_scalar = ["tmp1", "tmp2"] @@ -1148,7 +1148,7 @@ def test_loop_blocking_mixed_memlets_2(): def test_loop_blocking_no_independent_nodes(): import dace - sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("mixed_memlet_sdfg")) state = sdfg.add_state(is_start_block=True) names = ["A", "B", "C"] for aname in names: @@ -1210,7 +1210,7 @@ def test_loop_blocking_no_independent_nodes(): def _make_only_last_two_elements_sdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("B", dace.int32) @@ -1279,7 +1279,7 @@ def test_blocking_size_too_big(): Here the blocking size is larger than the size in that dimension. Thus the transformation will not apply. """ - sdfg = dace.SDFG(util.unique_name("blocking_size_too_large")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("blocking_size_too_large")) state = sdfg.add_state(is_start_block=True) for name in "ab": @@ -1325,7 +1325,7 @@ def test_blocking_size_too_big(): def _make_loop_blocking_sdfg_with_semi_independent_map() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry ]: - sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_semi_independent_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_inner_semi_independent_map")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40, 3), dtype=dace.float64, transient=False) @@ -1415,7 +1415,7 @@ def test_loop_blocking_sdfg_with_semi_independent_map(): def _make_loop_blocking_only_independent_inner_map() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry ]: - sdfg = dace.SDFG(util.unique_name("sdfg_with_only_independent_inner_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_only_independent_inner_map")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40, 3), dtype=dace.float64, transient=False) @@ -1479,7 +1479,7 @@ def _make_loop_blocking_output_access_node( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.AccessNode ]: - sdfg = dace.SDFG(util.unique_name("sdfg_with_direct_output_access_node")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_direct_output_access_node")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40,), dtype=dace.float64, transient=False) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py index d8cf8e33f8..1e7ce47197 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py @@ -21,7 +21,9 @@ def _make_transients_persistent_inner_access_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("transients_persistent_inner_access_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("transients_persistent_inner_access_sdfg") + ) state = sdfg.add_state(is_start_block=True) for name in "abc": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index aa5fce9d76..1bf30f3697 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -48,7 +48,7 @@ def _make_test_sdfg( if out_offset is None: out_offset = in_offset - sdfg = dace.SDFG(util.unique_name("map_buffer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) state = sdfg.add_state(is_start_block=True) names = {input_name, tmp_name, output_name} for name in names: @@ -220,7 +220,7 @@ def test_map_buffer_elimination_offset_5(): def test_map_buffer_elimination_not_apply(): """Indirect accessing, because of this the double buffer is needed.""" - sdfg = dace.SDFG(util.unique_name("map_buffer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) state = sdfg.add_state(is_start_block=True) names = ["A", "tmp", "idx"] @@ -269,7 +269,7 @@ def test_map_buffer_elimination_with_nested_sdfgs(): stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] # top-level sdfg - sdfg = dace.SDFG(util.unique_name("map_buffer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) out, out_desc = sdfg.add_array( "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) @@ -278,14 +278,14 @@ def test_map_buffer_elimination_with_nested_sdfgs(): state = sdfg.add_state() tmp_node = state.add_access(tmp) - nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + nsdfg1 = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) state1 = nsdfg1.add_state() tmp1_node = state1.add_access(tmp1) - nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + nsdfg2 = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py index 83e0bd921a..3fd392946c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py @@ -15,6 +15,7 @@ from gt4py.next.program_processors.runners.dace.transformations import ( map_fusion_utils as gtx_map_fusion_utils, + utils as gtx_transformations_utils, ) import numpy as np @@ -24,7 +25,7 @@ def test_copy_map_graph(): N = dace.symbol("N", dace.int32) - sdfg = dace.SDFG(util.unique_name("copy_map_graph")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("copy_map_graph")) A, A_desc = sdfg.add_array("A", [N], dtype=dace.float64) B = sdfg.add_datadesc("B", A_desc.clone()) st = sdfg.add_state() @@ -103,7 +104,7 @@ def test_copy_map_graph(): def test_split_overlapping_map_range(map_ranges): first_ndrange, second_ndrange = map_ranges[0:2] - sdfg = dace.SDFG(util.unique_name("split_overlapping_map_range")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("split_overlapping_map_range")) st = sdfg.add_state() first_map_entry, _ = st.add_map("first", first_ndrange) second_map_entry, _ = st.add_map("second", second_ndrange) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py index c067e47f4f..95132c39f0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -62,7 +62,7 @@ def _perform_reorder_test( def _make_test_sdfg(map_params: list[str]) -> dace.SDFG: """Generate an SDFG for the test.""" - sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("gpu_promotable_sdfg")) state: dace.SDFGState = sdfg.add_state("state", is_start_block=True) dim = len(map_params) for aname in ["a", "b"]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py index 435c37d629..cef3058ab0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py @@ -31,7 +31,7 @@ def _make_serial_map_promotion_sdfg( ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: shape_1d = (N,) shape_2d = (N, N) - sdfg = dace.SDFG(util.unique_name("serial_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("serial_promotable_sdfg")) state = sdfg.add_state(is_start_block=True) # 1D Arrays @@ -252,7 +252,7 @@ def test_serial_map_promotion_on_symbolic_range(use_symbolic_range): def test_serial_map_promotion_2d_top_1d_bottom(): """Does not apply because the bottom map needs to be promoted.""" - sdfg = dace.SDFG(util.unique_name("serial_map_promoter_2d_map_on_top")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("serial_map_promoter_2d_map_on_top")) state = sdfg.add_state(is_start_block=True) # 2D Arrays @@ -311,7 +311,7 @@ def test_serial_map_promotion_2d_top_1d_bottom(): def _make_horizontal_promoter_sdfg( d1_map_is_vertical: bool, ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: - sdfg = dace.SDFG(util.unique_name("serial_map_promoter_tester")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("serial_map_promoter_tester")) state = sdfg.add_state(is_start_block=True) h_idx = gtx_dace_lowering.get_map_variable(gtx_common.Dimension("boden")) @@ -447,7 +447,9 @@ def test_horizonal_promotion_promotion_and_merge(d1_map_is_vertical: bool): def _make_sdfg_different_1d_map_name( d1_map_param: str, ) -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("serial_map_promoter_different_names_" + d1_map_param)) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("serial_map_promoter_different_names_" + d1_map_param) + ) state = sdfg.add_state(is_start_block=True) # 1D Arrays diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py index 6d569e79b5..28f0653922 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py @@ -22,7 +22,7 @@ def _make_sdfg_simple() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("simple_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "abc": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py index f4c9020de4..d6dc0ae54a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py @@ -21,7 +21,7 @@ def _make_sdfg_1( consumer_is_map: bool, ) -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("simple_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_sdfg")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 5338b49920..89f8439472 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -37,7 +37,7 @@ def _make_if_block( b2_type: dace.typeclass = dace.float64, output_type: dace.typeclass = dace.float64, ) -> dace_nodes.NestedSDFG: - inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) + inner_sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_stmt_")) types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} for name in {b1_name, b2_name, cond_name, output_name}: @@ -47,7 +47,7 @@ def _make_if_block( transient=False, ) - if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) + if_region = dace.sdfg.state.ConditionalBlock(gtx_transformations.utils.unique_name("if")) inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) @@ -126,7 +126,7 @@ def test_if_mover_independent_branches(): d = b ``` """ - sdfg = dace.SDFG(util.unique_name("independent_branches")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("independent_branches")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -237,7 +237,7 @@ def test_if_mover_independent_branches(): def test_if_mover_invalid_if_block(): - sdfg = dace.SDFG(util.unique_name("invalid")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("invalid")) state = sdfg.add_state(is_start_block=True) input_names = ["a", "b", "c", "d"] @@ -346,7 +346,7 @@ def test_if_mover_dependent_branch_1(): d = b ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_dependent_branches")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -487,7 +487,7 @@ def test_if_mover_dependent_branch_2(): d = b1 ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_dependent_branches_2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches_2")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -587,7 +587,7 @@ def test_if_mover_dependent_branch_3(): Very similar test to `test_if_mover_dependent_branch_1()`, but the common data is an AccessNode outside the Map. """ - sdfg = dace.SDFG(util.unique_name("if_mover_dependent_branches")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches")) state = sdfg.add_state(is_start_block=True) gnames = ["a", "b", "c", "d", "cond"] @@ -683,7 +683,7 @@ def test_if_mover_no_ops(): ``` I.e. there is no gain from moving something inside the body. """ - sdfg = dace.SDFG(util.unique_name("if_mover_no_ops")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_no_ops")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -743,7 +743,7 @@ def test_if_mover_one_branch_is_nothing(): I.e. in one case something can be moved in but there is nothing to move for the other branch. """ - sdfg = dace.SDFG(util.unique_name("if_mover_one_branch_is_nothing")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_one_branch_is_nothing")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -833,7 +833,7 @@ def test_if_mover_chain(): e = aa if cc else bb ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_chain_of_blocks")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -947,7 +947,7 @@ def test_if_mover_chain(): def test_if_mover_symbolic_tasklet(): - sdfg = dace.SDFG(util.unique_name("if_mover_symbols_in_tasklets")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_symbols_in_tasklets")) state = sdfg.add_state(is_start_block=True) for i in [1, 2]: @@ -1062,7 +1062,7 @@ def test_if_mover_access_node_between(): e = aa if cc else bb ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_chain_of_blocks")) state = sdfg.add_state(is_start_block=True) # Inputs diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py index 2ca57691f2..51a577e467 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py @@ -26,7 +26,7 @@ def _make_movable_tasklet( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.MapEntry ]: - sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_scalar("outer_scalar", dtype=dace.float64, transient=True) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py index fc07d3a831..30dd8de029 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py @@ -32,7 +32,7 @@ def apply_distributed_self_copy_elimination( def _make_not_apply_because_of_write_to_g_sdfg() -> dace.SDFG: """This SDFG is not eligible, because there is a write to `G`.""" - sdfg = dace.SDFG(util.unique_name("not_apply_because_of_write_to_g_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("not_apply_because_of_write_to_g_sdfg")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -83,7 +83,7 @@ def _make_eligible_sdfg_1() -> dace.SDFG: The main difference is that there is no mutating write to `a` and thus the transformation applies. """ - sdfg = dace.SDFG(util.unique_name("_make_eligible_sdfg_1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("_make_eligible_sdfg_1")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -126,7 +126,7 @@ def _make_eligible_sdfg_1() -> dace.SDFG: def _make_multiple_temporaries_sdfg1() -> dace.SDFG: """Generates an SDFG in which `G` is saved into different temporaries.""" - sdfg = dace.SDFG(util.unique_name("multiple_temporaries")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_temporaries")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -170,7 +170,7 @@ def _make_multiple_temporaries_sdfg2() -> dace.SDFG: generated by `_make_multiple_temporaries_sdfg()` is that the temporaries are used sequentially. """ - sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_temporaries_sequential")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -281,7 +281,7 @@ def _make_non_eligible_because_of_pseudo_temporary() -> dace.SDFG: Note that in this particular case it would be possible, but we do not support it. """ - sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_temporaries_sequential")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -321,7 +321,7 @@ def _make_wb_single_state_sdfg() -> dace.SDFG: This pattern is handled by the `SingleStateGlobalSelfCopyElimination` transformation. """ - sdfg = dace.SDFG(util.unique_name("single_state_write_back_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("single_state_write_back_sdfg")) sdfg.add_array("g", shape=(10,), dtype=dace.float64, transient=False) sdfg.add_array("t", shape=(10,), dtype=dace.float64, transient=True) @@ -351,7 +351,7 @@ def _make_wb_single_state_sdfg() -> dace.SDFG: def _make_non_eligible_sdfg_with_branches(): """Creates an SDFG with two different definitions of `T`.""" - sdfg = dace.SDFG(util.unique_name("non_eligible_sdfg_with_branches_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("non_eligible_sdfg_with_branches_sdfg")) # This is the `G` array, it is also used as output. sdfg.add_array("a", shape=(10,), dtype=dace.float64, transient=False) @@ -400,7 +400,7 @@ def _make_write_into_global_at_t_definition() -> tuple[ This SDFG is different from `_make_not_apply_because_of_write_to_g_sdfg` as the write happens before we define `t`. """ - sdfg = dace.SDFG(util.unique_name("write_into_global_at_t_definition")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("write_into_global_at_t_definition")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py index b763393253..963b3321bc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py @@ -22,7 +22,7 @@ def test_complex_copies_global_access_node(): N = 64 K = 80 - sdfg = dace.SDFG(util.unique_name("vertically_implicit_solver_like_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("vertically_implicit_solver_like_sdfg")) A, _ = sdfg.add_array("A", [N, K + 1], dtype=dace.float64) B, _ = sdfg.add_array("B", [N, K + 1], dtype=dace.float64) tmp0, _ = sdfg.add_temp_transient([N], dtype=dace.float64) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py index a755504f40..e7aa91c517 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py @@ -24,7 +24,7 @@ def _make_map_with_scalar_copies() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("scalar_elimination")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("scalar_elimination")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py index 04312fb7c5..d86529f79d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py @@ -25,7 +25,7 @@ def _make_sdfg_with_map_with_view( use_array_as_temp: bool = False, ) -> dace.SDFG: shape = (N, N) - sdfg = dace.SDFG(util.unique_name("simple_map_with_view")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_map_with_view")) state = sdfg.add_state(is_start_block=True) for name in ["a", "out"]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py index 306f05454e..76e99a4678 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py @@ -25,7 +25,7 @@ def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: """Generates an SDFG that contains the self copying pattern.""" - sdfg = dace.SDFG(util.unique_name("self_copy_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("self_copy_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "GT": @@ -46,7 +46,7 @@ def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: def _make_direct_self_copy_elimination_used_sdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("direct_self_copy_elimination_used")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("direct_self_copy_elimination_used")) state = sdfg.add_state(is_start_block=True) for name in "ABCG": @@ -88,7 +88,7 @@ def _make_self_copy_sdfg_with_multiple_paths() -> tuple[ `SingleStateGlobalDirectSelfCopyElimination` transformation can not handle this case, but its split node can handle it. """ - sdfg = dace.SDFG(util.unique_name("self_copy_sdfg_with_multiple_paths")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("self_copy_sdfg_with_multiple_paths")) state = sdfg.add_state(is_start_block=True) for name in "GT": @@ -132,7 +132,9 @@ def _make_concat_where_like( j_idx = 0 j_range = "1:10" - sdfg = dace.SDFG(util.unique_name(f"self_copy_sdfg_concat_where_like_{j_desc}_level")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name(f"self_copy_sdfg_concat_where_like_{j_desc}_level") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -215,7 +217,11 @@ def _make_concat_where_like_with_silent_write_to_g1( dace_nodes.AccessNode, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name(f"self_copy_sdfg_concat_where_like_multiple_writes_to_g1")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name( + f"self_copy_sdfg_concat_where_like_multiple_writes_to_g1" + ) + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -307,7 +313,9 @@ def _make_concat_where_like_41_to_60( dace_nodes.AccessNode, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name(f"self_copy_sdfg_concat_where_like_41_to_60")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name(f"self_copy_sdfg_concat_where_like_41_to_60") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -443,7 +451,7 @@ def _make_concat_where_like_not_possible() -> tuple[ """Because the "Bulk Map" writes more into `tmp` than is written back the transformation is not applicable. """ - sdfg = dace.SDFG(util.unique_name(f"self_copy_too_big_bulk_write")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"self_copy_too_big_bulk_write")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -499,7 +507,7 @@ def _make_multi_t_patch_sdfg() -> tuple[ uninitialized because it is not read. """ - sdfg = dace.SDFG(util.unique_name(f"multi_t_patch_description_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"multi_t_patch_description_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "gabt": @@ -582,7 +590,7 @@ def _make_not_everything_is_written_back( in fact running `SplitAccessNode` would do the job. """ - sdfg = dace.SDFG(util.unique_name(f"not_everything_is_written_back")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"not_everything_is_written_back")) state = sdfg.add_state(is_start_block=True) consumer_shape = (10,) if consume_all_of_t else (6,) @@ -645,7 +653,7 @@ def _make_not_everything_is_written_back( def _make_write_write_conflict( conflict_at_t: bool, ) -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"write_write_conflict_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"write_write_conflict_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "gto": @@ -717,7 +725,7 @@ def _make_write_write_conflict( def _make_read_write_conflict() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"write_write_conflict_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"write_write_conflict_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "gtom": @@ -896,7 +904,7 @@ def test_global_self_copy_elimination_tmp_downstream(): def test_direct_global_self_copy_simple(): - sdfg = dace.SDFG(util.unique_name("simple_direct_self_copy")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_direct_self_copy")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py index be7ed5405d..1bfa02cfe1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py @@ -64,7 +64,7 @@ def _perform_test( def test_map_producer_ac_consumer(): """The data is generated by a Map and then consumed by an AccessNode.""" - sdfg = dace.SDFG(util.unique_name("map_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -96,7 +96,7 @@ def test_map_producer_ac_consumer(): def test_map_producer_map_consumer(): """The data is generated by a Map and then consumed by another Map.""" - sdfg = dace.SDFG(util.unique_name("map_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -135,7 +135,7 @@ def test_map_producer_map_consumer(): def test_ac_producer_ac_consumer(): - sdfg = dace.SDFG(util.unique_name("ac_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -159,7 +159,7 @@ def test_ac_producer_ac_consumer(): def test_ac_producer_map_consumer(): - sdfg = dace.SDFG(util.unique_name("ac_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -193,7 +193,9 @@ def test_simple_splitable_ac_source_not_full_consume(): """Similar to `test_simple_splitable_ac_source_full_consume`, but one consumer does not fully consumer what is produced. """ - sdfg = dace.SDFG(util.unique_name("simple_splitable_ac_source_not_full_consume")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("simple_splitable_ac_source_not_full_consume") + ) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -221,7 +223,9 @@ def test_simple_splitable_ac_source_multiple_consumer(): """Similar to `test_simple_splitable_ac_source_not_full_consume`, but there are multiple consumer, per producer. """ - sdfg = dace.SDFG(util.unique_name("simple_splitable_ac_source_multiple_consumer")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("simple_splitable_ac_source_multiple_consumer") + ) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -259,7 +263,9 @@ def _make_transient_producer_sdfg( depending on the value of `partial_read`. """ sdfg = dace.SDFG( - util.unique_name("partial_ac_read" + ("_partial_read" if partial_read else "_full_read")) + gtx_transformations.utils.unique_name( + "partial_ac_read" + ("_partial_read" if partial_read else "_full_read") + ) ) state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "c", "d", "t1", "t2"]: @@ -302,7 +308,7 @@ def test_transient_producer_partial_read(): def test_overlapping_consume_ac_source(): """There are 2 producers, but only one consumer that needs both producers.""" - sdfg = dace.SDFG(util.unique_name("overlapping_consume_ac_source")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("overlapping_consume_ac_source")) state = sdfg.add_state(is_start_block=True) for name in "abtc": @@ -333,7 +339,7 @@ def _make_map_producer_multiple_consumer( `partial_read` it will either read the full data for only parts of it. """ sdfg = dace.SDFG( - util.unique_name( + gtx_transformations.utils.unique_name( "map_producer_map_consumer" + ("_partial_read" if partial_read else "_full_read") ) ) @@ -392,7 +398,7 @@ def test_map_producer_multi_consumer_partialread(): def test_same_producer(): - sdfg = dace.SDFG(util.unique_name("same_producer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("same_producer")) state = sdfg.add_state(is_start_block=True) for name in "abt": @@ -454,7 +460,7 @@ def test_map_producer_complex_map_consumer(): b[15:20] -> e[15:20] c[20:25] -> e[20:25] """ - sdfg = dace.SDFG(util.unique_name("map_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -531,7 +537,7 @@ def test_map_producer_complex_map_consumer(): def test_map_producer_map_consumer_complex(): """The data is generated by a Map and then consumed by another Map.""" - sdfg = dace.SDFG(util.unique_name("map_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -623,7 +629,7 @@ def test_map_producer_ac_consumer_complex(): " nodes that have as inputs Maps whose outputs are assigned to AccessNodes." ) - sdfg = dace.SDFG(util.unique_name("map_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -693,7 +699,7 @@ def test_ac_producer_complex_map_consumer(): " nodes that have AccessNodes as input and Maps as outputs." ) - sdfg = dace.SDFG(util.unique_name("map_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -760,7 +766,7 @@ def test_ac_producer_complex_ac_consumer(): b[33:38] -> e[22:27] c[14:19] -> e[27:32] """ - sdfg = dace.SDFG(util.unique_name("ac_producer_complex_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_complex_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -817,7 +823,7 @@ def test_ac_producer_ac_consumer_complex(): b[28:33] -> d[27:32] b[33:38] -> e[34:39] """ - sdfg = dace.SDFG(util.unique_name("ac_producer_ac_consumer_complex")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_ac_consumer_complex")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -868,7 +874,7 @@ def test_unused_partial_read_from_inout_node(): a[0:5] -> c[0:5] t[5:10] -> c[5:10] """ - sdfg = dace.SDFG(util.unique_name("full_write_from_global_to_split_node")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("full_write_from_global_to_split_node")) for data in "abc": sdfg.add_array(data, [10], dace.float64) t, _ = sdfg.add_temp_transient([10], dace.float64) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py index e5d5532ade..9dda69f253 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py @@ -25,7 +25,9 @@ def _make_split_edge_two_ac_producer_one_ac_consumer_1d_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("split_edge_two_ac_producer_one_ac_consumer_1d")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("split_edge_two_ac_producer_one_ac_consumer_1d") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array("src1", shape=(10,), dtype=dace.float64, transient=False) @@ -79,7 +81,7 @@ def _make_split_edge_mock_apply_diffusion_to_w_sdfg() -> tuple[ ]: # Test is roughly equivalent to what we see in `apply_diffusion_to_w` # Although instead of Maps we sometimes use direct edges. - sdfg = dace.SDFG(util.unique_name("split_edge_mock_apply_diffusion_to_w")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("split_edge_mock_apply_diffusion_to_w")) state = sdfg.add_state(is_start_block=True) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py index d8373593c5..06e2caf187 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py @@ -37,7 +37,7 @@ def _make_distributed_split_sdfg() -> tuple[ dace.SDFGState, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name("distributed_split_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("distributed_split_sdfg")) state = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state) @@ -115,7 +115,7 @@ def test_distributed_split(): def _make_split_node_simple_sdfg() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.MapExit, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("single_state_split")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("single_state_split")) state = sdfg.add_state(is_start_block=True) for name in "abt": @@ -201,7 +201,7 @@ def test_simple_node_split(): def _make_split_edge_sdfg() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.AccessNode ]: - sdfg = dace.SDFG(util.unique_name("split_edge")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("split_edge")) state = sdfg.add_state(is_start_block=True) for name in "abt": @@ -281,7 +281,7 @@ def test_split_edge(): def test_split_edge_2d(): - sdfg = dace.SDFG(util.unique_name("split_edge_2d")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("split_edge_2d")) state = sdfg.add_state(is_start_block=True) for name in "ab": @@ -422,7 +422,7 @@ def test_subset_merging_stability(): def _make_sdfg_for_deterministic_splitting() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode ]: - sdfg = dace.SDFG(util.unique_name("deterministic_splitter")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("deterministic_splitter")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py index 653a36bad9..d5734f9f27 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py @@ -25,7 +25,7 @@ def _make_simple_two_state_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("simple_linear_states")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_linear_states")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -71,7 +71,7 @@ def _make_simple_two_state_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGS def _make_global_in_both_read_and_write() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: """The first state contains a read to a global and the second contains a write to the global.""" - sdfg = dace.SDFG(util.unique_name("global_in_both_states_read_and_write")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_in_both_states_read_and_write")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -106,7 +106,7 @@ def _make_global_in_both_read_and_write() -> tuple[dace.SDFG, dace.SDFGState, da def _make_global_both_state_read() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: """In both states the same global is read.""" - sdfg = dace.SDFG(util.unique_name("global_read_in_the_same_state")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_read_in_the_same_state")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -140,7 +140,7 @@ def _make_global_both_state_read() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFG def _make_global_both_state_write() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: """In both states the same global is written.""" - sdfg = dace.SDFG(util.unique_name("global_write_in_the_same_state")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_write_in_the_same_state")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -176,7 +176,9 @@ def _make_empty_state( first_state_empty: bool, ) -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: sdfg = dace.SDFG( - util.unique_name("global_" + ("first" if first_state_empty else "second") + "_state_empty") + gtx_transformations.utils.unique_name( + "global_" + ("first" if first_state_empty else "second") + "_state_empty" + ) ) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -204,7 +206,7 @@ def _make_empty_state( def _make_global_merge_1() -> dace.SDFG: """The first state writes to a global while the second state reads and writes to it.""" - sdfg = dace.SDFG(util.unique_name("global_merge_1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_merge_1")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -237,7 +239,7 @@ def _make_global_merge_1() -> dace.SDFG: def _make_global_merge_2() -> dace.SDFG: """In both states the global data is read and written to.""" - sdfg = dace.SDFG(util.unique_name("global_merge_2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_merge_2")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -269,7 +271,7 @@ def _make_global_merge_2() -> dace.SDFG: def _make_swapping_sdfg() -> dace.SDFG: """Makes an SDFG that implements `x, y = y, x` with Memlets.""" - sdfg = dace.SDFG(util.unique_name("swapping_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("swapping_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -299,7 +301,7 @@ def _make_swapping_sdfg() -> dace.SDFG: def _make_non_concurrent_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("non_concurrent_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("non_concurrent_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -345,7 +347,7 @@ def _make_non_concurrent_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGSta def _make_double_producer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("double_producer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_producer_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -394,7 +396,7 @@ def _make_double_producer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGSt def _make_double_consumer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("double_consumer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_consumer_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -468,7 +470,7 @@ def _make_double_consumer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGSt def _make_hidden_double_producer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("hidden_double_producer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("hidden_double_producer_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -624,7 +626,7 @@ def test_empty_second_state(): def test_both_states_are_empty(): - sdfg = dace.SDFG(util.unique_name("full_empty_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("full_empty_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index e2fabc6383..f1c8da0143 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -25,7 +25,7 @@ def _make_strides_propagation_level3_sdfg() -> dace.SDFG: """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" - sdfg = dace.SDFG(util.unique_name("level3")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("level3")) state = sdfg.add_state(is_start_block=True) names = ["a3", "c3"] @@ -58,7 +58,7 @@ def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.Neste The function returns the level 2 SDFG and the NestedSDFG node that contains the level 3 SDFG. """ - sdfg = dace.SDFG(util.unique_name("level2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("level2")) state = sdfg.add_state(is_start_block=True) names = ["a2", "a2_alias", "b2", "c2"] @@ -125,7 +125,7 @@ def _make_strides_propagation_level1_sdfg() -> tuple[ - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). """ - sdfg = dace.SDFG(util.unique_name("level1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("level1")) state = sdfg.add_state(is_start_block=True) names = ["a1", "b1", "c1"] @@ -238,7 +238,9 @@ def test_strides_propagation(): def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_dependent_symbol_nsdfg") + ) state = sdfg.add_state(is_start_block=True) array_names = ["a2", "b2"] @@ -266,7 +268,9 @@ def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) + sdfg_level1 = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_dependent_symbol_sdfg") + ) state = sdfg_level1.add_state(is_start_block=True) array_names = ["a1", "b1"] @@ -345,7 +349,9 @@ def test_strides_propagation_symbolic_expression(): def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_shared_symbols_nsdfg") + ) state = sdfg.add_state(is_start_block=True) # NOTE: Both arrays have the same symbols used for strides. @@ -379,7 +385,9 @@ def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + sdfg_level1 = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_shared_symbols_sdfg") + ) state = sdfg_level1.add_state(is_start_block=True) # NOTE: Both arrays use the same symbols as strides. @@ -471,7 +479,9 @@ def ref(a1, b1): def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: - sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_stride_1_nsdfg")) + sdfg_level1 = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_stride_1_nsdfg") + ) state = sdfg_level1.add_state(is_start_block=True) a_stride_sym = dace.symbol("a_stride", dtype=dace.uint32) @@ -500,7 +510,7 @@ def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: def _make_strides_propagation_stride_1_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg = dace.SDFG(util.unique_name("strides_propagation_stride_1_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("strides_propagation_stride_1_sdfg")) state = sdfg.add_state(is_start_block=True) a_stride_sym = dace.symbol("a_stride", dtype=dace.uint32) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py index f6e85171d8..44a0a95f2e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py @@ -23,7 +23,9 @@ def serial_map_sdfg(N, extra_intermediate_edge=False): sdfg = dace.SDFG( - util.unique_name("serial_map" if extra_intermediate_edge else "serial_map_extra_edge") + gtx_transformations.utils.unique_name( + "serial_map" if extra_intermediate_edge else "serial_map_extra_edge" + ) ) A, _ = sdfg.add_array("A", [N], dtype=dace.float64) B, _ = sdfg.add_array("B", [N], dtype=dace.float64) @@ -150,7 +152,7 @@ def test_vertical_map_fusion_disabled(): @pytest.mark.parametrize("run_map_fusion", [True, False]) def test_vertical_map_fusion_with_neighbor_access(run_map_fusion: bool): N = 80 - sdfg = dace.SDFG(util.unique_name("simple")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple")) A, _ = sdfg.add_array("A", shape=(N,), dtype=dace.float64, strides=(1,)) B, _ = sdfg.add_array("B", shape=(N,), dtype=dace.float64, strides=(1,)) C, _ = sdfg.add_array("C", shape=(N,), dtype=dace.float64, strides=(1,)) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py index 4c8e7f3899..fec2b09537 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import uuid from typing import Literal, Union, overload, Any import numpy as np @@ -14,6 +13,9 @@ import copy from dace.sdfg import nodes as dace_nodes from dace import data as dace_data +from gt4py.next.program_processors.runners.dace.transformations import ( + utils as gtx_transformations_utils, +) @overload @@ -58,15 +60,6 @@ def count_nodes( return len(found_nodes) -def unique_name(name: str) -> str: - """Adds a unique string to `name`.""" - maximal_length = 200 - unique_sufix = str(uuid.uuid1()).replace("-", "_") - if len(name) > (maximal_length - len(unique_sufix)): - name = name[: (maximal_length - len(unique_sufix) - 1)] - return f"{name}_{unique_sufix}" - - def compile_and_run_sdfg( sdfg: dace.SDFG, *args: Any, @@ -82,7 +75,7 @@ def compile_and_run_sdfg( with dace.config.set_temporary("compiler.use_cache", value=False): sdfg_clone = copy.deepcopy(sdfg) - sdfg_clone.name = unique_name(sdfg_clone.name) + sdfg_clone.name = gtx_transformations_utils.unique_name(sdfg_clone.name) sdfg_clone._recompile = True sdfg_clone._regenerate_code = True # TODO(phimuell): Find out if it has an effect. csdfg = sdfg_clone.compile() From 22969633f62bd106a9074d2e24fd85c12106f676 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 3 Feb 2026 14:21:32 +0100 Subject: [PATCH 44/61] Added a test to check for conflicts that are caused by the relocation of data. --- .../move_dataflow_into_if_body.py | 67 +++++++++++++ .../test_move_dataflow_into_if_body.py | 96 ++++++++++++++++++- 2 files changed, 161 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 1db9047199..5afb309edd 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -154,6 +154,16 @@ def can_be_applied( if all(len(rel_df) == 0 for rel_df in relocatable_dataflow.values()): return False + # Check if relatability is possible. + if not self._check_relocatability( + sdfg=sdfg, + state=graph, + relocatable_dataflow=relocatable_dataflow, + enclosing_map=enclosing_map, + if_block=if_block, + ): + return False + # Because the transformation can only handle `if` expressions that # are _directly_ inside a Map, we must check if the upstream contains # suitable `if` expressions that must be processed first. The simplest way @@ -480,6 +490,63 @@ def _update_symbol_mapping( or dace_dtypes.typeclass(int), ) + def _check_relocatability( + self, + sdfg: dace.SDFG, + state: dace.SDFGState, + relocatable_dataflow: dict[str, set[dace_nodes.Node]], + if_block: dace_nodes.NestedSDFG, + enclosing_map: dace_nodes.MapEntry, + ) -> bool: + """Check if the relocation would cause any conflict, such as a symbol clash.""" + + # TODO: also names of data containers. + + # TODO(phimuell): There is an obscure case where the nested SDFG, on its own, + # defines a symbol that is also mapped, for example a dynamic Map range. + # It is probably not a problem, because of the scopes DaCe adds when + # generating the C++ code. + + # Create a subgraph to compute the free symbols, i.e. the symbols that + # need to be supplied from the outside. However, this are not all. + # Note, just adding some "well chosen" nodes to the set will not work. + all_relocated_dataflow: set[dace_nodes.Node] = functools.reduce( + lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() + ) + subgraph_view = dace.sdfg.state.StateSubgraphView(state, all_relocated_dataflow) + requiered_symbols: set[str] = subgraph_view.free_symbols + + for node_to_check in all_relocated_dataflow: + if isinstance(node_to_check, dace_nodes.MapEntry): + # This means that a nested Map is fully relocated into the `if` block. + # When DaCe computes the free symbols, it removes these symbols. + # TODO(phimuell): Because of C++ scoping rules it might be possible + # to skip this step, i.e. not add them to the set. + assert node_to_check is not enclosing_map + requiered_symbols |= set(node_to_check.map.params) + + for iedge in state.in_edges(node_to_check): + src_node = iedge.src + if src_node not in all_relocated_dataflow: + # This means that `src_node` is not relocated but mapped into the + # `if` block. This means that `edge` is replicated as well. + # NOTE: This code is based on the one found in `DataflowGraphView`. + # TODO(phimuell): Do we have to inspect the full Memlet path here? + assert isinstance(src_node, dace_nodes.AccessNode) or src_node is enclosing_map + requiered_symbols |= iedge.data.used_symbols(True, edge=iedge) + + # A conflicting symbol is a free symbol of the relocatable dataflow, that is not a + # direct mapping. For example if there is a symbol `n` on the inside and outside + # then everything is okay if the symbol mapping is `{n: n}` i.e. the symbol has the + # same meaning inside and outside. Everything else is not okay. + symbol_mapping = if_block.symbol_mapping + conflicting_symbols = requiered_symbols.intersection((str(k) for k in symbol_mapping)) + for conflicting_symbol in conflicting_symbols: + if conflicting_symbol != str(symbol_mapping[conflicting_symbol]): + return False + + return True + def _find_branch_for( self, if_block: dace_nodes.NestedSDFG, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 03aba6599e..15211021b1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -93,9 +93,8 @@ def _perform_test( return sdfg # General case, run the SDFG first and then compare the result. - ref, res = util.make_sdfg_args(sdfg) - if explected_applies != 0: + ref, res = util.make_sdfg_args(sdfg) util.compile_and_run_sdfg(sdfg, **ref) nb_apply = sdfg.apply_transformations_repeated( @@ -1199,3 +1198,96 @@ def test_if_mover_access_node_between(): "__cond", } assert set(top_if_block.sdfg.arrays.keys()) == expected_top_if_block_data + + +def test_if_mover_symbol_aliasing(): + sdfg = dace.SDFG(util.unique_name("if_mover_symbol_alias")) + state = sdfg.add_state(is_start_block=True) + + scalar_names = ["cond", "a1", "b2"] + array_names = list("abcd") + sdfg.add_symbol("n", stype=dace.int32) + for aname in array_names: + sdfg.add_array( + aname, + shape=((10, "n") if aname in "ab" else (10,)), + dtype=dace.float64, + transient=False, + ) + for sname in scalar_names: + sdfg.add_scalar( + sname, + dtype=(dace.bool_ if sname == "cond" else dace.float64), + transient=True, + ) + a, b, c, d, cond_ac, true_ac, false_ac = ( + state.add_access(name) for name in array_names + scalar_names + ) + + me, mx = state.add_map("outer_map", ndrange={"__i": "0:10"}) + + for ac in [a, b, c]: + state.add_edge( + ac, + None, + me, + f"IN_{ac.data}", + dace.Memlet(f"{ac.data}[0:10" + ("]" if ac is c else ", 0:n]")), + ) + me.add_scope_connectors(ac.data) + + # Make the condition. + cond_tlet = state.add_tasklet( + "cond_tlet", + inputs={"__in0"}, + outputs={"__out"}, + code="__out = __in0 < 0.0", + ) + state.add_edge(me, "OUT_c", cond_tlet, "__in0", dace.Memlet("c[__i]")) + state.add_edge(cond_tlet, "__out", cond_ac, None, dace.Memlet(f"{cond_ac.data}[0]")) + + # The true branch. + true_tlet = state.add_tasklet( + "true_tlet", + inputs={"__in0"}, + outputs={"__out"}, + code="__out = __in0 + 1.0", + ) + state.add_edge(me, "OUT_a", true_tlet, "__in0", dace.Memlet("a[__i, n - 1]")) + state.add_edge(true_tlet, "__out", true_ac, None, dace.Memlet(f"{true_ac.data}[0]")) + + # False branch + false_tlet = state.add_tasklet( + "false_tlet", + inputs={"__in0"}, + outputs={"__out"}, + code="__out = __in0 + 1.0", + ) + state.add_edge(me, "OUT_b", false_tlet, "__in0", dace.Memlet("b[__i, n - 3]")) + state.add_edge(false_tlet, "__out", false_ac, None, dace.Memlet(f"{false_ac.data}[0]")) + + # Create the top `if_block` + if_block = _make_if_block(state, sdfg) + + # By Adding this symbol mapping, we emulate the case where something is used + # inside and special case must be taken. + assert len(if_block.symbol_mapping) == 0 + if_block.symbol_mapping["n"] = "n - 1" + + # Connect the inputs to the if block. + state.add_edge(true_ac, None, if_block, "__arg1", dace.Memlet(f"{true_ac}[0]")) + state.add_edge(false_ac, None, if_block, "__arg2", dace.Memlet(f"{false_ac}[0]")) + state.add_edge(cond_ac, None, if_block, "__cond", dace.Memlet(f"{cond_ac}[0]")) + + state.add_edge(if_block, "__output", mx, "IN_d", dace.Memlet("d[__i]")) + state.add_edge(mx, "OUT_d", d, None, dace.Memlet("d[0:10]")) + mx.add_scope_connectors("d") + + sdfg.validate() + + # Because `n` is already taken, see above, we need an additional symbol mapping + # to account for the access on the Memlets of the `{true, false}_tlet`. + _perform_test( + sdfg=sdfg, + explected_applies=0, + ) From 3319c000848799c9ba2ca80f839f65f46c506799 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 14:35:29 +0100 Subject: [PATCH 45/61] Apply review comments on RemoveScalarCopies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philip Müller --- .../runners/dace/transformations/__init__.py | 4 +- .../dace/transformations/auto_optimize.py | 2 +- ...ing_scalars.py => remove_scalar_copies.py} | 48 +++++++++++++------ ...calars.py => test_remove_scalar_copies.py} | 8 ++-- 4 files changed, 39 insertions(+), 23 deletions(-) rename src/gt4py/next/program_processors/runners/dace/transformations/{remove_aliasing_scalars.py => remove_scalar_copies.py} (85%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/{test_remove_aliasing_scalars.py => test_remove_scalar_copies.py} (93%) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 1fe71d20da..0c1ec3424f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -54,7 +54,7 @@ ) from .redundant_array_removers import CopyChainRemover, DoubleWriteRemover, gt_remove_copy_chain from .remove_access_node_copies import RemoveAccessNodeCopies -from .remove_aliasing_scalars import RemoveAliasingScalars +from .remove_scalar_copies import RemoveScalarCopies from .remove_views import RemovePointwiseViews from .scan_loop_unrolling import ScanLoopUnrolling from .simplify import ( @@ -107,8 +107,8 @@ "MultiStateGlobalSelfCopyElimination", "MultiStateGlobalSelfCopyElimination2", "RemoveAccessNodeCopies", - "RemoveAliasingScalars", "RemovePointwiseViews", + "RemoveScalarCopies", "ScanLoopUnrolling", "SingleStateGlobalDirectSelfCopyElimination", "SingleStateGlobalSelfCopyElimination", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 10a6256fef..74bdb03351 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -730,7 +730,7 @@ def _gt_auto_process_dataflow_inside_maps( single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.RemoveAliasingScalars( + gtx_transformations.RemoveScalarCopies( single_use_data=single_use_data, ), validate=False, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py similarity index 85% rename from src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py rename to src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index 3572f43980..fc7b3288be 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -15,7 +15,27 @@ @dace_properties.make_properties -class RemoveAliasingScalars(dace_transformation.SingleStateTransformation): +class RemoveScalarCopies(dace_transformation.SingleStateTransformation): + """Removes copies between two scalar transient variables. + Exaxmple: + ___ + / \ + | A | + \___/ + | + \/ + ___ + / \ + | B | + \___/ + is transformed to + ___ + / \ + | A | + \___/ + and all uses of B are replaced with A. + """ + first_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) second_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -59,20 +79,24 @@ def can_be_applied( second_node_desc = second_node.desc(sdfg) scope_dict = graph.scope_dict() - if first_node not in scope_dict or second_node not in scope_dict: - return False # Make sure that both access nodes are in the same scope if scope_dict[first_node] != scope_dict[second_node]: return False # Make sure that both access nodes are transients - if not first_node_desc.transient or not second_node_desc.transient: + if not (first_node_desc.transient and second_node_desc.transient): + return False + + edges = list(graph.edges_between(first_node, second_node)) + if len(edges) != 1: + return False + + # Check that the second access node has only one incoming edge, which is the one from the first access node. + if graph.in_degree(second_node) != 1: return False - edges = graph.edges_between(first_node, second_node) - assert len(edges) == 1 - edge = next(iter(edges)) + edge = edges[0] # Check if the edge transfers only one element if edge.data.num_elements() != 1: @@ -85,7 +109,8 @@ def can_be_applied( if out_edges.data.num_elements() != 1: return False - # Make sure that the edge subset is 1 + # Make sure that the data descriptors of both access nodes are scalars + # TODO(iomaganaris): We could extend this transfromation to handle AccessNodes that are arrays with 1 element as well. if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( second_node_desc, dace.data.Scalar ): @@ -97,13 +122,6 @@ def can_be_applied( ): return False - # Make sure that both access nodes are transients - if not first_node_desc.transient or not second_node_desc.transient: - return False - - if graph.in_degree(second_node) != 1: - return False - # Make sure that both access nodes are single use data if self.assume_single_use_data: single_use_data = {sdfg: {first_node.data}} diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_scalar_copies.py similarity index 93% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_scalar_copies.py index e7aa91c517..c7616541db 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_scalar_copies.py @@ -46,10 +46,8 @@ def _make_map_with_scalar_copies() -> tuple[ tmp0, tmp1, tmp2 = (state.add_access(f"tmp{i}") for i in range(3)) me, mx = state.add_map("copy_map", ndrange={"__i": "0:10"}) - me.add_in_connector("IN_a") - me.add_out_connector("OUT_a") - mx.add_in_connector("IN_b") - mx.add_out_connector("OUT_b") + me.add_scope_connectors("a") + mx.add_scope_connectors("b") state.add_edge(a, None, me, "IN_a", dace.Memlet("a[__i]")) state.add_edge(me, "OUT_a", tmp0, None, dace.Memlet("a[__i]")) state.add_edge(tmp0, None, tmp1, None, dace.Memlet("tmp1[0]")) @@ -72,7 +70,7 @@ def test_remove_double_write_single_consumer(): find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.RemoveAliasingScalars( + gtx_transformations.RemoveScalarCopies( single_use_data=single_use_data, assume_single_use_data=False, ), From 202aac61703a6428968492f2f66f5669292bacf9 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 14:37:49 +0100 Subject: [PATCH 46/61] Remove _dacegraphs --- _dacegraphs/invalid.sdfgz | Bin 5370 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 _dacegraphs/invalid.sdfgz diff --git a/_dacegraphs/invalid.sdfgz b/_dacegraphs/invalid.sdfgz deleted file mode 100644 index 2869814a1c9a9c302e27afa792eef634827ef0c0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5370 zcmV95~>_x$uSL%B9$2`B7ndQRH^Z9b2 zE@%33Y5u*QWhBE?3p2yms~I7l%%s!|esZitCaLADd;>ghrdBKU`SAtCYb!r0<@$SZ z;oWVU>-LLxcV-PIYez&bHlc}4cxV&a z*o22RVSaTzYv99Qjk;8;uy!0)g_bs6sl|Hk7DYqms~dMn_4H|1&zyGk?CDmouxw_Z zpcdoRP%oF4FI>H|N+qY~mrD&~X_nSiEik{@Sb3Fd7fif3pIMa`bTqBj`%*2=Z`65N z!LNV*d0P9X`C^_w$9vqL&-0%1n|T5KX(BbjsWd22BtzU&hNSfZVZt26NhUcJ_qOhZ zT3;0RWWuD95t1^@wN4bKMp2T%fi!VcCxm32WGT^t+BDIU5SLQyOy)DHG51A@%5qE> z%j(zhp}tj9{?}aP?xA{6Iqfb>C>FN*z@h~ZjH`|##h_N8<}HgL~oy}yzraiq92Nj-tAtp&ebZPUCtN9g#fX= z3mmXq%jwPifQOr_MHM1j_qUw(^GkO>Y5ZUSQ~7fBVQp3_Hy0nw+N?j!muj)NQs*DO z|K;z$|NQO$eK^ng_2&=J8u|Le!u^@gHh#ovy0_w{-5O{FS+n;KD(LBzQ6uF-Kx^0Tv3cNpJ0?^8Q07(_#C+o*pefGt%r9LVhm->k`CC?BY~}rust={+D6zaNhu=|<_$@TIFlM> zne<3XZxoX*$zU4M4qJ94u-z7*z=+zOur-CZhKg(c5V8N=FGSKqkm09?+|Kf-AuTm0hULk;rrBPh_glg7--t;)o`>Arepu{xOXY5NU3N_U=^O^l=Jej#+npmH+tE2Ldx&=S5RpB4vd4T;+_G>6>b0)x zqaS`Mr9HcAGxz88?#fKvOO`iz$6NmK`)|#!DNY8x=bq|BDkMpSMV@kpxXCi4d5V$8 zm}16+mYq{2Ie49_H`f)nJq^``bgpUXTsvOpnzqiheL6SOKBb^CKrkVOnS?U}W)cJ= zD>6zP%e>8K*Om`{mJiu)YW?p^*VMWB=4?4)!@G1xG}5FLna!YECSOhd2B=v~{^zR+ zexH1?&#qMw$6B&fB@Fk zpSO*^y%Kdjam!(1c%kome9u#22M}s?P8!a(=@$FcCPcBh-KMqj#%8Hwi}@d-a%{7v zYX(PkEM4E^vGE2q9^j6~rW@3B-;5-xx96L7sp)Q`)LuQ zG&4dbD6xhiSi~=cSCk_~oY02pL~hGmYoAx+7tp%Ed_UN}e)OIG-z|5Hz9aadI+hup zN(tpxpkm-Q;YfS$k&s?;sGS10lW8zbHCL;tYP;3TSIf)qE|-7&;?AL(&8!l{QdPl` z$TB9O+qBoP4k>V2i0V_M661xW*f1}fSm?H~6z%nGt8<6jO#_zimlVK@Yo=l8!d$f9 zvF{A9cS)n;PW`B>cjl?RcF)15?!VT;O?DlD>i(l$7qG4ic<@~T?h%; zP;Q6H-7LTJhQRR@wEWmwi{>TC2AZ!dQzjzlp32H);2m9ynT@aMdCRdcdR{F*x$=_= zR^1=4^q4PUzhkDC_VeC)XYb*#vqv5X%q_`tqEio=(Imz2sq)gn{y9@f!FGf8+*l@( zt}=F(Fn}@xwxQZ?sgY0*DCtwP#@zzfiSS0+n>W(P8^_8UfzYhFrWr~>T~3g7!ho|x zYLF^81nfSqz|jZnLOv36kkes~x;@hXyEG5ItuJ-7OcuLlhqL=qJ?_QcAAmjC|DMqI z_=_-J>Ds0p4DT1=MckVgapc7(+(aIVo?#;}U<11Lrt07cA*9e+bZ%mIl%`>WduVYU z^zv33!V-$8%3vcLlr2aSr4cyVX(DKvxs;9c^`LLw?hN0kc~a9fq+$0aGvk!IDxL#j zd%_ZhWR`%J9JVRa%sRwKYOI2FJLAWi@vu1K>9K`%T-Y14X;YA+Cx}7#23Y9^ zcMF|@yG7!hahk#3T#ST1egx=iIs377na|ntF_-y%yt&NhU*CTQzqRM?3E;NDPA^0A zM|TxwRP--TdC#Z9^mENZ__`jx(VN(eVXx>E{x{y%;t(*jzxy~4w3*BU?rL{9Xl}b^ zM}#4RDs+v#m(jMdBl8QlnniF#aE!73Jn;Lxy#PfWORQkXdJ92j5&`?Inj3_{4U7p0 z*kO4ZOnNI5u^q8pi&9u}h|q}8s57(IJ^ z4x5?;^m8#R!n(E`Ko2WW!Sf6tQEhNulm25Kk_k|-^}<}fL7kd9)IY0;xBait;-ogmYq(CXmC3 zaqvZhWGJBbS4d-#A^{=$g%Um?)V~^+tr8T`XKNDbvvuk&S@c9)D+KQ=Ar8DFN{C;^ zC5!fvh$Y5q4S6UK#O&7~P^2CtVu?W}O;WA|WU{h6x@1L{tQy`(dh4Od+yjoGCu znK=#NCF67=4uCkFh|q}BiKh|1vZI+!JgeH(WY`g!dX%aRq>jvKq$LE2z=CJiU@0-> zuUqZvO8H~A+LZ*VUHM^Z*R#MD>1d6xme@)}S7dXOU{mzAcv@soJ8eVXoprWL)2Nwl zB&sbRjiygYi=uZg8f#51#DP&(8#$<@IQ_NDeTMZYDJNRVHAuzU!94wR%tMk-1Pb^1 ziJAM!nL`E5lMu}p>Ar+^U&6XC;b(^!9UUL#Km~M&9XLcAG9Wd@koBzIn#RGZ#$hQc zh9s#NsBIjiZajT`W1v+&BwX022`j>sm|+U(L5iQ=nxB9YKL@vjOHyAGyQBlyC5`NI z{Ol48*OAj&^xzkAiv03hc&z-vLH7F-OpjC-PnhgtV7iN7@V^*~;sGp*Ba6OB`ioFB z4IP~WlVA)?gArPrL8&l?Cc_9V)R+k|dVF_iv}>OvLB_$RZ|ElLlZDwEHBm+nM6!s; zC-U)m2~3V1p=s0v8a+_Sj|P?P!5cajD8oE|dO>@Drf(P!!*zYblzqdreS_3}L-c)v zg$l!@3jM?i!{rJ?1q*{E3&TVU;j)EL;X<%6}Q!9u^MNyjl%>=$wW zhH*bc0vG}Fb^$X~7&ufK7*^ZcQKl}z9!GtAjt0sEhX@77orK4kN}Q=gHhoUBbIhb4 zqu8e7CH{!EX&f8}*`~b!{058^*274lLyVMu2|}W|63vy!$ZyA789$ZC7zWFT=|qMn z6p5=UabO*`s?sCT$Sbl~Xb+2pKjX4WpTr}FZYcNqu#c92q`lP5=UQqL;F{UbpQ=wX zl6#dwdxwYSBnkA>Y-*^6`XBC^__xI}pKSaSzc_gZ#X}OT9I=C9aH^H)p*ZQCdga92 z59x_H2!$nA8Iq7CO50!D{gC#PAC8fnCAuF*_rtzPTZW`kO-LQ1`^#s;!GvZR3e+knM#DrY0p%~OkOj_bY{c^HgP%=2S=RFM8L%9 z%=5{w*^y0WUO;nwXd0Z?rMd1-i8DxZ&7VqheHLTT91$dGrwA1&DO3=h1B*q+I}`@783Td%GuVY?hH=+C@5=uS0A-SK#?B4 zARMMZX^N3ci_pXlvM(u>z`l;tnkf(AVfB^*XzMw5Iv>%(cU_W>#WN>pUvjYiRJE0@z!g) z&F(G(`!?+RSkrv>%L=NlZM&BR9>AXwU8pb(l3qX*WGWfR-#7zLT;VxMJ;V>OwcXb> z)sJ1+*@fLTu<1Z7#BU}xg-n1T=LJn% z=7_^-8nCAtyYSQ(>qsY~?WB^6d=|w6FX5g*mBX6<20!vbXjWk|m z(LLdfahj~-p4dNu71K+iA+|G4z++v()HB0aAy0fJ zG}l^c8aXX;THUJL08Z19)8c@5Y6k@&S;8btIa%R6$`ye)QbbGSwbHOVs?)SKp=;Fg zp0bvg)H{;WF<8w+R*P_dY7YftA^3xn96?&kB3Cv1BU48p3kSyxCtjsiMplcgwjYXX zAgkHPYH@-288$wJ76#LdAcKpIPi+MRO=GaOMqsl*d74O?y>4QAwLSepq*m`BpB+kS zv+Y#vk<=p8>o-0Roq$r75XLO{#S`g^AYKy0x)9EbdPTtL%6ZrUjg0n|7>(94niz%A zt|O^Mw0D!*q4Q6hf_TFbu*(z$zXOg50u={@Gg49z#!VuLj-(bzEgg&0!eVQ-#H(!J zok#cRB(o4xkfTH}SW%JSBP}7(DWbw=puyysmk+TvH%~AAEuX&=1h_u%?@qAa?Et?! zL4C_mx4VIS%VTBS*-sFM7DCOId&}&#OEko5p*#Lv~%gMtRVDhh_&T#*lPX18q$@%<~TTGypC-e1WvCJoGvU2B8FSl}b zQowRz9YpAzUl(tiFD8(+WU_!*xp4N^Z@&Bf{p8y})z$T-`*L!vAPTR88{sbish#9Y zc$&Xin6m7D%<~IC1{}I7!u`${=M%MBsn6vJc Date: Tue, 3 Feb 2026 14:45:02 +0100 Subject: [PATCH 47/61] Small update. --- .../transformation_tests/test_move_dataflow_into_if_body.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 15211021b1..b185fc88e1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -1201,6 +1201,12 @@ def test_if_mover_access_node_between(): def test_if_mover_symbol_aliasing(): + """Tests if symbol clashes are detected. + + Essentially there is a symbol `n` both in the parent SDFG and the `if_block`, + however, with different meanings. Thus the relocation will lead to invalid + behaviour and should be rejected. + """ sdfg = dace.SDFG(util.unique_name("if_mover_symbol_alias")) state = sdfg.add_state(is_start_block=True) From 9e6671ccbfca9a49812f82b4757b84a9a6a67bba Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 14:52:37 +0100 Subject: [PATCH 48/61] Added more elaborative comment in FuseHorizontalConditionBlocks for removal of conditional block --- .../dace/transformations/fuse_horizontal_conditionblocks.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 6ae3fefcd2..19ddb98ff9 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -289,7 +289,10 @@ def _find_corresponding_state_in_second( if edge.dst == second_cb: graph.remove_edge(edge) - # Need to remove both references to remove NestedSDFG from graph + # TODO(iomaganaris): Atm need to remove both references to remove NestedSDFG from graph + # second_conditional_block is inside the SDFG of NestedSDFG second_cb and removing only + # one of them keeps a reference to the other one so none is properly deleted from the SDFG. + # For now remove both but maybe this can be improved in the future. graph.remove_node(second_conditional_block) graph.remove_node(second_cb) From c55a04a5df167a9bc8ff74521c6dd13b784b7c13 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 15:26:29 +0100 Subject: [PATCH 49/61] make sure that the symbol mapping is the same between the fused nested sdfgs --- .../dace/transformations/fuse_horizontal_conditionblocks.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 19ddb98ff9..762fae9ae1 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -96,6 +96,12 @@ def can_be_applied( if first_cb.sdfg.number_of_nodes() > 1 or second_cb.sdfg.number_of_nodes() > 1: return False + # Check that the symbol mappings are compatible + sym_map1 = first_cb.symbol_mapping + sym_map2 = second_cb.symbol_mapping + if any(str(sym_map1[sym]) != str(sym_map2[sym]) for sym in sym_map2 if sym in sym_map1): + return False + # Get the actual conditional blocks first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) From b1be947f62693a30cf52b8f04cea08ebb99b69e7 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 3 Feb 2026 15:27:19 +0100 Subject: [PATCH 50/61] Added a check for data descriptors. --- .../transformations/move_dataflow_into_if_body.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 5afb309edd..cb3b9d08d0 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -500,8 +500,6 @@ def _check_relocatability( ) -> bool: """Check if the relocation would cause any conflict, such as a symbol clash.""" - # TODO: also names of data containers. - # TODO(phimuell): There is an obscure case where the nested SDFG, on its own, # defines a symbol that is also mapped, for example a dynamic Map range. # It is probably not a problem, because of the scopes DaCe adds when @@ -516,6 +514,7 @@ def _check_relocatability( subgraph_view = dace.sdfg.state.StateSubgraphView(state, all_relocated_dataflow) requiered_symbols: set[str] = subgraph_view.free_symbols + inner_data_names = if_block.sdfg.arrays.keys() for node_to_check in all_relocated_dataflow: if isinstance(node_to_check, dace_nodes.MapEntry): # This means that a nested Map is fully relocated into the `if` block. @@ -525,6 +524,16 @@ def _check_relocatability( assert node_to_check is not enclosing_map requiered_symbols |= set(node_to_check.map.params) + if ( + isinstance(node_to_check, dace_nodes.AccessNode) + and node_to_check.data in inner_data_names + ): + # There is already a data descriptor that is used on the inside as on + # the outside. Thous we would have to perform some renaming, which we + # currently does not. + # TODO(phimell): Handle this case. + return False + for iedge in state.in_edges(node_to_check): src_node = iedge.src if src_node not in all_relocated_dataflow: From d232b7349feb31c0add37517b96094f4553830dd Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 15:30:23 +0100 Subject: [PATCH 51/61] Added check for NestedSDFGs in FuseHorizontalConditionBlocks --- .../runners/dace/transformations/auto_optimize.py | 2 ++ .../transformations/fuse_horizontal_conditionblocks.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 74bdb03351..aeb5ca2699 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -737,6 +737,8 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + # Make sure that this runs before MoveDataflowIntoIfBody because atm it doesn't handle + # NestedSDFGs inside the ConditionalBlocks it fuses. sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), validate=True, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 762fae9ae1..f29cdca619 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -167,6 +167,13 @@ def can_be_applied( ): return False + # TODO(iomaganaris): Currently we do not handle NestedSDFGs inside the conditional blocks + # however this should be very close to handling Tasklets which is currently working. Since + # a test is missing and there is no immediate need for this feature, we leave it for future work. + for node in second_cb.sdfg.nodes(): + if isinstance(node, dace_nodes.NestedSDFG): + return False + return True def apply( From ed7835d2e4625f4e635c2d02456091ff424af5b8 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 15:32:37 +0100 Subject: [PATCH 52/61] Handle single use data check only for second AccessNode in RemoveScalarCopies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philip Müller --- .../transformations/remove_scalar_copies.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index fc7b3288be..fd4c3fc340 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -12,7 +12,7 @@ from dace import properties as dace_properties, transformation as dace_transformation from dace.sdfg import nodes as dace_nodes from dace.transformation import helpers as dace_helpers - +from dace.transformation.passes import analysis as dace_analysis @dace_properties.make_properties class RemoveScalarCopies(dace_transformation.SingleStateTransformation): @@ -110,7 +110,7 @@ def can_be_applied( return False # Make sure that the data descriptors of both access nodes are scalars - # TODO(iomaganaris): We could extend this transfromation to handle AccessNodes that are arrays with 1 element as well. + # TODO(iomaganaris): We could extend this transfromation to handle AccessNodes that are arrays with 1 element as well if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( second_node_desc, dace.data.Scalar ): @@ -122,20 +122,11 @@ def can_be_applied( ): return False - # Make sure that both access nodes are single use data - if self.assume_single_use_data: - single_use_data = {sdfg: {first_node.data}} - if self._single_use_data is None: - find_single_use_data = first_node.FindSingleUseData() - single_use_data = find_single_use_data.apply_pass(sdfg, None) - else: - single_use_data = self._single_use_data - if first_node.data not in single_use_data[sdfg]: - return False + # Make sure that the second AccessNode data is single use data since we're only going to remove that one if self.assume_single_use_data: single_use_data = {sdfg: {second_node.data}} if self._single_use_data is None: - find_single_use_data = second_node.FindSingleUseData() + find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) else: single_use_data = self._single_use_data From 4315d0c9840b7bee5bc44a2d85dd112172df788d Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 3 Feb 2026 15:41:51 +0100 Subject: [PATCH 53/61] Extend the test to also include the data that lies byond the enclosing Map. --- .../dace/transformations/move_dataflow_into_if_body.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index cb3b9d08d0..182e09adbc 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -12,6 +12,7 @@ import dace from dace import ( + data as dace_data, dtypes as dace_dtypes, properties as dace_properties, subsets as dace_sbs, @@ -544,6 +545,15 @@ def _check_relocatability( assert isinstance(src_node, dace_nodes.AccessNode) or src_node is enclosing_map requiered_symbols |= iedge.data.used_symbols(True, edge=iedge) + # The (beyond the enclosing Map) data is also mapped into the `if` block, so we + # have to consider that as well. + for iedge in state.in_edges(if_block): + if iedge.src is enclosing_map and (not iedge.data.is_empty()): + outside_desc = sdfg.arrays[iedge.data.data] + if isinstance(outside_desc, dace_data.View): + return False # Handle this case. + requiered_symbols |= outside_desc.used_symbols(True) + # A conflicting symbol is a free symbol of the relocatable dataflow, that is not a # direct mapping. For example if there is a symbol `n` on the inside and outside # then everything is okay if the symbol mapping is `{n: n}` i.e. the symbol has the From 678f18f85cdb8079ff529c06831d575d9d2ab308 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Wed, 4 Feb 2026 13:50:55 +0100 Subject: [PATCH 54/61] Fix issues with true/false_branch_x_x_x --- .../transformations/fuse_horizontal_conditionblocks.py | 8 ++++---- .../runners/dace/transformations/remove_scalar_copies.py | 1 + .../test_fuse_horizontal_conditionblocks.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index f29cdca619..b5f40f5508 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -119,10 +119,10 @@ def can_be_applied( state.name for state in second_conditional_block.all_states() ] if not ( - "true_branch" in first_conditional_block_state_names - and "false_branch" in first_conditional_block_state_names - and "true_branch" in second_conditional_block_state_names - and "false_branch" in second_conditional_block_state_names + any("true_branch" in name for name in first_conditional_block_state_names) + and any("false_branch" in name for name in first_conditional_block_state_names) + and any("true_branch" in name for name in second_conditional_block_state_names) + and any("false_branch" in name for name in second_conditional_block_state_names) ): return False diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index fd4c3fc340..e4ddd56148 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -14,6 +14,7 @@ from dace.transformation import helpers as dace_helpers from dace.transformation.passes import analysis as dace_analysis + @dace_properties.make_properties class RemoveScalarCopies(dace_transformation.SingleStateTransformation): """Removes copies between two scalar transient variables. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 39004d3786..2eb11a55d7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -46,7 +46,7 @@ def _make_if_block_with_tasklet( inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) - tstate = then_body.add_state("true_branch", is_start_block=True) + tstate = then_body.add_state("true_branch_0_1_2_3_4", is_start_block=True) tasklet = tstate.add_tasklet( "true_tasklet", inputs={"__tasklet_in"}, @@ -69,7 +69,7 @@ def _make_if_block_with_tasklet( ) else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=inner_sdfg) - fstate = else_body.add_state("false_branch", is_start_block=True) + fstate = else_body.add_state("false_branch_0_1_2_3_4", is_start_block=True) fstate.add_nedge( fstate.add_access(b2_name), fstate.add_access(output_name), From 9200078efec9f084285556718f82cdddfc8cb90c Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Feb 2026 14:28:50 +0100 Subject: [PATCH 55/61] Applied Edoardo's suggestion. --- .../move_dataflow_into_if_body.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index 182e09adbc..bd4989227c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -156,7 +156,7 @@ def can_be_applied( return False # Check if relatability is possible. - if not self._check_relocatability( + if not self._check_for_data_and_symbol_conflicts( sdfg=sdfg, state=graph, relocatable_dataflow=relocatable_dataflow, @@ -491,7 +491,7 @@ def _update_symbol_mapping( or dace_dtypes.typeclass(int), ) - def _check_relocatability( + def _check_for_data_and_symbol_conflicts( self, sdfg: dace.SDFG, state: dace.SDFGState, @@ -512,14 +512,16 @@ def _check_relocatability( all_relocated_dataflow: set[dace_nodes.Node] = functools.reduce( lambda s1, s2: s1.union(s2), relocatable_dataflow.values(), set() ) - subgraph_view = dace.sdfg.state.StateSubgraphView(state, all_relocated_dataflow) - requiered_symbols: set[str] = subgraph_view.free_symbols + requiered_symbols: set[str] = dace.sdfg.state.StateSubgraphView( + state, all_relocated_dataflow + ).free_symbols inner_data_names = if_block.sdfg.arrays.keys() for node_to_check in all_relocated_dataflow: if isinstance(node_to_check, dace_nodes.MapEntry): - # This means that a nested Map is fully relocated into the `if` block. - # When DaCe computes the free symbols, it removes these symbols. + # A Map is fully moved into the nested SDFG. `free_symbols` will ignore + # the Map's parameter (while including the one from the ranges). + # will now add them again to make sure that there are no clashes. # TODO(phimuell): Because of C++ scoping rules it might be possible # to skip this step, i.e. not add them to the set. assert node_to_check is not enclosing_map @@ -530,8 +532,8 @@ def _check_relocatability( and node_to_check.data in inner_data_names ): # There is already a data descriptor that is used on the inside as on - # the outside. Thous we would have to perform some renaming, which we - # currently does not. + # the outside. Thus we would have to perform some renaming, which we + # currently do not. # TODO(phimell): Handle this case. return False From e748e9b8edb6bb6d03b050bf228f70e23560ff72 Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Wed, 4 Feb 2026 14:39:42 +0100 Subject: [PATCH 56/61] After a discussion with Edoardo deleted that part. --- .../dace/transformations/move_dataflow_into_if_body.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py index bd4989227c..f421778051 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/move_dataflow_into_if_body.py @@ -518,15 +518,6 @@ def _check_for_data_and_symbol_conflicts( inner_data_names = if_block.sdfg.arrays.keys() for node_to_check in all_relocated_dataflow: - if isinstance(node_to_check, dace_nodes.MapEntry): - # A Map is fully moved into the nested SDFG. `free_symbols` will ignore - # the Map's parameter (while including the one from the ranges). - # will now add them again to make sure that there are no clashes. - # TODO(phimuell): Because of C++ scoping rules it might be possible - # to skip this step, i.e. not add them to the set. - assert node_to_check is not enclosing_map - requiered_symbols |= set(node_to_check.map.params) - if ( isinstance(node_to_check, dace_nodes.AccessNode) and node_to_check.data in inner_data_names From 562742be51469c31f6bfcdd9d605d60df0ef77a5 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 10:41:41 +0100 Subject: [PATCH 57/61] Address Philip's comments --- .../fuse_horizontal_conditionblocks.py | 109 +++++++++--------- .../transformations/remove_scalar_copies.py | 3 +- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index b5f40f5508..5a6420e0a8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -93,7 +93,7 @@ def can_be_applied( return False # Make sure that the conditional blocks contain only one conditional block each - if first_cb.sdfg.number_of_nodes() > 1 or second_cb.sdfg.number_of_nodes() > 1: + if first_cb.sdfg.number_of_nodes() != 1 or second_cb.sdfg.number_of_nodes() != 1: return False # Check that the symbol mappings are compatible @@ -170,9 +170,10 @@ def can_be_applied( # TODO(iomaganaris): Currently we do not handle NestedSDFGs inside the conditional blocks # however this should be very close to handling Tasklets which is currently working. Since # a test is missing and there is no immediate need for this feature, we leave it for future work. - for node in second_cb.sdfg.nodes(): - if isinstance(node, dace_nodes.NestedSDFG): - return False + for state in second_cb.sdfg.states(): + for node in state.nodes(): + if isinstance(node, dace_nodes.NestedSDFG): + return False return True @@ -197,30 +198,23 @@ def apply( # Store the new names for the arrays in the second conditional block to avoid name clashes and add their data descriptors # to the first conditional block SDFG - second_arrays_rename_map = {} + second_arrays_rename_map: dict[str, str] = {} for data_name, data_desc in original_arrays_second_conditional_block.items(): if data_name == "__cond": continue - if data_name in original_arrays_first_conditional_block: - new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" - second_arrays_rename_map[data_name] = new_data_name - data_desc_renamed = copy.deepcopy(data_desc) - data_desc_renamed.name = new_data_name - if new_data_name not in first_cb.sdfg.arrays: - first_cb.sdfg.add_datadesc(new_data_name, data_desc_renamed) - else: - second_arrays_rename_map[data_name] = data_name - if data_name not in first_cb.sdfg.arrays: - first_cb.sdfg.add_datadesc(data_name, copy.deepcopy(data_desc)) + new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" + second_arrays_rename_map[data_name] = new_data_name + data_desc_renamed = copy.deepcopy(data_desc) + first_cb.sdfg.add_datadesc(new_data_name, data_desc_renamed) second_conditional_states = list(second_conditional_block.all_states()) # Move the connectors from the second conditional block to the first in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} out_connectors_to_move = second_cb.out_connectors - in_connectors_to_move_rename_map = {} - out_connectors_to_move_rename_map = {} - for original_in_connector_name, _v in in_connectors_to_move.items(): + in_connectors_to_move_rename_map: dict[str, str] = {} + out_connectors_to_move_rename_map: dict[str, str] = {} + for original_in_connector_name in in_connectors_to_move: new_connector_name = original_in_connector_name if new_connector_name in first_cb.in_connectors: new_connector_name = second_arrays_rename_map[original_in_connector_name] @@ -231,7 +225,7 @@ def apply( dace_helpers.redirect_edge( state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb ) - for original_out_connector_name, _v in out_connectors_to_move.items(): + for original_out_connector_name in out_connectors_to_move: new_connector_name = original_out_connector_name if new_connector_name in first_cb.out_connectors: new_connector_name = second_arrays_rename_map[original_out_connector_name] @@ -246,22 +240,24 @@ def apply( def _find_corresponding_state_in_second( inner_state: dace.SDFGState, ) -> dace.SDFGState: - inner_state_name = inner_state.name - true_branch = "true_branch" in inner_state_name - corresponding_state_in_second = None - for state in second_conditional_states: - if true_branch and "true_branch" in state.name: - corresponding_state_in_second = state - break - elif not true_branch and "false_branch" in state.name: - corresponding_state_in_second = state - break - return corresponding_state_in_second + is_true_branch = "true_branch" in inner_state.name + branch_type = "true_branch" if is_true_branch else "false_branch" + return next(state for state in second_conditional_states if branch_type in state.name) + + corresponding_states_first_to_second = { + state: _find_corresponding_state_in_second(state) + for state in first_conditional_block.all_states() + } # Copy first the nodes from the second conditional block to the first - nodes_renamed_map = {} + # Create a dictionary that maps the original nodes in the second conditional + # block to the new nodes in the first conditional block to be able to properly connect the edges later + nodes_renamed_map: dict[dace_nodes.Node, dace_nodes.Node] = {} + # Save edges of second conditional block to a state to be able to delete the nodes from the second conditional block + second_state_edges_to_add: dict[dace.SDFGState, dace_graph.Edge] = {} for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) + corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] + second_state_edges_to_add[corresponding_state_in_second] = [] nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: new_node = node @@ -270,34 +266,39 @@ def _find_corresponding_state_in_second( new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) nodes_renamed_map[node] = new_node + second_state_edges_to_add[corresponding_state_in_second].extend( + corresponding_state_in_second.in_edges(node) + ) + second_state_edges_to_add[corresponding_state_in_second].extend( + corresponding_state_in_second.out_edges(node) + ) + # Remove the original node from the second conditional block to avoid any potential issues + # with the nodes coexisting in two states + corresponding_state_in_second.remove_node(node) first_inner_state.add_node(new_node) # Then copy the edges - second_to_first_connections = {} + second_to_first_connections: dict[str, str] = {} for node in nodes_renamed_map: if isinstance(node, dace_nodes.AccessNode): second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) - nodes_to_move = list(corresponding_state_in_second.nodes()) - for node in nodes_to_move: - for edge in list(corresponding_state_in_second.out_edges(node)): - dst = edge.dst - if dst in nodes_to_move: - new_memlet = copy.deepcopy(edge.data) - if edge.data.data in second_to_first_connections: - new_memlet.data = second_to_first_connections[edge.data.data] - first_inner_state.add_edge( - nodes_renamed_map[node], - nodes_renamed_map[node].data - if isinstance(node, dace_nodes.AccessNode) and edge.src_conn - else edge.src_conn, - nodes_renamed_map[dst], - second_to_first_connections[dst.data] - if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn - else edge.dst_conn, - new_memlet, - ) + corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] + for edge in second_state_edges_to_add[corresponding_state_in_second]: + new_memlet = copy.deepcopy(edge.data) + if edge.data.data in second_to_first_connections: + new_memlet.data = second_to_first_connections[edge.data.data] + first_inner_state.add_edge( + nodes_renamed_map[edge.src], + nodes_renamed_map[edge.src].data + if isinstance(edge.src, dace_nodes.AccessNode) and edge.src_conn + else edge.src_conn, + nodes_renamed_map[edge.dst], + second_to_first_connections[edge.dst.data] + if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn + else edge.dst_conn, + new_memlet, + ) for edge in list(graph.out_edges(conditional_access_node)): if edge.dst == second_cb: graph.remove_edge(edge) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index e4ddd56148..2cfef5f59d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -92,13 +92,12 @@ def can_be_applied( edges = list(graph.edges_between(first_node, second_node)) if len(edges) != 1: return False + edge = edges[0] # Check that the second access node has only one incoming edge, which is the one from the first access node. if graph.in_degree(second_node) != 1: return False - edge = edges[0] - # Check if the edge transfers only one element if edge.data.num_elements() != 1: return False From 636f4c5137b2809f0d2e0f76f79935308edb2b72 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 14:08:33 +0100 Subject: [PATCH 58/61] Fix issues --- .../fuse_horizontal_conditionblocks.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 5a6420e0a8..b6190de55f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -216,8 +216,7 @@ def apply( out_connectors_to_move_rename_map: dict[str, str] = {} for original_in_connector_name in in_connectors_to_move: new_connector_name = original_in_connector_name - if new_connector_name in first_cb.in_connectors: - new_connector_name = second_arrays_rename_map[original_in_connector_name] + new_connector_name = second_arrays_rename_map[original_in_connector_name] in_connectors_to_move_rename_map[original_in_connector_name] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): @@ -227,8 +226,7 @@ def apply( ) for original_out_connector_name in out_connectors_to_move: new_connector_name = original_out_connector_name - if new_connector_name in first_cb.out_connectors: - new_connector_name = second_arrays_rename_map[original_out_connector_name] + new_connector_name = second_arrays_rename_map[original_out_connector_name] out_connectors_to_move_rename_map[original_out_connector_name] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): @@ -262,9 +260,8 @@ def _find_corresponding_state_in_second( for node in nodes_to_move: new_node = node if isinstance(node, dace_nodes.AccessNode): - if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: - new_data_name = second_arrays_rename_map[node.data] - new_node = dace_nodes.AccessNode(new_data_name) + new_data_name = second_arrays_rename_map[node.data] + new_node = dace_nodes.AccessNode(new_data_name) nodes_renamed_map[node] = new_node second_state_edges_to_add[corresponding_state_in_second].extend( corresponding_state_in_second.in_edges(node) From caf69a359a9c7023cfbedfd2bfeef4595f6fee44 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 16:26:50 +0100 Subject: [PATCH 59/61] Address Philip's comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philip Müller --- .../fuse_horizontal_conditionblocks.py | 120 ++++++++---------- 1 file changed, 52 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index b6190de55f..8962b9be9d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -189,17 +189,16 @@ def apply( first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) - # Store original arrays to check later that all the necessary arrays have been moved - original_arrays_first_conditional_block = first_conditional_block.sdfg.arrays.copy() - original_arrays_second_conditional_block = second_conditional_block.sdfg.arrays.copy() - total_original_arrays = len(original_arrays_first_conditional_block) + len( - original_arrays_second_conditional_block + # Store number of original arrays to check later that all the necessary arrays have been moved + total_original_arrays = len(first_conditional_block.sdfg.arrays) + len( + second_conditional_block.sdfg.arrays ) - # Store the new names for the arrays in the second conditional block to avoid name clashes and add their data descriptors - # to the first conditional block SDFG + # Store the new names for the arrays of the second conditional block (transients and globals) to avoid name clashes and add their data descriptors + # to the first conditional block SDFG. We don't have to add `__cond` because we know it's the same for both conditional blocks. + # TODO(iomaganaris): Remove inputs to the conditional block that come from the same AccessNodes (same data) second_arrays_rename_map: dict[str, str] = {} - for data_name, data_desc in original_arrays_second_conditional_block.items(): + for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): if data_name == "__cond": continue new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" @@ -210,30 +209,24 @@ def apply( second_conditional_states = list(second_conditional_block.all_states()) # Move the connectors from the second conditional block to the first - in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} - out_connectors_to_move = second_cb.out_connectors - in_connectors_to_move_rename_map: dict[str, str] = {} - out_connectors_to_move_rename_map: dict[str, str] = {} - for original_in_connector_name in in_connectors_to_move: - new_connector_name = original_in_connector_name - new_connector_name = second_arrays_rename_map[original_in_connector_name] - in_connectors_to_move_rename_map[original_in_connector_name] = new_connector_name - first_cb.add_in_connector(new_connector_name) - for edge in graph.in_edges(second_cb): - if edge.dst_conn == original_in_connector_name: - dace_helpers.redirect_edge( - state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb - ) - for original_out_connector_name in out_connectors_to_move: - new_connector_name = original_out_connector_name - new_connector_name = second_arrays_rename_map[original_out_connector_name] - out_connectors_to_move_rename_map[original_out_connector_name] = new_connector_name - first_cb.add_out_connector(new_connector_name) - for edge in graph.out_edges(second_cb): - if edge.src_conn == original_out_connector_name: - dace_helpers.redirect_edge( - state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb - ) + for edge in graph.in_edges(second_cb): + if edge.dst_conn == "__cond": + continue + first_cb.add_in_connector(second_arrays_rename_map[edge.dst_conn]) + dace_helpers.redirect_edge( + state=graph, + edge=edge, + new_dst_conn=second_arrays_rename_map[edge.dst_conn], + new_dst=first_cb, + ) + for edge in graph.out_edges(second_cb): + first_cb.add_out_connector(second_arrays_rename_map[edge.src_conn]) + dace_helpers.redirect_edge( + state=graph, + edge=edge, + new_src_conn=second_arrays_rename_map[edge.src_conn], + new_src=first_cb, + ) def _find_corresponding_state_in_second( inner_state: dace.SDFGState, @@ -242,20 +235,14 @@ def _find_corresponding_state_in_second( branch_type = "true_branch" if is_true_branch else "false_branch" return next(state for state in second_conditional_states if branch_type in state.name) - corresponding_states_first_to_second = { - state: _find_corresponding_state_in_second(state) - for state in first_conditional_block.all_states() - } - # Copy first the nodes from the second conditional block to the first # Create a dictionary that maps the original nodes in the second conditional # block to the new nodes in the first conditional block to be able to properly connect the edges later nodes_renamed_map: dict[dace_nodes.Node, dace_nodes.Node] = {} - # Save edges of second conditional block to a state to be able to delete the nodes from the second conditional block - second_state_edges_to_add: dict[dace.SDFGState, dace_graph.Edge] = {} for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] - second_state_edges_to_add[corresponding_state_in_second] = [] + corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) + # Save edges of second conditional block to a state to be able to delete the nodes from the second conditional block + edges_to_copy = list(corresponding_state_in_second.edges()) nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: new_node = node @@ -263,43 +250,40 @@ def _find_corresponding_state_in_second( new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) nodes_renamed_map[node] = new_node - second_state_edges_to_add[corresponding_state_in_second].extend( - corresponding_state_in_second.in_edges(node) - ) - second_state_edges_to_add[corresponding_state_in_second].extend( - corresponding_state_in_second.out_edges(node) - ) # Remove the original node from the second conditional block to avoid any potential issues # with the nodes coexisting in two states corresponding_state_in_second.remove_node(node) first_inner_state.add_node(new_node) - # Then copy the edges - second_to_first_connections: dict[str, str] = {} - for node in nodes_renamed_map: - if isinstance(node, dace_nodes.AccessNode): - second_to_first_connections[node.data] = nodes_renamed_map[node].data - for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] - for edge in second_state_edges_to_add[corresponding_state_in_second]: - new_memlet = copy.deepcopy(edge.data) - if edge.data.data in second_to_first_connections: - new_memlet.data = second_to_first_connections[edge.data.data] - first_inner_state.add_edge( - nodes_renamed_map[edge.src], - nodes_renamed_map[edge.src].data - if isinstance(edge.src, dace_nodes.AccessNode) and edge.src_conn - else edge.src_conn, - nodes_renamed_map[edge.dst], - second_to_first_connections[edge.dst.data] - if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn - else edge.dst_conn, - new_memlet, + for edge_to_copy in edges_to_copy: + new_edge = first_inner_state.add_edge( + nodes_renamed_map[edge_to_copy.src], + edge_to_copy.src_conn, + nodes_renamed_map[edge_to_copy.dst], + edge_to_copy.dst_conn, + edge_to_copy.data, ) + if ( + not new_edge.data.is_empty() + ) and new_edge.data.data in second_arrays_rename_map: + new_edge.data.data = second_arrays_rename_map[new_edge.data.data] + for edge in list(graph.out_edges(conditional_access_node)): if edge.dst == second_cb: graph.remove_edge(edge) + # Copy missing symbols from second conditional block to the first one + missing_symbols = { + sym2: val2 + for sym2, val2 in second_cb.symbol_mapping.items() + if sym2 not in first_cb.symbol_mapping + } + for missing_symb, symb_def in missing_symbols.items(): + first_cb.symbol_mapping[missing_symb] = symb_def + first_cb.add_symbol( + missing_symb, second_cb.sdfg.symbols[missing_symb], find_new_name=False + ) + # TODO(iomaganaris): Atm need to remove both references to remove NestedSDFG from graph # second_conditional_block is inside the SDFG of NestedSDFG second_cb and removing only # one of them keeps a reference to the other one so none is properly deleted from the SDFG. From 4f03b21943b759d1f02801b9221a35681e9918d4 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 16:37:03 +0100 Subject: [PATCH 60/61] Removed check for nestedsdfg --- .../transformations/fuse_horizontal_conditionblocks.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 8962b9be9d..ba9a3d9ba5 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -167,14 +167,6 @@ def can_be_applied( ): return False - # TODO(iomaganaris): Currently we do not handle NestedSDFGs inside the conditional blocks - # however this should be very close to handling Tasklets which is currently working. Since - # a test is missing and there is no immediate need for this feature, we leave it for future work. - for state in second_cb.sdfg.states(): - for node in state.nodes(): - if isinstance(node, dace_nodes.NestedSDFG): - return False - return True def apply( From 902eed0c207ab9d66a8c94f615dc5355fbbe2c60 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Wed, 11 Feb 2026 10:03:44 +0100 Subject: [PATCH 61/61] Added test for symbols --- .../fuse_horizontal_conditionblocks.py | 2 +- .../test_fuse_horizontal_conditionblocks.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index ba9a3d9ba5..0c3d93587c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -272,7 +272,7 @@ def _find_corresponding_state_in_second( } for missing_symb, symb_def in missing_symbols.items(): first_cb.symbol_mapping[missing_symb] = symb_def - first_cb.add_symbol( + first_cb.sdfg.add_symbol( missing_symb, second_cb.sdfg.symbols[missing_symb], find_new_name=False ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 2eb11a55d7..45b3c5ff5b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -47,11 +47,13 @@ def _make_if_block_with_tasklet( then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) tstate = then_body.add_state("true_branch_0_1_2_3_4", is_start_block=True) + inner_sdfg.add_symbol("multiplier", dace.float64) + tasklet = tstate.add_tasklet( "true_tasklet", inputs={"__tasklet_in"}, outputs={"__tasklet_out"}, - code="__tasklet_out = __tasklet_in * 2.0", + code="__tasklet_out = __tasklet_in * multiplier", ) tstate.add_edge( tstate.add_access(b1_name), @@ -79,11 +81,13 @@ def _make_if_block_with_tasklet( if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) if_region.add_branch(dace.sdfg.state.CodeBlock(f"not {cond_name}"), else_body) - return state.add_nested_sdfg( + nested_sdfg = state.add_nested_sdfg( sdfg=inner_sdfg, inputs={b1_name, b2_name, cond_name}, outputs={output_name}, ) + nested_sdfg.symbol_mapping["multiplier"] = 2.0 + return nested_sdfg def _make_map_with_conditional_blocks() -> dace.SDFG: @@ -188,6 +192,10 @@ def test_fuse_horizontal_condition_blocks(): n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.sdfg.state.ConditionalBlock) ] assert len(new_conditional_blocks) == 1 + conditional_block = new_conditional_blocks[0] + assert ( + len(conditional_block.sdfg.symbols) == 1 and "multiplier" in conditional_block.sdfg.symbols + ) util.compile_and_run_sdfg(sdfg, **res) assert util.compare_sdfg_res(ref=ref, res=res)