From 2d84e42b8db40c8dcbd461d844ab2c947b82d17d Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 27 Nov 2025 13:57:07 +0100 Subject: [PATCH 01/34] Enable maxnreg setting --- .../dace/transformations/auto_optimize.py | 4 ++++ .../runners/dace/transformations/gpu_utils.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1d04c21fc3..1bd8e7ad9a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -116,6 +116,7 @@ def gt_auto_optimize( gpu_block_size_1d: Optional[Sequence[int | str] | str] = (64, 1, 1), gpu_block_size_2d: Optional[Sequence[int | str] | str] = None, gpu_block_size_3d: Optional[Sequence[int | str] | str] = None, + gpu_maxnreg: Optional[int] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, blocking_only_if_independent_nodes: bool = True, @@ -371,6 +372,7 @@ def gt_auto_optimize( gpu_block_size=gpu_block_size, gpu_launch_factor=gpu_launch_factor, gpu_launch_bounds=gpu_launch_bounds, + gpu_maxnreg=gpu_maxnreg, optimization_hooks=optimization_hooks, gpu_block_size_spec=gpu_block_size_spec if gpu_block_size_spec else None, validate_all=validate_all, @@ -747,6 +749,7 @@ def _gt_auto_configure_maps_and_strides( gpu_block_size: Optional[Sequence[int | str] | str], gpu_launch_bounds: Optional[int | str], gpu_launch_factor: Optional[int], + gpu_maxnreg: Optional[int], optimization_hooks: dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun], gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]], validate_all: bool, @@ -799,6 +802,7 @@ def _gt_auto_configure_maps_and_strides( gpu_launch_bounds=gpu_launch_bounds, gpu_launch_factor=gpu_launch_factor, gpu_block_size_spec=gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, try_removing_trivial_maps=True, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 1786913edb..a1df77cad4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -34,6 +34,7 @@ def gt_gpu_transformation( gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -123,6 +124,7 @@ def gt_gpu_transformation( launch_bounds=gpu_launch_bounds, launch_factor=gpu_launch_factor, **gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, ) @@ -362,6 +364,7 @@ def gt_set_gpu_blocksize( block_size: Optional[Sequence[int | str] | str], launch_bounds: Optional[int | str] = None, launch_factor: Optional[int] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -393,6 +396,7 @@ def gt_set_gpu_blocksize( }.items(): if f"{arg}_{dim}d" not in kwargs: kwargs[f"{arg}_{dim}d"] = val + kwargs["maxnreg"] = gpu_maxnreg setter = GPUSetBlockSize(**kwargs) @@ -590,6 +594,12 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): default=None, desc="Set the launch bound property for 3 dimensional map.", ) + maxnreg = dace_properties.Property( + dtype=int, + allow_none=True, + default=None, + desc="Set the maxnreg property for the GPU maps.", + ) # Pattern matching map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) @@ -605,6 +615,7 @@ def __init__( launch_factor_1d: int | None = None, launch_factor_2d: int | None = None, launch_factor_3d: int | None = None, + maxnreg: int | None = None, ) -> None: super().__init__() if block_size_1d is not None: @@ -632,6 +643,8 @@ def __init__( self.launch_bounds_3d = _gpu_launch_bound_parser( self.block_size_3d, launch_bounds_3d, launch_factor_3d ) + if maxnreg is not None: + self.maxnreg = maxnreg @classmethod def expressions(cls) -> Any: @@ -748,7 +761,9 @@ def apply( block_size[i] = map_size[map_dim_idx_to_inspect] gpu_map.gpu_block_size = tuple(block_size) - if launch_bounds is not None: # Note: empty string has a meaning in DaCe + if self.maxnreg is not None: + gpu_map.gpu_maxnreg = self.maxnreg + elif launch_bounds is not None: # Note: empty string has a meaning in DaCe gpu_map.gpu_launch_bounds = launch_bounds From bb7bd9190f1b17c597a81be10f9ea2fedd9eef3b Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:29:09 +0100 Subject: [PATCH 02/34] Added test for gpu_maxnreg --- .../transformation_tests/test_gpu_utils.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 1ef64da559..720a269547 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -351,3 +351,36 @@ def test_set_gpu_properties_2D_3D(): assert len(map4.params) == 4 assert map4.gpu_block_size == [1, 2, 32] assert map4.gpu_launch_bounds == "0" + + +def test_set_gpu_maxnreg(): + """Tests if gpu_maxnreg property is set correctly to GPU maps.""" + sdfg = dace.SDFG(util.unique_name("gpu_maxnreg_test")) + state = sdfg.add_state(is_start_block=True) + dim = 2 + shape = (10,) * (dim - 1) + (1,) + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + sdfg.validate() + sdfg.apply_gpu_transformations() + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size_1d=(128, 1, 1), + block_size_2d=(64, 2, 1), + block_size_3d=(2, 2, 32), + block_size=(32, 4, 1), + gpu_maxnreg=128, + ) + assert me.gpu_maxnreg == 128 From 80a4b8e07c2aefe327564b363f6bafc2b2e83bbe Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:31:08 +0100 Subject: [PATCH 03/34] Fix formating --- .../dace_tests/transformation_tests/test_gpu_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 720a269547..dd4927bb03 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -359,12 +359,8 @@ def test_set_gpu_maxnreg(): state = sdfg.add_state(is_start_block=True) dim = 2 shape = (10,) * (dim - 1) + (1,) - sdfg.add_array( - f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) - sdfg.add_array( - f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) + sdfg.add_array(f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) + sdfg.add_array(f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) _, me, _ = state.add_mapped_tasklet( f"map_{dim}", map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, From 17f8b37cb722020bcafac95bda61eb87e3e2e7c1 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 09:19:43 +0100 Subject: [PATCH 04/34] Mention that maxnreg takes precedence over launch bounds --- .../runners/dace/transformations/gpu_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index a1df77cad4..c565ff0f0b 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -598,7 +598,7 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): dtype=int, allow_none=True, default=None, - desc="Set the maxnreg property for the GPU maps.", + desc="Set the maxnreg property for the GPU maps. Takes precedence over any launch_bounds.", ) # Pattern matching From 923e3cd7d5d1c21551aa815aaba90fb66eff46b0 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 27 Nov 2025 13:57:07 +0100 Subject: [PATCH 05/34] Enable maxnreg setting --- .../dace/transformations/auto_optimize.py | 4 ++++ .../runners/dace/transformations/gpu_utils.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 486af807ab..8a11b64714 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -117,6 +117,7 @@ def gt_auto_optimize( gpu_block_size_1d: Optional[Sequence[int | str] | str] = (64, 1, 1), gpu_block_size_2d: Optional[Sequence[int | str] | str] = None, gpu_block_size_3d: Optional[Sequence[int | str] | str] = None, + gpu_maxnreg: Optional[int] = None, blocking_dim: Optional[gtx_common.Dimension] = None, blocking_size: int = 10, blocking_only_if_independent_nodes: bool = True, @@ -375,6 +376,7 @@ def gt_auto_optimize( gpu_block_size=gpu_block_size, gpu_launch_factor=gpu_launch_factor, gpu_launch_bounds=gpu_launch_bounds, + gpu_maxnreg=gpu_maxnreg, optimization_hooks=optimization_hooks, gpu_block_size_spec=gpu_block_size_spec if gpu_block_size_spec else None, validate_all=validate_all, @@ -781,6 +783,7 @@ def _gt_auto_configure_maps_and_strides( gpu_block_size: Optional[Sequence[int | str] | str], gpu_launch_bounds: Optional[int | str], gpu_launch_factor: Optional[int], + gpu_maxnreg: Optional[int], optimization_hooks: dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun], gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]], validate_all: bool, @@ -833,6 +836,7 @@ def _gt_auto_configure_maps_and_strides( gpu_launch_bounds=gpu_launch_bounds, gpu_launch_factor=gpu_launch_factor, gpu_block_size_spec=gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, try_removing_trivial_maps=True, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 1786913edb..a1df77cad4 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -34,6 +34,7 @@ def gt_gpu_transformation( gpu_launch_bounds: Optional[int | str] = None, gpu_launch_factor: Optional[int] = None, gpu_block_size_spec: Optional[dict[str, Sequence[int | str] | str]] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -123,6 +124,7 @@ def gt_gpu_transformation( launch_bounds=gpu_launch_bounds, launch_factor=gpu_launch_factor, **gpu_block_size_spec, + gpu_maxnreg=gpu_maxnreg, validate=False, validate_all=validate_all, ) @@ -362,6 +364,7 @@ def gt_set_gpu_blocksize( block_size: Optional[Sequence[int | str] | str], launch_bounds: Optional[int | str] = None, launch_factor: Optional[int] = None, + gpu_maxnreg: Optional[int] = None, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -393,6 +396,7 @@ def gt_set_gpu_blocksize( }.items(): if f"{arg}_{dim}d" not in kwargs: kwargs[f"{arg}_{dim}d"] = val + kwargs["maxnreg"] = gpu_maxnreg setter = GPUSetBlockSize(**kwargs) @@ -590,6 +594,12 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): default=None, desc="Set the launch bound property for 3 dimensional map.", ) + maxnreg = dace_properties.Property( + dtype=int, + allow_none=True, + default=None, + desc="Set the maxnreg property for the GPU maps.", + ) # Pattern matching map_entry = dace_transformation.PatternNode(dace_nodes.MapEntry) @@ -605,6 +615,7 @@ def __init__( launch_factor_1d: int | None = None, launch_factor_2d: int | None = None, launch_factor_3d: int | None = None, + maxnreg: int | None = None, ) -> None: super().__init__() if block_size_1d is not None: @@ -632,6 +643,8 @@ def __init__( self.launch_bounds_3d = _gpu_launch_bound_parser( self.block_size_3d, launch_bounds_3d, launch_factor_3d ) + if maxnreg is not None: + self.maxnreg = maxnreg @classmethod def expressions(cls) -> Any: @@ -748,7 +761,9 @@ def apply( block_size[i] = map_size[map_dim_idx_to_inspect] gpu_map.gpu_block_size = tuple(block_size) - if launch_bounds is not None: # Note: empty string has a meaning in DaCe + if self.maxnreg is not None: + gpu_map.gpu_maxnreg = self.maxnreg + elif launch_bounds is not None: # Note: empty string has a meaning in DaCe gpu_map.gpu_launch_bounds = launch_bounds From 0d2b3e14cf9acc520a2d89174a010678cea1df44 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:29:09 +0100 Subject: [PATCH 06/34] Added test for gpu_maxnreg --- .../transformation_tests/test_gpu_utils.py | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 1ef64da559..720a269547 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -351,3 +351,36 @@ def test_set_gpu_properties_2D_3D(): assert len(map4.params) == 4 assert map4.gpu_block_size == [1, 2, 32] assert map4.gpu_launch_bounds == "0" + + +def test_set_gpu_maxnreg(): + """Tests if gpu_maxnreg property is set correctly to GPU maps.""" + sdfg = dace.SDFG(util.unique_name("gpu_maxnreg_test")) + state = sdfg.add_state(is_start_block=True) + dim = 2 + shape = (10,) * (dim - 1) + (1,) + sdfg.add_array( + f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + sdfg.add_array( + f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global + ) + _, me, _ = state.add_mapped_tasklet( + f"map_{dim}", + map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, + inputs={"__in": dace.Memlet(f"A_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + code="__out = math.cos(__in)", + outputs={"__out": dace.Memlet(f"B_{dim}[{','.join(f'__i{i}' for i in range(dim))}]")}, + external_edges=True, + ) + sdfg.validate() + sdfg.apply_gpu_transformations() + gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize( + sdfg=sdfg, + block_size_1d=(128, 1, 1), + block_size_2d=(64, 2, 1), + block_size_3d=(2, 2, 32), + block_size=(32, 4, 1), + gpu_maxnreg=128, + ) + assert me.gpu_maxnreg == 128 From 215e775cc41826a3030f98c289b7343dd81f3726 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 14:31:08 +0100 Subject: [PATCH 07/34] Fix formating --- .../dace_tests/transformation_tests/test_gpu_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index 720a269547..dd4927bb03 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -359,12 +359,8 @@ def test_set_gpu_maxnreg(): state = sdfg.add_state(is_start_block=True) dim = 2 shape = (10,) * (dim - 1) + (1,) - sdfg.add_array( - f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) - sdfg.add_array( - f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global - ) + sdfg.add_array(f"A_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) + sdfg.add_array(f"B_{dim}", shape=shape, dtype=dace.float64, storage=dace.StorageType.GPU_Global) _, me, _ = state.add_mapped_tasklet( f"map_{dim}", map_ranges={f"__i{i}": f"0:{s}" for i, s in enumerate(shape)}, From d58ade95788075d3d159483856534482e5a17dba Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 09:19:43 +0100 Subject: [PATCH 08/34] Mention that maxnreg takes precedence over launch bounds --- .../runners/dace/transformations/gpu_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index a1df77cad4..c565ff0f0b 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -598,7 +598,7 @@ class GPUSetBlockSize(dace_transformation.SingleStateTransformation): dtype=int, allow_none=True, default=None, - desc="Set the maxnreg property for the GPU maps.", + desc="Set the maxnreg property for the GPU maps. Takes precedence over any launch_bounds.", ) # Pattern matching From 956282f77338facfb6d73ed127ce127125d2c67c Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 22 Jan 2026 15:31:28 +0100 Subject: [PATCH 09/34] Remove scalar copies --- .../runners/dace/transformations/__init__.py | 1 + .../dace/transformations/auto_optimize.py | 20 +++ .../transformations/kill_aliasing_scalars.py | 143 ++++++++++++++++++ .../test_kill_aliasing_scalars.py | 85 +++++++++++ 4 files changed, 249 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index a1b766100a..31ed0a5ef7 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -27,6 +27,7 @@ gt_set_gpu_blocksize, ) from .inline_fuser import inline_dataflow_into_map +from .kill_aliasing_scalars import KillAliasingScalars from .local_double_buffering import gt_create_local_double_buffering from .loop_blocking import LoopBlocking from .map_fusion import ( diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 8a11b64714..6a8d83268b 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -773,6 +773,26 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + sdfg.save("before_kill_aliasing_scalars.sdfg") + + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + ), + validate=False, + validate_all=validate_all, + ) + # sdfg.apply_transformations_repeated( + # gtx_transformations.CopyChainRemover( + # single_use_data=single_use_data, + # ), + # validate=False, + # validate_all=validate_all, + # ) + return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py b/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py new file mode 100644 index 0000000000..db5bc96a44 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py @@ -0,0 +1,143 @@ +import copy +import warnings +from typing import Any, Callable, Mapping, Optional, TypeAlias, Union + +import dace +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) +from dace.sdfg import nodes as dace_nodes +from dace.transformation import helpers as dace_helpers +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + +@dace_properties.make_properties +class KillAliasingScalars(dace_transformation.SingleStateTransformation): + first_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + second_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + + assume_single_use_data = dace_properties.Property( + dtype=bool, + default=False, + desc="Always assume that `self.access_node` is single use data. Only useful if used through `SplitAccessNode.apply_to()`.", + ) + + # Name of all data that is used at only one place. Is computed by the + # `FindSingleUseData` pass and be passed at construction time. Needed until + # [issue#1911](https://github.com/spcl/dace/issues/1911) has been solved. + _single_use_data: Optional[dict[dace.SDFG, set[str]]] + + def __init__( + self, + *args: Any, + single_use_data: Optional[dict[dace.SDFG, set[str]]] = None, + assume_single_use_data: Optional[bool] = None, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + self._single_use_data = single_use_data + if assume_single_use_data is not None: + self.assume_single_use_data = assume_single_use_data + + @classmethod + def expressions(cls) -> Any: + return [dace.sdfg.utils.node_path_graph(cls.first_access_node, cls.second_access_node)] + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + first_node: dace_nodes.AccessNode = self.first_access_node + first_node_desc = first_node.desc(sdfg) + second_node: dace_nodes.AccessNode = self.second_access_node + second_node_desc = second_node.desc(sdfg) + + scope_dict = graph.scope_dict() + if first_node not in scope_dict or second_node not in scope_dict: + return False + + if scope_dict[first_node] != scope_dict[second_node]: + return False + + if not first_node_desc.transient or not second_node_desc.transient: + return False + + edges = graph.edges_between(first_node, second_node) + assert len(edges) == 1 + edge = next(iter(edges)) + # Check if edge volume is 1 + if edge.data.num_elements() != 1: + return False + if edge.data.dynamic: + return False + + for out_edges in graph.out_edges(second_node): + if out_edges.data.num_elements() != 1: + return False + # if out_edges.data.dynamic: + # return False + # breakpoint() + # subset: dace_subsets.Subset = edge.data.get("subset", None) + # if subset is None: + # return False + + # Make sure that the edge subset is 1 + if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( + second_node_desc, dace.data.Scalar): + return False + + # Make sure that both access nodes are not views + if isinstance(first_node_desc, dace.data.View) or isinstance( + second_node_desc, dace.data.View): + return False + + # Make sure that both access nodes are transients + if not first_node_desc.transient or not second_node_desc.transient: + return False + + if graph.in_degree(second_node) != 1: + return False + + if self.assume_single_use_data: + single_use_data = {sdfg: {first_node.data}} + if self._single_use_data is None: + find_single_use_data = first_node.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + else: + single_use_data = self._single_use_data + if first_node.data not in single_use_data[sdfg]: + return False + + if self.assume_single_use_data: + single_use_data = {sdfg: {second_node.data}} + if self._single_use_data is None: + find_single_use_data = second_node.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + else: + single_use_data = self._single_use_data + if second_node.data not in single_use_data[sdfg]: + return False + + return True + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + first_node: dace_nodes.AccessNode = self.first_access_node + second_node: dace_nodes.AccessNode = self.second_access_node + + # Redirect all outcoming edges of the second access node to the first + for edge in list(graph.out_edges(second_node)): + dace_helpers.redirect_edge(state=graph, edge=edge, new_src=first_node, new_data=first_node.data if edge.data.data == second_node.data else edge.data.data) + # edge.subset = first_node.desc(sdfg).get_subset() + # if edge.other_subset is not None: + # edge.other_subset = edge.dst.desc(sdfg).get_subset() + + # Remove the second access node + graph.remove_node(second_node) \ No newline at end of file diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py new file mode 100644 index 0000000000..1d03b3e3e8 --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py @@ -0,0 +1,85 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace.transformation.passes import analysis as dace_analysis + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util + +import dace + + +def _make_map_with_scalar_copies() -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit +]: + sdfg = dace.SDFG(util.unique_name("scalar_elimination")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array( + "a", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "b", + shape=(10,), + dtype=dace.float64, + transient=True, + ) + + a, b = (state.add_access(name) for name in "ab") + for i in range(3): + sdfg.add_scalar(f"tmp{i}", dtype=dace.float64, transient=True) + tmp0, tmp1, tmp2 = (state.add_access(f"tmp{i}") for i in range(3)) + + me, mx = state.add_map("copy_map", ndrange={"__i": "0:10"}) + me.add_in_connector("IN_a") + me.add_out_connector("OUT_a") + mx.add_in_connector("IN_b") + mx.add_out_connector("OUT_b") + state.add_edge(a, None, me, "IN_a", dace.Memlet("a[__i]")) + state.add_edge(me, "OUT_a", tmp0, None, dace.Memlet("a[__i]")) + state.add_edge(tmp0, None, tmp1, None, dace.Memlet("tmp1[0]")) + state.add_edge(tmp1, None, tmp2, None, dace.Memlet("tmp1[0]")) + state.add_edge(tmp2, None, mx, "IN_b", dace.Memlet("[0] -> b[__i]")) + state.add_edge(mx, "OUT_b", b, None, dace.Memlet("b[__i]")) + + sdfg.validate() + return sdfg, state, me, mx + + +def test_remove_double_write_single_consumer(): + sdfg, state, me, mx = _make_map_with_scalar_copies() + + access_nodes_inside_original_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + assert access_nodes_inside_original_map == 3 + + sdfg.view() + breakpoint() + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + assume_single_use_data=False, + ), + validate=True, + validate_all=True, + ) + sdfg.view() + breakpoint() + access_nodes_inside_new_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + assert access_nodes_inside_new_map == 1 From 03c8cbb15d424a6ad1a9ccf1de2b860cf54c8566 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Fri, 23 Jan 2026 10:08:28 +0100 Subject: [PATCH 10/34] Rename if statements in _make_if_block --- .../transformation_tests/test_move_dataflow_into_if_body.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 03aba6599e..5338b49920 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -37,7 +37,7 @@ def _make_if_block( b2_type: dace.typeclass = dace.float64, output_type: dace.typeclass = dace.float64, ) -> dace_nodes.NestedSDFG: - inner_sdfg = dace.SDFG(util.unique_name("inner_sdfg")) + inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} for name in {b1_name, b2_name, cond_name, output_name}: @@ -47,7 +47,7 @@ def _make_if_block( transient=False, ) - if_region = dace.sdfg.state.ConditionalBlock("if") + if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) From e10b2ee6f16dc28829fb7aa4db76bba2ad0e7b31 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Fri, 23 Jan 2026 10:08:50 +0100 Subject: [PATCH 11/34] [WIP] Added fuse condition block transformation --- .../runners/dace/transformations/__init__.py | 1 + .../fuse_horizontal_conditionblocks.py | 115 +++++++++++++++++ .../test_fuse_horizontal_conditionblocks.py | 119 ++++++++++++++++++ 3 files changed, 235 insertions(+) create mode 100644 src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py create mode 100644 tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 31ed0a5ef7..f899c7e6a2 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -20,6 +20,7 @@ gt_auto_optimize, ) from .dead_dataflow_elimination import gt_eliminate_dead_dataflow, gt_remove_map +from .fuse_horizontal_conditionblocks import FuseHorizontalConditionBlocks from .gpu_utils import ( GPUSetBlockSize, gt_gpu_transform_non_standard_memlet, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py new file mode 100644 index 0000000000..e68f93a48a --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -0,0 +1,115 @@ +import copy +import warnings +from typing import Any, Callable, Mapping, Optional, TypeAlias, Union + +import dace +from dace import ( + properties as dace_properties, + subsets as dace_subsets, + transformation as dace_transformation, +) + +from dace.sdfg import nodes as dace_nodes, graph as dace_graph +from dace.transformation import helpers as dace_helpers +from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + +@dace_properties.make_properties +class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): + access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + first_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) + second_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) + + @classmethod + def expressions(cls) -> Any: + map_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() + map_fusion_parallel_match.add_nedge(cls.access_node, cls.first_conditional_block, dace.Memlet()) + map_fusion_parallel_match.add_nedge(cls.access_node, cls.second_conditional_block, dace.Memlet()) + return [map_fusion_parallel_match] + + def can_be_applied( + self, + graph: Union[dace.SDFGState, dace.SDFG], + expr_index: int, + sdfg: dace.SDFG, + permissive: bool = False, + ) -> bool: + access_node: dace_nodes.AccessNode = self.access_node + access_node_desc = access_node.desc(sdfg) + first_cb: dace_nodes.NestedSDFG = self.first_conditional_block + second_cb: dace_nodes.NestedSDFG = self.second_conditional_block + scope_dict = graph.scope_dict() + + if first_cb.sdfg.parent != second_cb.sdfg.parent: + return False + # Check that the nested SDFGs' names starts with "if_stmt" + if not (first_cb.sdfg.name.startswith("if_stmt") and second_cb.sdfg.name.startswith("if_stmt")): + return False + + # Check that the common access node is a boolean scalar + if not isinstance(access_node_desc, dace.data.Scalar) or access_node_desc.dtype != dace.bool_: + return False + + if len(first_cb.sdfg.nodes()) > 1 or len(second_cb.sdfg.nodes()) > 1: + return False + + first_conditional_block = next(iter(first_cb.sdfg.nodes())) + second_conditional_block = next(iter(second_cb.sdfg.nodes())) + if not (isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock)): + return False + + if scope_dict[first_cb] != scope_dict[second_cb]: + return False + + if not isinstance(scope_dict[first_cb], dace.nodes.MapEntry): + return False + + # Check that there is an edge to the conditional blocks with dst_conn == "__cond" + cond_edges_first = [e for e in graph.in_edges(first_cb) if e.dst_conn == "__cond"] + if len(cond_edges_first) != 1: + return False + cond_edges_second = [e for e in graph.in_edges(second_cb) if e.dst_conn == "__cond"] + if len(cond_edges_second) != 1: + return False + cond_edge_first = cond_edges_first[0] + cond_edge_second = cond_edges_second[0] + if cond_edge_first.src != cond_edge_second.src and (cond_edge_first.src != access_node or cond_edge_second.src != access_node): + return False + + print(f"Found valid conditional blocks: {first_cb} and {second_cb}", flush=True) + # breakpoint() + + # TODO(iomaganaris): Need to check also that first and second nested SDFGs are not reachable from each other + + return True + + def apply( + self, + graph: Union[dace.SDFGState, dace.SDFG], + sdfg: dace.SDFG, + ) -> None: + access_node: dace_nodes.AccessNode = self.access_node + first_cb: dace.sdfg.state.ConditionalBlock = self.first_conditional_block + second_cb: dace.sdfg.state.ConditionalBlock = self.second_conditional_block + + first_conditional_block = next(iter(first_cb.sdfg.nodes())) + second_conditional_block = next(iter(second_cb.sdfg.nodes())) + + second_conditional_states = list(second_conditional_block.all_states()) + + for first_inner_state in first_conditional_block.all_states(): + first_inner_state_name = first_inner_state.name + corresponding_state_in_second = None + for state in second_conditional_states: + if state.name == first_inner_state_name: + corresponding_state_in_second = state + break + if corresponding_state_in_second is None: + raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + nodes_to_move = list(corresponding_state_in_second.nodes()) + in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} + out_connectors_to_move = second_cb.out_connectors + breakpoint() + + + # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) + # breakpoint() \ No newline at end of file diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py new file mode 100644 index 0000000000..2d5025affa --- /dev/null +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -0,0 +1,119 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +dace = pytest.importorskip("dace") +from dace.sdfg import nodes as dace_nodes +from dace.transformation.passes import analysis as dace_analysis + +from gt4py.next.program_processors.runners.dace import ( + transformations as gtx_transformations, +) + +from . import util +from .test_move_dataflow_into_if_body import _make_if_block + +import dace + +def _make_map_with_conditional_blocks() -> tuple[ + dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit +]: + sdfg = dace.SDFG(util.unique_name("map_with_conditional_blocks")) + state = sdfg.add_state(is_start_block=True) + + sdfg.add_array( + "a", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "b", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "c", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + sdfg.add_array( + "d", + shape=(10,), + dtype=dace.float64, + transient=False, + ) + a, b, c, d = (state.add_access(name) for name in "abcd") + + for tmp_name in ["tmp_a", "tmp_b", "tmp_c", "tmp_d"]: + sdfg.add_scalar(tmp_name, dtype=dace.float64, transient=True) + tmp_a, tmp_b, tmp_c, tmp_d = (state.add_access(f"tmp_{name}") for name in "abcd") + + sdfg.add_scalar("cond_var", dtype=dace.bool_, transient=True) + cond_var = state.add_access("cond_var") + + me, mx = state.add_map("map_with_ifs", ndrange={"__i": "0:10"}) + me.add_in_connector("IN_a") + me.add_out_connector("OUT_a") + me.add_in_connector("IN_b") + me.add_out_connector("OUT_b") + mx.add_in_connector("IN_c") + mx.add_out_connector("OUT_c") + mx.add_in_connector("IN_d") + mx.add_out_connector("OUT_d") + state.add_edge(a, None, me, "IN_a", dace.Memlet("a[__i]")) + state.add_edge(b, None, me, "IN_b", dace.Memlet("b[__i]")) + state.add_edge(me, "OUT_a", tmp_a, None, dace.Memlet("a[__i]")) + state.add_edge(me, "OUT_b", tmp_b, None, dace.Memlet("b[__i]")) + + tasklet_cond = state.add_tasklet( + "tasklet_cond", + inputs={"__in"}, + outputs={"__out"}, + code="__out = __in <= 0.0", + ) + state.add_edge(tmp_a, None, tasklet_cond, "__in", dace.Memlet("tmp_a[0]")) + state.add_edge(tasklet_cond, "__out", cond_var, None, dace.Memlet("cond_var")) + + if_block_0 = _make_if_block(state=state, outer_sdfg=sdfg) + state.add_edge(cond_var, None, if_block_0, "__cond", dace.Memlet("cond_var")) + state.add_edge(tmp_a, None, if_block_0, "__arg1", dace.Memlet("tmp_a[0]")) + state.add_edge(tmp_b, None, if_block_0, "__arg2", dace.Memlet("tmp_b[0]")) + state.add_edge(if_block_0, "__output", tmp_c, None, dace.Memlet("tmp_c[0]")) + state.add_edge(tmp_c, None, mx, "IN_c", dace.Memlet("c[__i]")) + + if_block_1 = _make_if_block(state=state, outer_sdfg=sdfg) + state.add_edge(cond_var, None, if_block_1, "__cond", dace.Memlet("cond_var")) + state.add_edge(tmp_a, None, if_block_1, "__arg1", dace.Memlet("tmp_a[0]")) + state.add_edge(tmp_b, None, if_block_1, "__arg2", dace.Memlet("tmp_b[0]")) + state.add_edge(if_block_1, "__output", tmp_d, None, dace.Memlet("tmp_d[0]")) + state.add_edge(tmp_d, None, mx, "IN_d", dace.Memlet("d[__i]")) + + state.add_edge(mx, "OUT_c", c, None, dace.Memlet("c[__i]")) + state.add_edge(mx, "OUT_d", d, None, dace.Memlet("d[__i]")) + + sdfg.validate() + return sdfg, state, me, mx + +def test_fuse_horizontal_condition_blocks(): + sdfg, state, me, mx = _make_map_with_conditional_blocks() + + # sdfg.view() + # breakpoint() + + sdfg.apply_transformations_repeated( + gtx_transformations.FuseHorizontalConditionBlocks(), + validate=True, + validate_all=True, + ) + + # sdfg.view() + # breakpoint() From a2029f32b28b9ab52a16ade7ea8026414691a688 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 26 Jan 2026 17:47:07 +0100 Subject: [PATCH 12/34] [WIP] Copy access nodes between condition blocks --- .../fuse_horizontal_conditionblocks.py | 80 +++++++++++++++++-- 1 file changed, 75 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index e68f93a48a..4836f2115d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -12,6 +12,7 @@ from dace.sdfg import nodes as dace_nodes, graph as dace_graph from dace.transformation import helpers as dace_helpers from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations +from dace.sdfg import utils as sdutil @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): @@ -88,14 +89,66 @@ def apply( sdfg: dace.SDFG, ) -> None: access_node: dace_nodes.AccessNode = self.access_node - first_cb: dace.sdfg.state.ConditionalBlock = self.first_conditional_block - second_cb: dace.sdfg.state.ConditionalBlock = self.second_conditional_block + first_cb: dace_nodes.NestedSDFG = self.first_conditional_block + second_cb: dace_nodes.NestedSDFG = self.second_conditional_block first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) second_conditional_states = list(second_conditional_block.all_states()) + in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} + out_connectors_to_move = second_cb.out_connectors + in_connectors_to_move_rename_map = {} + out_connectors_to_move_rename_map = {} + for k, v in in_connectors_to_move.items(): + new_connector_name = k + if new_connector_name in first_cb.in_connectors: + new_connector_name = f"{k}_from_second" + in_connectors_to_move_rename_map[k] = new_connector_name + first_cb.add_in_connector(new_connector_name) + for edge in graph.in_edges(second_cb): + if edge.dst_conn == k: + dace_helpers.redirect_edge(state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb) + for k, v in out_connectors_to_move.items(): + new_connector_name = k + if new_connector_name in first_cb.out_connectors: + new_connector_name = f"{k}_from_second" + out_connectors_to_move_rename_map[k] = new_connector_name + first_cb.add_out_connector(new_connector_name) + for edge in graph.out_edges(second_cb): + if edge.src_conn == k: + dace_helpers.redirect_edge(state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb) + + nodes_renamed_map = {} + for first_inner_state in first_conditional_block.all_states(): + first_inner_state_name = first_inner_state.name + corresponding_state_in_second = None + for state in second_conditional_states: + if state.name == first_inner_state_name: + corresponding_state_in_second = state + break + if corresponding_state_in_second is None: + raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + nodes_to_move = list(corresponding_state_in_second.nodes()) + for node in nodes_to_move: + new_node = node + if isinstance(node, dace_nodes.AccessNode): + if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: + new_data_name = f"{node.data}_from_second" + new_node = dace_nodes.AccessNode(new_data_name) + new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) + new_desc.name = new_data_name + if new_data_name not in first_cb.sdfg.arrays: + first_cb.sdfg.add_datadesc(new_data_name, new_desc) + else: + second_cb.sdfg.remove_data(node.data) + nodes_renamed_map[node] = new_node + first_inner_state.add_node(new_node) + + second_to_first_connections = {} + for node in nodes_renamed_map: + second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): first_inner_state_name = first_inner_state.name corresponding_state_in_second = None @@ -106,9 +159,26 @@ def apply( if corresponding_state_in_second is None: raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") nodes_to_move = list(corresponding_state_in_second.nodes()) - in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} - out_connectors_to_move = second_cb.out_connectors - breakpoint() + for node in nodes_to_move: + for edge in list(corresponding_state_in_second.out_edges(node)): + dst = edge.dst + if dst in nodes_to_move: + new_memlet = copy.deepcopy(edge.data) + if edge.data.data in second_to_first_connections: + new_memlet.data = second_to_first_connections[edge.data.data] + first_inner_state.add_edge(nodes_renamed_map[node], nodes_renamed_map[node].data, nodes_renamed_map[edge.dst], second_to_first_connections[node.data], new_memlet) + for edge in list(graph.out_edges(access_node)): + if edge.dst == second_cb: + graph.remove_edge(edge) + + # TODO(iomaganaris): Figure out if I have to handle any symbols + + # Need to remove both references to remove NestedSDFG from graph + graph.remove_node(second_conditional_block) + graph.remove_node(second_cb) + + sdfg.view() + breakpoint() # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) From 8961d5e764daa37e97084e851c5e781499b57317 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 26 Jan 2026 20:36:10 +0100 Subject: [PATCH 13/34] Fix branch selection --- .../fuse_horizontal_conditionblocks.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 4836f2115d..9c239cb1a5 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -123,9 +123,13 @@ def apply( nodes_renamed_map = {} for first_inner_state in first_conditional_block.all_states(): first_inner_state_name = first_inner_state.name + true_branch = "true_branch" in first_inner_state_name corresponding_state_in_second = None for state in second_conditional_states: - if state.name == first_inner_state_name: + if true_branch and "true_branch" in state.name: + corresponding_state_in_second = state + break + elif not true_branch and "false_branch" in state.name: corresponding_state_in_second = state break if corresponding_state_in_second is None: @@ -151,9 +155,13 @@ def apply( second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): first_inner_state_name = first_inner_state.name + true_branch = "true_branch" in first_inner_state_name corresponding_state_in_second = None for state in second_conditional_states: - if state.name == first_inner_state_name: + if true_branch and "true_branch" in state.name: + corresponding_state_in_second = state + break + elif not true_branch and "false_branch" in state.name: corresponding_state_in_second = state break if corresponding_state_in_second is None: From 49b7619347b0b17f57c1c1ad3edee2db70deae6c Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 27 Jan 2026 09:16:23 +0100 Subject: [PATCH 14/34] Moved kill aliasing transformation and updates to the fuse conditionals --- .../dace/transformations/auto_optimize.py | 34 ++++++++----------- .../fuse_horizontal_conditionblocks.py | 28 ++++++++++++--- 2 files changed, 37 insertions(+), 25 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 6a8d83268b..1b9d9b2157 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -652,9 +652,23 @@ def _gt_auto_process_top_level_maps( validate_all=validate_all, skip=gtx_transformations.constants._GT_AUTO_OPT_TOP_LEVEL_STAGE_SIMPLIFY_SKIP_LIST, ) + sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") + + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + ), + validate=False, + validate_all=validate_all, + ) + sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") if GT4PyAutoOptHook.TopLevelDataFlowPost in optimization_hooks: optimization_hooks[GT4PyAutoOptHook.TopLevelDataFlowPost](sdfg) # type: ignore[call-arg] + sdfg.save("after_top_level_map_optimization.sdfg") return sdfg @@ -773,26 +787,6 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - sdfg.save("before_kill_aliasing_scalars.sdfg") - - find_single_use_data = dace_analysis.FindSingleUseData() - single_use_data = find_single_use_data.apply_pass(sdfg, None) - - sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( - single_use_data=single_use_data, - ), - validate=False, - validate_all=validate_all, - ) - # sdfg.apply_transformations_repeated( - # gtx_transformations.CopyChainRemover( - # single_use_data=single_use_data, - # ), - # validate=False, - # validate_all=validate_all, - # ) - return sdfg diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 9c239cb1a5..83fbae70c8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -1,4 +1,5 @@ import copy +import uuid import warnings from typing import Any, Callable, Mapping, Optional, TypeAlias, Union @@ -14,6 +15,14 @@ from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations from dace.sdfg import utils as sdutil +def unique_name(name: str) -> str: + """Adds a unique string to `name`.""" + maximal_length = 200 + unique_sufix = str(uuid.uuid1()).replace("-", "_") + if len(name) > (maximal_length - len(unique_sufix)): + name = name[: (maximal_length - len(unique_sufix) - 1)] + return f"{name}_{unique_sufix}" + @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -80,6 +89,16 @@ def can_be_applied( # breakpoint() # TODO(iomaganaris): Need to check also that first and second nested SDFGs are not reachable from each other + if gtx_transformations.utils.is_reachable( + start=first_cb, + target=second_cb, + state=graph, + ) or gtx_transformations.utils.is_reachable( + start=second_cb, + target=first_cb, + state=graph, + ): + return False return True @@ -104,7 +123,7 @@ def apply( for k, v in in_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.in_connectors: - new_connector_name = f"{k}_from_second" + new_connector_name = unique_name(k) in_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): @@ -113,7 +132,7 @@ def apply( for k, v in out_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.out_connectors: - new_connector_name = f"{k}_from_second" + new_connector_name = unique_name(k) out_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): @@ -139,14 +158,12 @@ def apply( new_node = node if isinstance(node, dace_nodes.AccessNode): if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: - new_data_name = f"{node.data}_from_second" + new_data_name = unique_name(node.data) new_node = dace_nodes.AccessNode(new_data_name) new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) new_desc.name = new_data_name if new_data_name not in first_cb.sdfg.arrays: first_cb.sdfg.add_datadesc(new_data_name, new_desc) - else: - second_cb.sdfg.remove_data(node.data) nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) @@ -180,6 +197,7 @@ def apply( graph.remove_edge(edge) # TODO(iomaganaris): Figure out if I have to handle any symbols + sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") # Need to remove both references to remove NestedSDFG from graph graph.remove_node(second_conditional_block) From eb462f6a47d8ff3bf1e08e1dce8222cd22364e55 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 27 Jan 2026 09:31:56 +0100 Subject: [PATCH 15/34] Rename properly second map arrays --- .../fuse_horizontal_conditionblocks.py | 41 +++++++++++++++---- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 83fbae70c8..ab7105a3b5 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -114,6 +114,30 @@ def apply( first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) + original_arrays_first_conditional_block = {} + for data_name, data_desc in first_conditional_block.sdfg.arrays.items(): + original_arrays_first_conditional_block[data_name] = data_desc + original_arrays_second_conditional_block = {} + for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): + original_arrays_second_conditional_block[data_name] = data_desc + total_original_arrays = len(original_arrays_first_conditional_block) + len(original_arrays_second_conditional_block) + + second_arrays_rename_map = {} + for data_name, data_desc in original_arrays_second_conditional_block.items(): + if data_name == "__cond": + continue + if data_name in original_arrays_first_conditional_block: + new_data_name = unique_name(data_name) + second_arrays_rename_map[data_name] = new_data_name + data_desc_renamed = copy.deepcopy(data_desc) + data_desc_renamed.name = new_data_name + if new_data_name not in first_cb.sdfg.arrays: + first_cb.sdfg.add_datadesc(new_data_name, data_desc_renamed) + else: + second_arrays_rename_map[data_name] = data_name + if data_name not in first_cb.sdfg.arrays: + first_cb.sdfg.add_datadesc(data_name, copy.deepcopy(data_desc)) + second_conditional_states = list(second_conditional_block.all_states()) in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} @@ -123,7 +147,7 @@ def apply( for k, v in in_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.in_connectors: - new_connector_name = unique_name(k) + new_connector_name = second_arrays_rename_map[k] in_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): @@ -132,7 +156,7 @@ def apply( for k, v in out_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.out_connectors: - new_connector_name = unique_name(k) + new_connector_name = second_arrays_rename_map[k] out_connectors_to_move_rename_map[k] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): @@ -158,12 +182,12 @@ def apply( new_node = node if isinstance(node, dace_nodes.AccessNode): if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: - new_data_name = unique_name(node.data) + new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) new_desc.name = new_data_name - if new_data_name not in first_cb.sdfg.arrays: - first_cb.sdfg.add_datadesc(new_data_name, new_desc) + # if new_data_name not in first_cb.sdfg.arrays: + # first_cb.sdfg.add_datadesc(new_data_name, new_desc) nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) @@ -203,8 +227,11 @@ def apply( graph.remove_node(second_conditional_block) graph.remove_node(second_cb) - sdfg.view() - breakpoint() + new_arrays = len(first_cb.sdfg.arrays) + assert new_arrays == total_original_arrays - 1, f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" + + # sdfg.view() + # breakpoint() # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) From 73dc4c1e1e41a74c880783649fc3e8e14459f143 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Wed, 28 Jan 2026 18:11:56 +0100 Subject: [PATCH 16/34] Enable if grouping in proper place --- .../dace/transformations/auto_optimize.py | 33 +++++++++++-------- .../fuse_horizontal_conditionblocks.py | 6 ++-- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1b9d9b2157..10101a7b7e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -652,19 +652,6 @@ def _gt_auto_process_top_level_maps( validate_all=validate_all, skip=gtx_transformations.constants._GT_AUTO_OPT_TOP_LEVEL_STAGE_SIMPLIFY_SKIP_LIST, ) - sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") - - find_single_use_data = dace_analysis.FindSingleUseData() - single_use_data = find_single_use_data.apply_pass(sdfg, None) - - sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( - single_use_data=single_use_data, - ), - validate=False, - validate_all=validate_all, - ) - sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") if GT4PyAutoOptHook.TopLevelDataFlowPost in optimization_hooks: optimization_hooks[GT4PyAutoOptHook.TopLevelDataFlowPost](sdfg) # type: ignore[call-arg] @@ -740,6 +727,26 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") + + find_single_use_data = dace_analysis.FindSingleUseData() + single_use_data = find_single_use_data.apply_pass(sdfg, None) + + sdfg.apply_transformations_repeated( + gtx_transformations.KillAliasingScalars( + single_use_data=single_use_data, + ), + validate=False, + validate_all=validate_all, + ) + sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") + + sdfg.apply_transformations_repeated( + gtx_transformations.FuseHorizontalConditionBlocks(), + validate=True, + validate_all=True, + ) + # Move dataflow into the branches of the `if` such that they are only evaluated # if they are needed. Important to call it repeatedly. # TODO(phimuell): It is unclear if `MoveDataflowIntoIfBody` should be called diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index ab7105a3b5..0417500d45 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -220,16 +220,14 @@ def apply( if edge.dst == second_cb: graph.remove_edge(edge) - # TODO(iomaganaris): Figure out if I have to handle any symbols - sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") - # Need to remove both references to remove NestedSDFG from graph graph.remove_node(second_conditional_block) graph.remove_node(second_cb) new_arrays = len(first_cb.sdfg.arrays) assert new_arrays == total_original_arrays - 1, f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" - + # TODO(iomaganaris): Figure out if I have to handle any symbols + sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") # sdfg.view() # breakpoint() From 08ce68cb3f94b0e44e27496b73a70ea2ce874e1c Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 10:19:06 +0100 Subject: [PATCH 17/34] Rename kill aliasing scalars to remove --- .../runners/dace/transformations/__init__.py | 3 +- .../dace/transformations/auto_optimize.py | 2 +- ..._scalars.py => remove_aliasing_scalars.py} | 62 ++++++++++--------- ...ars.py => test_remove_aliasing_scalars.py} | 15 ++--- 4 files changed, 44 insertions(+), 38 deletions(-) rename src/gt4py/next/program_processors/runners/dace/transformations/{kill_aliasing_scalars.py => remove_aliasing_scalars.py} (78%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/{test_kill_aliasing_scalars.py => test_remove_aliasing_scalars.py} (85%) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index f899c7e6a2..0b575fafd9 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -28,7 +28,7 @@ gt_set_gpu_blocksize, ) from .inline_fuser import inline_dataflow_into_map -from .kill_aliasing_scalars import KillAliasingScalars +from .remove_aliasing_scalars import RemoveAliasingScalars from .local_double_buffering import gt_create_local_double_buffering from .loop_blocking import LoopBlocking from .map_fusion import ( @@ -94,6 +94,7 @@ "GT4PyStateFusion", "HorizontalMapFusionCallback", "HorizontalMapSplitCallback", + "RemoveAliasingScalars", "LoopBlocking", "MapFusionHorizontal", "MapFusionVertical", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 10101a7b7e..85a9e78591 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -733,7 +733,7 @@ def _gt_auto_process_dataflow_inside_maps( single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( + gtx_transformations.RemoveAliasingScalars( single_use_data=single_use_data, ), validate=False, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py similarity index 78% rename from src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py rename to src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py index db5bc96a44..3572f43980 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/kill_aliasing_scalars.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py @@ -1,19 +1,21 @@ -import copy -import warnings -from typing import Any, Callable, Mapping, Optional, TypeAlias, Union +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Optional, Union import dace -from dace import ( - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) +from dace import properties as dace_properties, transformation as dace_transformation from dace.sdfg import nodes as dace_nodes from dace.transformation import helpers as dace_helpers -from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations + @dace_properties.make_properties -class KillAliasingScalars(dace_transformation.SingleStateTransformation): +class RemoveAliasingScalars(dace_transformation.SingleStateTransformation): first_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) second_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -59,40 +61,40 @@ def can_be_applied( scope_dict = graph.scope_dict() if first_node not in scope_dict or second_node not in scope_dict: return False - + + # Make sure that both access nodes are in the same scope if scope_dict[first_node] != scope_dict[second_node]: return False + # Make sure that both access nodes are transients if not first_node_desc.transient or not second_node_desc.transient: return False edges = graph.edges_between(first_node, second_node) assert len(edges) == 1 edge = next(iter(edges)) - # Check if edge volume is 1 + + # Check if the edge transfers only one element if edge.data.num_elements() != 1: return False if edge.data.dynamic: return False - + + # Check that all the outgoing edges of the second access node transfer only one element for out_edges in graph.out_edges(second_node): if out_edges.data.num_elements() != 1: return False - # if out_edges.data.dynamic: - # return False - # breakpoint() - # subset: dace_subsets.Subset = edge.data.get("subset", None) - # if subset is None: - # return False # Make sure that the edge subset is 1 if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( - second_node_desc, dace.data.Scalar): + second_node_desc, dace.data.Scalar + ): return False # Make sure that both access nodes are not views if isinstance(first_node_desc, dace.data.View) or isinstance( - second_node_desc, dace.data.View): + second_node_desc, dace.data.View + ): return False # Make sure that both access nodes are transients @@ -101,7 +103,8 @@ def can_be_applied( if graph.in_degree(second_node) != 1: return False - + + # Make sure that both access nodes are single use data if self.assume_single_use_data: single_use_data = {sdfg: {first_node.data}} if self._single_use_data is None: @@ -111,7 +114,6 @@ def can_be_applied( single_use_data = self._single_use_data if first_node.data not in single_use_data[sdfg]: return False - if self.assume_single_use_data: single_use_data = {sdfg: {second_node.data}} if self._single_use_data is None: @@ -123,7 +125,7 @@ def can_be_applied( return False return True - + def apply( self, graph: Union[dace.SDFGState, dace.SDFG], @@ -134,10 +136,12 @@ def apply( # Redirect all outcoming edges of the second access node to the first for edge in list(graph.out_edges(second_node)): - dace_helpers.redirect_edge(state=graph, edge=edge, new_src=first_node, new_data=first_node.data if edge.data.data == second_node.data else edge.data.data) - # edge.subset = first_node.desc(sdfg).get_subset() - # if edge.other_subset is not None: - # edge.other_subset = edge.dst.desc(sdfg).get_subset() + dace_helpers.redirect_edge( + state=graph, + edge=edge, + new_src=first_node, + new_data=first_node.data if edge.data.data == second_node.data else edge.data.data, + ) # Remove the second access node - graph.remove_node(second_node) \ No newline at end of file + graph.remove_node(second_node) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py similarity index 85% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py index 1d03b3e3e8..a755504f40 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_kill_aliasing_scalars.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py @@ -64,22 +64,23 @@ def _make_map_with_scalar_copies() -> tuple[ def test_remove_double_write_single_consumer(): sdfg, state, me, mx = _make_map_with_scalar_copies() - access_nodes_inside_original_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + access_nodes_inside_original_map = util.count_nodes( + state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode + ) assert access_nodes_inside_original_map == 3 - sdfg.view() - breakpoint() find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.KillAliasingScalars( + gtx_transformations.RemoveAliasingScalars( single_use_data=single_use_data, assume_single_use_data=False, ), validate=True, validate_all=True, ) - sdfg.view() - breakpoint() - access_nodes_inside_new_map = util.count_nodes(state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode) + + access_nodes_inside_new_map = util.count_nodes( + state.scope_subgraph(me, include_entry=False, include_exit=False), dace_nodes.AccessNode + ) assert access_nodes_inside_new_map == 1 From e110ab2b129a84c204a5077f1ae0c230aa664d54 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 10:19:35 +0100 Subject: [PATCH 18/34] Cleared a bit FuseConditionalBlocks --- .../runners/dace/transformations/__init__.py | 1 + .../fuse_horizontal_conditionblocks.py | 140 +++++++++++------- .../test_fuse_horizontal_conditionblocks.py | 94 ++++++++++-- 3 files changed, 171 insertions(+), 64 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 0b575fafd9..f3b2f0bc3a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -85,6 +85,7 @@ __all__ = [ "CopyChainRemover", "DoubleWriteRemover", + "FuseHorizontalConditionBlocks", "GPUSetBlockSize", "GT4PyAutoOptHook", "GT4PyAutoOptHookFun", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 0417500d45..4ea72c7ddf 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -1,19 +1,22 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + import copy import uuid -import warnings -from typing import Any, Callable, Mapping, Optional, TypeAlias, Union +from typing import Any, Optional, Union import dace -from dace import ( - properties as dace_properties, - subsets as dace_subsets, - transformation as dace_transformation, -) - -from dace.sdfg import nodes as dace_nodes, graph as dace_graph +from dace import properties as dace_properties, transformation as dace_transformation +from dace.sdfg import graph as dace_graph, nodes as dace_nodes from dace.transformation import helpers as dace_helpers + from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations -from dace.sdfg import utils as sdutil + def unique_name(name: str) -> str: """Adds a unique string to `name`.""" @@ -23,6 +26,7 @@ def unique_name(name: str) -> str: name = name[: (maximal_length - len(unique_sufix) - 1)] return f"{name}_{unique_sufix}" + @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -32,8 +36,12 @@ class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformatio @classmethod def expressions(cls) -> Any: map_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() - map_fusion_parallel_match.add_nedge(cls.access_node, cls.first_conditional_block, dace.Memlet()) - map_fusion_parallel_match.add_nedge(cls.access_node, cls.second_conditional_block, dace.Memlet()) + map_fusion_parallel_match.add_nedge( + cls.access_node, cls.first_conditional_block, dace.Memlet() + ) + map_fusion_parallel_match.add_nedge( + cls.access_node, cls.second_conditional_block, dace.Memlet() + ) return [map_fusion_parallel_match] def can_be_applied( @@ -49,27 +57,41 @@ def can_be_applied( second_cb: dace_nodes.NestedSDFG = self.second_conditional_block scope_dict = graph.scope_dict() + # Check that both conditional blocks are in the same parent SDFG if first_cb.sdfg.parent != second_cb.sdfg.parent: return False + # Check that the nested SDFGs' names starts with "if_stmt" - if not (first_cb.sdfg.name.startswith("if_stmt") and second_cb.sdfg.name.startswith("if_stmt")): + if not ( + first_cb.sdfg.name.startswith("if_stmt") and second_cb.sdfg.name.startswith("if_stmt") + ): return False # Check that the common access node is a boolean scalar - if not isinstance(access_node_desc, dace.data.Scalar) or access_node_desc.dtype != dace.bool_: + if ( + not isinstance(access_node_desc, dace.data.Scalar) + or access_node_desc.dtype != dace.bool_ + ): return False + # Make sure that the conditional blocks contain only one conditional block each if len(first_cb.sdfg.nodes()) > 1 or len(second_cb.sdfg.nodes()) > 1: return False + # Get the actual conditional blocks first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) - if not (isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock)): + if not ( + isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) + and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock) + ): return False + # Make sure that both conditional blocks are in the same scope if scope_dict[first_cb] != scope_dict[second_cb]: return False + # Make sure that both conditional blocks are in a map scope if not isinstance(scope_dict[first_cb], dace.nodes.MapEntry): return False @@ -82,13 +104,12 @@ def can_be_applied( return False cond_edge_first = cond_edges_first[0] cond_edge_second = cond_edges_second[0] - if cond_edge_first.src != cond_edge_second.src and (cond_edge_first.src != access_node or cond_edge_second.src != access_node): + if cond_edge_first.src != cond_edge_second.src and ( + cond_edge_first.src != access_node or cond_edge_second.src != access_node + ): return False - print(f"Found valid conditional blocks: {first_cb} and {second_cb}", flush=True) - # breakpoint() - - # TODO(iomaganaris): Need to check also that first and second nested SDFGs are not reachable from each other + # Need to check also that first and second nested SDFGs are not reachable from each other if gtx_transformations.utils.is_reachable( start=first_cb, target=second_cb, @@ -114,14 +135,19 @@ def apply( first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) + # Store original arrays to check later that all the necessary arrays have been moved original_arrays_first_conditional_block = {} for data_name, data_desc in first_conditional_block.sdfg.arrays.items(): original_arrays_first_conditional_block[data_name] = data_desc original_arrays_second_conditional_block = {} for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): original_arrays_second_conditional_block[data_name] = data_desc - total_original_arrays = len(original_arrays_first_conditional_block) + len(original_arrays_second_conditional_block) + total_original_arrays = len(original_arrays_first_conditional_block) + len( + original_arrays_second_conditional_block + ) + # Store the new names for the arrays in the second conditional block to avoid name clashes and add their data descriptors + # to the first conditional block SDFG second_arrays_rename_map = {} for data_name, data_desc in original_arrays_second_conditional_block.items(): if data_name == "__cond": @@ -140,11 +166,12 @@ def apply( second_conditional_states = list(second_conditional_block.all_states()) + # Move the connectors from the second conditional block to the first in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} out_connectors_to_move = second_cb.out_connectors in_connectors_to_move_rename_map = {} out_connectors_to_move_rename_map = {} - for k, v in in_connectors_to_move.items(): + for k, _v in in_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.in_connectors: new_connector_name = second_arrays_rename_map[k] @@ -152,8 +179,10 @@ def apply( first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): if edge.dst_conn == k: - dace_helpers.redirect_edge(state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb) - for k, v in out_connectors_to_move.items(): + dace_helpers.redirect_edge( + state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb + ) + for k, _v in out_connectors_to_move.items(): new_connector_name = k if new_connector_name in first_cb.out_connectors: new_connector_name = second_arrays_rename_map[k] @@ -161,12 +190,15 @@ def apply( first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): if edge.src_conn == k: - dace_helpers.redirect_edge(state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb) + dace_helpers.redirect_edge( + state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb + ) - nodes_renamed_map = {} - for first_inner_state in first_conditional_block.all_states(): - first_inner_state_name = first_inner_state.name - true_branch = "true_branch" in first_inner_state_name + def _find_corresponding_state_in_second( + inner_state: dace.SDFGState, + ) -> dace.SDFGState: + inner_state_name = inner_state.name + true_branch = "true_branch" in inner_state_name corresponding_state_in_second = None for state in second_conditional_states: if true_branch and "true_branch" in state.name: @@ -176,7 +208,15 @@ def apply( corresponding_state_in_second = state break if corresponding_state_in_second is None: - raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + raise RuntimeError( + f"Could not find corresponding state in second conditional block for state {inner_state_name}" + ) + return corresponding_state_in_second + + # Copy first the nodes from the second conditional block to the first + nodes_renamed_map = {} + for first_inner_state in first_conditional_block.all_states(): + corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: new_node = node @@ -186,27 +226,15 @@ def apply( new_node = dace_nodes.AccessNode(new_data_name) new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) new_desc.name = new_data_name - # if new_data_name not in first_cb.sdfg.arrays: - # first_cb.sdfg.add_datadesc(new_data_name, new_desc) nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) + # Then copy the edges second_to_first_connections = {} for node in nodes_renamed_map: second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): - first_inner_state_name = first_inner_state.name - true_branch = "true_branch" in first_inner_state_name - corresponding_state_in_second = None - for state in second_conditional_states: - if true_branch and "true_branch" in state.name: - corresponding_state_in_second = state - break - elif not true_branch and "false_branch" in state.name: - corresponding_state_in_second = state - break - if corresponding_state_in_second is None: - raise RuntimeError(f"Could not find corresponding state in second conditional block for state {first_inner_state_name}") + corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: for edge in list(corresponding_state_in_second.out_edges(node)): @@ -215,7 +243,17 @@ def apply( new_memlet = copy.deepcopy(edge.data) if edge.data.data in second_to_first_connections: new_memlet.data = second_to_first_connections[edge.data.data] - first_inner_state.add_edge(nodes_renamed_map[node], nodes_renamed_map[node].data, nodes_renamed_map[edge.dst], second_to_first_connections[node.data], new_memlet) + first_inner_state.add_edge( + nodes_renamed_map[node], + nodes_renamed_map[node].data + if isinstance(node, dace_nodes.AccessNode) and edge.src_conn + else edge.src_conn, + nodes_renamed_map[dst], + second_to_first_connections[dst.data] + if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn + else edge.dst_conn, + new_memlet, + ) for edge in list(graph.out_edges(access_node)): if edge.dst == second_cb: graph.remove_edge(edge) @@ -225,12 +263,6 @@ def apply( graph.remove_node(second_cb) new_arrays = len(first_cb.sdfg.arrays) - assert new_arrays == total_original_arrays - 1, f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" - # TODO(iomaganaris): Figure out if I have to handle any symbols - sdfg.save(f"after_fuse_horizontal_conditionblocks_{first_cb}_{second_cb}.sdfg") - # sdfg.view() - # breakpoint() - - - # print(f"Fused conditional blocks into: {new_nested_sdfg}", flush=True) - # breakpoint() \ No newline at end of file + assert new_arrays == total_original_arrays - 1, ( + f"After fusion, expected {total_original_arrays - 1} arrays but found {new_arrays}" + ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 2d5025affa..04dd6a7dcc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -21,9 +21,72 @@ import dace -def _make_map_with_conditional_blocks() -> tuple[ - dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit -]: + +def _make_if_block_with_tasklet( + state: dace.SDFGState, + b1_name: str = "__arg1", + b2_name: str = "__arg2", + cond_name: str = "__cond", + output_name: str = "__output", + b1_type: dace.typeclass = dace.float64, + b2_type: dace.typeclass = dace.float64, + output_type: dace.typeclass = dace.float64, +) -> dace_nodes.NestedSDFG: + inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) + + types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} + for name in {b1_name, b2_name, cond_name, output_name}: + inner_sdfg.add_scalar( + name, + dtype=types[name], + transient=False, + ) + + if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) + inner_sdfg.add_node(if_region, is_start_block=True) + + then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) + tstate = then_body.add_state("true_branch", is_start_block=True) + tasklet = tstate.add_tasklet( + "true_tasklet", + inputs={"__tasklet_in"}, + outputs={"__tasklet_out"}, + code="__tasklet_out = __tasklet_in * 2.0", + ) + tstate.add_edge( + tstate.add_access(b1_name), + None, + tasklet, + "__tasklet_in", + dace.Memlet(f"{b1_name}[0]"), + ) + tstate.add_edge( + tasklet, + "__tasklet_out", + tstate.add_access(output_name), + None, + dace.Memlet(f"{output_name}[0]"), + ) + + else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=inner_sdfg) + fstate = else_body.add_state("false_branch", is_start_block=True) + fstate.add_nedge( + fstate.add_access(b2_name), + fstate.add_access(output_name), + dace.Memlet(f"{b2_name}[0] -> [0]"), + ) + + if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) + if_region.add_branch(dace.sdfg.state.CodeBlock(f"not {cond_name}"), else_body) + + return state.add_nested_sdfg( + sdfg=inner_sdfg, + inputs={b1_name, b2_name, cond_name}, + outputs={output_name}, + ) + + +def _make_map_with_conditional_blocks() -> dace.SDFG: sdfg = dace.SDFG(util.unique_name("map_with_conditional_blocks")) state = sdfg.add_state(is_start_block=True) @@ -90,7 +153,7 @@ def _make_map_with_conditional_blocks() -> tuple[ state.add_edge(if_block_0, "__output", tmp_c, None, dace.Memlet("tmp_c[0]")) state.add_edge(tmp_c, None, mx, "IN_c", dace.Memlet("c[__i]")) - if_block_1 = _make_if_block(state=state, outer_sdfg=sdfg) + if_block_1 = _make_if_block_with_tasklet(state=state) state.add_edge(cond_var, None, if_block_1, "__cond", dace.Memlet("cond_var")) state.add_edge(tmp_a, None, if_block_1, "__arg1", dace.Memlet("tmp_a[0]")) state.add_edge(tmp_b, None, if_block_1, "__arg2", dace.Memlet("tmp_b[0]")) @@ -101,13 +164,19 @@ def _make_map_with_conditional_blocks() -> tuple[ state.add_edge(mx, "OUT_d", d, None, dace.Memlet("d[__i]")) sdfg.validate() - return sdfg, state, me, mx + return sdfg + def test_fuse_horizontal_condition_blocks(): - sdfg, state, me, mx = _make_map_with_conditional_blocks() + sdfg = _make_map_with_conditional_blocks() + + conditional_blocks = [ + n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.sdfg.state.ConditionalBlock) + ] + assert len(conditional_blocks) == 2 - # sdfg.view() - # breakpoint() + ref, res = util.make_sdfg_args(sdfg) + util.compile_and_run_sdfg(sdfg, **ref) sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), @@ -115,5 +184,10 @@ def test_fuse_horizontal_condition_blocks(): validate_all=True, ) - # sdfg.view() - # breakpoint() + new_conditional_blocks = [ + n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.sdfg.state.ConditionalBlock) + ] + assert len(new_conditional_blocks) == 1 + + util.compile_and_run_sdfg(sdfg, **res) + assert util.compare_sdfg_res(ref=ref, res=res) From 15ae54d62d721e9f2eb52a90d4c0ffe72880e8f4 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Thu, 29 Jan 2026 10:27:07 +0100 Subject: [PATCH 19/34] Fixed imports --- .../runners/dace/transformations/__init__.py | 4 ++-- .../dace/transformations/fuse_horizontal_conditionblocks.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index f3b2f0bc3a..edd52d0280 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -28,7 +28,6 @@ gt_set_gpu_blocksize, ) from .inline_fuser import inline_dataflow_into_map -from .remove_aliasing_scalars import RemoveAliasingScalars from .local_double_buffering import gt_create_local_double_buffering from .loop_blocking import LoopBlocking from .map_fusion import ( @@ -55,6 +54,7 @@ ) from .redundant_array_removers import CopyChainRemover, DoubleWriteRemover, gt_remove_copy_chain from .remove_access_node_copies import RemoveAccessNodeCopies +from .remove_aliasing_scalars import RemoveAliasingScalars from .remove_views import RemovePointwiseViews from .scan_loop_unrolling import ScanLoopUnrolling from .simplify import ( @@ -95,7 +95,6 @@ "GT4PyStateFusion", "HorizontalMapFusionCallback", "HorizontalMapSplitCallback", - "RemoveAliasingScalars", "LoopBlocking", "MapFusionHorizontal", "MapFusionVertical", @@ -108,6 +107,7 @@ "MultiStateGlobalSelfCopyElimination", "MultiStateGlobalSelfCopyElimination2", "RemoveAccessNodeCopies", + "RemoveAliasingScalars", "RemovePointwiseViews", "ScanLoopUnrolling", "SingleStateGlobalDirectSelfCopyElimination", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 4ea72c7ddf..9fcec15b14 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -8,7 +8,7 @@ import copy import uuid -from typing import Any, Optional, Union +from typing import Any, Union import dace from dace import properties as dace_properties, transformation as dace_transformation From 8f44eb5475e90c7f48dbd6be7bbc05ef11738476 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Mon, 2 Feb 2026 17:16:20 +0100 Subject: [PATCH 20/34] Remove sdfg.save --- .../runners/dace/transformations/auto_optimize.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 85a9e78591..10a6256fef 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -655,7 +655,6 @@ def _gt_auto_process_top_level_maps( if GT4PyAutoOptHook.TopLevelDataFlowPost in optimization_hooks: optimization_hooks[GT4PyAutoOptHook.TopLevelDataFlowPost](sdfg) # type: ignore[call-arg] - sdfg.save("after_top_level_map_optimization.sdfg") return sdfg @@ -727,8 +726,6 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - sdfg.save("before_kill_aliasing_scalars_top_level.sdfg") - find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) @@ -739,7 +736,6 @@ def _gt_auto_process_dataflow_inside_maps( validate=False, validate_all=validate_all, ) - sdfg.save("after_kill_aliasing_scalars_top_level.sdfg") sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), From 6a8b1b4fea50ebba496ec3c8ce594cf2ac80f32d Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 09:32:44 +0100 Subject: [PATCH 21/34] Fix issues after cherry-picking --- .../dace/transformations/fuse_horizontal_conditionblocks.py | 3 ++- .../test_fuse_horizontal_conditionblocks.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 9fcec15b14..470610aa5f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -232,7 +232,8 @@ def _find_corresponding_state_in_second( # Then copy the edges second_to_first_connections = {} for node in nodes_renamed_map: - second_to_first_connections[node.data] = nodes_renamed_map[node].data + if isinstance(node, dace_nodes.AccessNode): + second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) nodes_to_move = list(corresponding_state_in_second.nodes()) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 04dd6a7dcc..ec23b4ee24 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -141,7 +141,7 @@ def _make_map_with_conditional_blocks() -> dace.SDFG: "tasklet_cond", inputs={"__in"}, outputs={"__out"}, - code="__out = __in <= 0.0", + code="__out = __in <= 0.5", ) state.add_edge(tmp_a, None, tasklet_cond, "__in", dace.Memlet("tmp_a[0]")) state.add_edge(tasklet_cond, "__out", cond_var, None, dace.Memlet("cond_var")) From 0ce43372b2c6cc4a89968a1ec01768343edc3bb4 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 14:08:39 +0100 Subject: [PATCH 22/34] Applied suggestions to FuseHorizontalConditionBlocks and moved unique_name to gtx_transformations utils --- _dacegraphs/invalid.sdfgz | Bin 0 -> 5370 bytes .../runners/dace/transformations/__init__.py | 3 +- .../fuse_horizontal_conditionblocks.py | 142 +++++++++++------- .../runners/dace/transformations/utils.py | 10 ++ .../test_auto_optimizer_hooks.py | 2 +- .../test_copy_chain_remover.py | 20 +-- .../test_create_local_double_buffering.py | 10 +- .../test_dead_dataflow_elimination.py | 4 +- .../test_distributed_buffer_relocator.py | 24 ++- .../test_double_write_remover.py | 8 +- .../test_fuse_horizontal_conditionblocks.py | 6 +- .../transformation_tests/test_gpu_utils.py | 11 +- .../test_horizontal_map_split_fusion.py | 2 +- .../transformation_tests/test_inline_fuser.py | 8 +- .../test_loop_blocking.py | 28 ++-- .../test_make_transients_persistent.py | 4 +- .../test_map_buffer_elimination.py | 10 +- .../test_map_fusion_utils.py | 5 +- .../transformation_tests/test_map_order.py | 2 +- .../transformation_tests/test_map_promoter.py | 10 +- .../transformation_tests/test_map_splitter.py | 2 +- .../transformation_tests/test_map_to_copy.py | 2 +- .../test_move_dataflow_into_if_body.py | 24 +-- .../test_move_tasklet_into_map.py | 2 +- ...ulti_state_global_self_copy_elimination.py | 16 +- .../test_multiple_copies_global.py | 2 +- .../test_remove_aliasing_scalars.py | 2 +- .../test_remove_point_view.py | 2 +- ...ngle_state_global_self_copy_elimination.py | 32 ++-- .../test_split_access_node.py | 40 ++--- .../transformation_tests/test_split_memlet.py | 6 +- .../test_splitting_tools.py | 10 +- .../transformation_tests/test_state_fusion.py | 28 ++-- .../transformation_tests/test_strides.py | 28 ++-- .../test_vertical_map_split_fusion.py | 6 +- .../dace_tests/transformation_tests/util.py | 15 +- 36 files changed, 306 insertions(+), 220 deletions(-) create mode 100644 _dacegraphs/invalid.sdfgz diff --git a/_dacegraphs/invalid.sdfgz b/_dacegraphs/invalid.sdfgz new file mode 100644 index 0000000000000000000000000000000000000000..2869814a1c9a9c302e27afa792eef634827ef0c0 GIT binary patch literal 5370 zcmV95~>_x$uSL%B9$2`B7ndQRH^Z9b2 zE@%33Y5u*QWhBE?3p2yms~I7l%%s!|esZitCaLADd;>ghrdBKU`SAtCYb!r0<@$SZ z;oWVU>-LLxcV-PIYez&bHlc}4cxV&a z*o22RVSaTzYv99Qjk;8;uy!0)g_bs6sl|Hk7DYqms~dMn_4H|1&zyGk?CDmouxw_Z zpcdoRP%oF4FI>H|N+qY~mrD&~X_nSiEik{@Sb3Fd7fif3pIMa`bTqBj`%*2=Z`65N z!LNV*d0P9X`C^_w$9vqL&-0%1n|T5KX(BbjsWd22BtzU&hNSfZVZt26NhUcJ_qOhZ zT3;0RWWuD95t1^@wN4bKMp2T%fi!VcCxm32WGT^t+BDIU5SLQyOy)DHG51A@%5qE> z%j(zhp}tj9{?}aP?xA{6Iqfb>C>FN*z@h~ZjH`|##h_N8<}HgL~oy}yzraiq92Nj-tAtp&ebZPUCtN9g#fX= z3mmXq%jwPifQOr_MHM1j_qUw(^GkO>Y5ZUSQ~7fBVQp3_Hy0nw+N?j!muj)NQs*DO z|K;z$|NQO$eK^ng_2&=J8u|Le!u^@gHh#ovy0_w{-5O{FS+n;KD(LBzQ6uF-Kx^0Tv3cNpJ0?^8Q07(_#C+o*pefGt%r9LVhm->k`CC?BY~}rust={+D6zaNhu=|<_$@TIFlM> zne<3XZxoX*$zU4M4qJ94u-z7*z=+zOur-CZhKg(c5V8N=FGSKqkm09?+|Kf-AuTm0hULk;rrBPh_glg7--t;)o`>Arepu{xOXY5NU3N_U=^O^l=Jej#+npmH+tE2Ldx&=S5RpB4vd4T;+_G>6>b0)x zqaS`Mr9HcAGxz88?#fKvOO`iz$6NmK`)|#!DNY8x=bq|BDkMpSMV@kpxXCi4d5V$8 zm}16+mYq{2Ie49_H`f)nJq^``bgpUXTsvOpnzqiheL6SOKBb^CKrkVOnS?U}W)cJ= zD>6zP%e>8K*Om`{mJiu)YW?p^*VMWB=4?4)!@G1xG}5FLna!YECSOhd2B=v~{^zR+ zexH1?&#qMw$6B&fB@Fk zpSO*^y%Kdjam!(1c%kome9u#22M}s?P8!a(=@$FcCPcBh-KMqj#%8Hwi}@d-a%{7v zYX(PkEM4E^vGE2q9^j6~rW@3B-;5-xx96L7sp)Q`)LuQ zG&4dbD6xhiSi~=cSCk_~oY02pL~hGmYoAx+7tp%Ed_UN}e)OIG-z|5Hz9aadI+hup zN(tpxpkm-Q;YfS$k&s?;sGS10lW8zbHCL;tYP;3TSIf)qE|-7&;?AL(&8!l{QdPl` z$TB9O+qBoP4k>V2i0V_M661xW*f1}fSm?H~6z%nGt8<6jO#_zimlVK@Yo=l8!d$f9 zvF{A9cS)n;PW`B>cjl?RcF)15?!VT;O?DlD>i(l$7qG4ic<@~T?h%; zP;Q6H-7LTJhQRR@wEWmwi{>TC2AZ!dQzjzlp32H);2m9ynT@aMdCRdcdR{F*x$=_= zR^1=4^q4PUzhkDC_VeC)XYb*#vqv5X%q_`tqEio=(Imz2sq)gn{y9@f!FGf8+*l@( zt}=F(Fn}@xwxQZ?sgY0*DCtwP#@zzfiSS0+n>W(P8^_8UfzYhFrWr~>T~3g7!ho|x zYLF^81nfSqz|jZnLOv36kkes~x;@hXyEG5ItuJ-7OcuLlhqL=qJ?_QcAAmjC|DMqI z_=_-J>Ds0p4DT1=MckVgapc7(+(aIVo?#;}U<11Lrt07cA*9e+bZ%mIl%`>WduVYU z^zv33!V-$8%3vcLlr2aSr4cyVX(DKvxs;9c^`LLw?hN0kc~a9fq+$0aGvk!IDxL#j zd%_ZhWR`%J9JVRa%sRwKYOI2FJLAWi@vu1K>9K`%T-Y14X;YA+Cx}7#23Y9^ zcMF|@yG7!hahk#3T#ST1egx=iIs377na|ntF_-y%yt&NhU*CTQzqRM?3E;NDPA^0A zM|TxwRP--TdC#Z9^mENZ__`jx(VN(eVXx>E{x{y%;t(*jzxy~4w3*BU?rL{9Xl}b^ zM}#4RDs+v#m(jMdBl8QlnniF#aE!73Jn;Lxy#PfWORQkXdJ92j5&`?Inj3_{4U7p0 z*kO4ZOnNI5u^q8pi&9u}h|q}8s57(IJ^ z4x5?;^m8#R!n(E`Ko2WW!Sf6tQEhNulm25Kk_k|-^}<}fL7kd9)IY0;xBait;-ogmYq(CXmC3 zaqvZhWGJBbS4d-#A^{=$g%Um?)V~^+tr8T`XKNDbvvuk&S@c9)D+KQ=Ar8DFN{C;^ zC5!fvh$Y5q4S6UK#O&7~P^2CtVu?W}O;WA|WU{h6x@1L{tQy`(dh4Od+yjoGCu znK=#NCF67=4uCkFh|q}BiKh|1vZI+!JgeH(WY`g!dX%aRq>jvKq$LE2z=CJiU@0-> zuUqZvO8H~A+LZ*VUHM^Z*R#MD>1d6xme@)}S7dXOU{mzAcv@soJ8eVXoprWL)2Nwl zB&sbRjiygYi=uZg8f#51#DP&(8#$<@IQ_NDeTMZYDJNRVHAuzU!94wR%tMk-1Pb^1 ziJAM!nL`E5lMu}p>Ar+^U&6XC;b(^!9UUL#Km~M&9XLcAG9Wd@koBzIn#RGZ#$hQc zh9s#NsBIjiZajT`W1v+&BwX022`j>sm|+U(L5iQ=nxB9YKL@vjOHyAGyQBlyC5`NI z{Ol48*OAj&^xzkAiv03hc&z-vLH7F-OpjC-PnhgtV7iN7@V^*~;sGp*Ba6OB`ioFB z4IP~WlVA)?gArPrL8&l?Cc_9V)R+k|dVF_iv}>OvLB_$RZ|ElLlZDwEHBm+nM6!s; zC-U)m2~3V1p=s0v8a+_Sj|P?P!5cajD8oE|dO>@Drf(P!!*zYblzqdreS_3}L-c)v zg$l!@3jM?i!{rJ?1q*{E3&TVU;j)EL;X<%6}Q!9u^MNyjl%>=$wW zhH*bc0vG}Fb^$X~7&ufK7*^ZcQKl}z9!GtAjt0sEhX@77orK4kN}Q=gHhoUBbIhb4 zqu8e7CH{!EX&f8}*`~b!{058^*274lLyVMu2|}W|63vy!$ZyA789$ZC7zWFT=|qMn z6p5=UabO*`s?sCT$Sbl~Xb+2pKjX4WpTr}FZYcNqu#c92q`lP5=UQqL;F{UbpQ=wX zl6#dwdxwYSBnkA>Y-*^6`XBC^__xI}pKSaSzc_gZ#X}OT9I=C9aH^H)p*ZQCdga92 z59x_H2!$nA8Iq7CO50!D{gC#PAC8fnCAuF*_rtzPTZW`kO-LQ1`^#s;!GvZR3e+knM#DrY0p%~OkOj_bY{c^HgP%=2S=RFM8L%9 z%=5{w*^y0WUO;nwXd0Z?rMd1-i8DxZ&7VqheHLTT91$dGrwA1&DO3=h1B*q+I}`@783Td%GuVY?hH=+C@5=uS0A-SK#?B4 zARMMZX^N3ci_pXlvM(u>z`l;tnkf(AVfB^*XzMw5Iv>%(cU_W>#WN>pUvjYiRJE0@z!g) z&F(G(`!?+RSkrv>%L=NlZM&BR9>AXwU8pb(l3qX*WGWfR-#7zLT;VxMJ;V>OwcXb> z)sJ1+*@fLTu<1Z7#BU}xg-n1T=LJn% z=7_^-8nCAtyYSQ(>qsY~?WB^6d=|w6FX5g*mBX6<20!vbXjWk|m z(LLdfahj~-p4dNu71K+iA+|G4z++v()HB0aAy0fJ zG}l^c8aXX;THUJL08Z19)8c@5Y6k@&S;8btIa%R6$`ye)QbbGSwbHOVs?)SKp=;Fg zp0bvg)H{;WF<8w+R*P_dY7YftA^3xn96?&kB3Cv1BU48p3kSyxCtjsiMplcgwjYXX zAgkHPYH@-288$wJ76#LdAcKpIPi+MRO=GaOMqsl*d74O?y>4QAwLSepq*m`BpB+kS zv+Y#vk<=p8>o-0Roq$r75XLO{#S`g^AYKy0x)9EbdPTtL%6ZrUjg0n|7>(94niz%A zt|O^Mw0D!*q4Q6hf_TFbu*(z$zXOg50u={@Gg49z#!VuLj-(bzEgg&0!eVQ-#H(!J zok#cRB(o4xkfTH}SW%JSBP}7(DWbw=puyysmk+TvH%~AAEuX&=1h_u%?@qAa?Et?! zL4C_mx4VIS%VTBS*-sFM7DCOId&}&#OEko5p*#Lv~%gMtRVDhh_&T#*lPX18q$@%<~TTGypC-e1WvCJoGvU2B8FSl}b zQowRz9YpAzUl(tiFD8(+WU_!*xp4N^Z@&Bf{p8y})z$T-`*L!vAPTR88{sbish#9Y zc$&Xin6m7D%<~IC1{}I7!u`${=M%MBsn6vJc str: - """Adds a unique string to `name`.""" - maximal_length = 200 - unique_sufix = str(uuid.uuid1()).replace("-", "_") - if len(name) > (maximal_length - len(unique_sufix)): - name = name[: (maximal_length - len(unique_sufix) - 1)] - return f"{name}_{unique_sufix}" - - @dace_properties.make_properties class FuseHorizontalConditionBlocks(dace_transformation.SingleStateTransformation): - access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) + """Fuses two conditional blocks that share the same condition variable and are + not dependent to each other (i.e. the output of one of them is used as input to the other) + into a single conditional block. + The motivation for this transformation is to reduce the number of conditional blocks + which generate if statements in the CPU or GPU code, which can lead to performance improvements. + Example: + Before fusion: + ``` + if __cond: + __output1 = __arg1 * 2.0 + else: + __output1 = __arg2 + 3.0 + if __cond: + __output2 = __arg3 - 1.0 + else: + __output2 = __arg4 / 4.0 + ``` + After fusion: + ``` + if __cond: + __output1 = __arg1 * 2.0 + __output2 = __arg3 - 1.0 + else: + __output1 = __arg2 + 3.0 + __output2 = __arg4 / 4.0 + ``` + """ + + conditional_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) first_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) second_conditional_block = dace_transformation.PatternNode(dace_nodes.NestedSDFG) @classmethod def expressions(cls) -> Any: - map_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() - map_fusion_parallel_match.add_nedge( - cls.access_node, cls.first_conditional_block, dace.Memlet() + conditionalblock_fusion_parallel_match = dace_graph.OrderedMultiDiConnectorGraph() + conditionalblock_fusion_parallel_match.add_nedge( + cls.conditional_access_node, cls.first_conditional_block, dace.Memlet() ) - map_fusion_parallel_match.add_nedge( - cls.access_node, cls.second_conditional_block, dace.Memlet() + conditionalblock_fusion_parallel_match.add_nedge( + cls.conditional_access_node, cls.second_conditional_block, dace.Memlet() ) - return [map_fusion_parallel_match] + return [conditionalblock_fusion_parallel_match] def can_be_applied( self, @@ -51,12 +69,19 @@ def can_be_applied( sdfg: dace.SDFG, permissive: bool = False, ) -> bool: - access_node: dace_nodes.AccessNode = self.access_node - access_node_desc = access_node.desc(sdfg) + conditional_access_node: dace_nodes.AccessNode = self.conditional_access_node + conditional_access_node_desc = conditional_access_node.desc(sdfg) first_cb: dace_nodes.NestedSDFG = self.first_conditional_block second_cb: dace_nodes.NestedSDFG = self.second_conditional_block scope_dict = graph.scope_dict() + # Check that the common access node is a boolean scalar + if ( + not isinstance(conditional_access_node_desc, dace.data.Scalar) + or conditional_access_node_desc.dtype != dace.bool_ + ): + return False + # Check that both conditional blocks are in the same parent SDFG if first_cb.sdfg.parent != second_cb.sdfg.parent: return False @@ -67,15 +92,8 @@ def can_be_applied( ): return False - # Check that the common access node is a boolean scalar - if ( - not isinstance(access_node_desc, dace.data.Scalar) - or access_node_desc.dtype != dace.bool_ - ): - return False - # Make sure that the conditional blocks contain only one conditional block each - if len(first_cb.sdfg.nodes()) > 1 or len(second_cb.sdfg.nodes()) > 1: + if first_cb.sdfg.number_of_nodes() > 1 or second_cb.sdfg.number_of_nodes() > 1: return False # Get the actual conditional blocks @@ -83,7 +101,22 @@ def can_be_applied( second_conditional_block = next(iter(second_cb.sdfg.nodes())) if not ( isinstance(first_conditional_block, dace.sdfg.state.ConditionalBlock) + and len(first_conditional_block.sub_regions()) == 2 and isinstance(second_conditional_block, dace.sdfg.state.ConditionalBlock) + and len(second_conditional_block.sub_regions()) == 2 + ): + return False + first_conditional_block_state_names = [ + state.name for state in first_conditional_block.all_states() + ] + second_conditional_block_state_names = [ + state.name for state in second_conditional_block.all_states() + ] + if not ( + "true_branch" in first_conditional_block_state_names + and "false_branch" in first_conditional_block_state_names + and "true_branch" in second_conditional_block_state_names + and "false_branch" in second_conditional_block_state_names ): return False @@ -96,16 +129,23 @@ def can_be_applied( return False # Check that there is an edge to the conditional blocks with dst_conn == "__cond" - cond_edges_first = [e for e in graph.in_edges(first_cb) if e.dst_conn == "__cond"] + cond_edges_first = [ + e for e in graph.in_edges(first_cb) if e.dst_conn and e.dst_conn == "__cond" + ] if len(cond_edges_first) != 1: return False - cond_edges_second = [e for e in graph.in_edges(second_cb) if e.dst_conn == "__cond"] + cond_edges_second = [ + e for e in graph.in_edges(second_cb) if e.dst_conn and e.dst_conn == "__cond" + ] if len(cond_edges_second) != 1: return False cond_edge_first = cond_edges_first[0] cond_edge_second = cond_edges_second[0] - if cond_edge_first.src != cond_edge_second.src and ( - cond_edge_first.src != access_node or cond_edge_second.src != access_node + if cond_edge_first.data.is_empty() or cond_edge_second.data.is_empty(): + return False + if not all( + cond_edge.src is conditional_access_node + for cond_edge in [cond_edge_first, cond_edge_second] ): return False @@ -128,7 +168,7 @@ def apply( graph: Union[dace.SDFGState, dace.SDFG], sdfg: dace.SDFG, ) -> None: - access_node: dace_nodes.AccessNode = self.access_node + conditional_access_node: dace_nodes.AccessNode = self.conditional_access_node first_cb: dace_nodes.NestedSDFG = self.first_conditional_block second_cb: dace_nodes.NestedSDFG = self.second_conditional_block @@ -136,12 +176,8 @@ def apply( second_conditional_block = next(iter(second_cb.sdfg.nodes())) # Store original arrays to check later that all the necessary arrays have been moved - original_arrays_first_conditional_block = {} - for data_name, data_desc in first_conditional_block.sdfg.arrays.items(): - original_arrays_first_conditional_block[data_name] = data_desc - original_arrays_second_conditional_block = {} - for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): - original_arrays_second_conditional_block[data_name] = data_desc + original_arrays_first_conditional_block = first_conditional_block.sdfg.arrays.copy() + original_arrays_second_conditional_block = second_conditional_block.sdfg.arrays.copy() total_original_arrays = len(original_arrays_first_conditional_block) + len( original_arrays_second_conditional_block ) @@ -153,7 +189,7 @@ def apply( if data_name == "__cond": continue if data_name in original_arrays_first_conditional_block: - new_data_name = unique_name(data_name) + new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" second_arrays_rename_map[data_name] = new_data_name data_desc_renamed = copy.deepcopy(data_desc) data_desc_renamed.name = new_data_name @@ -171,25 +207,25 @@ def apply( out_connectors_to_move = second_cb.out_connectors in_connectors_to_move_rename_map = {} out_connectors_to_move_rename_map = {} - for k, _v in in_connectors_to_move.items(): - new_connector_name = k + for original_in_connector_name, _v in in_connectors_to_move.items(): + new_connector_name = original_in_connector_name if new_connector_name in first_cb.in_connectors: - new_connector_name = second_arrays_rename_map[k] - in_connectors_to_move_rename_map[k] = new_connector_name + new_connector_name = second_arrays_rename_map[original_in_connector_name] + in_connectors_to_move_rename_map[original_in_connector_name] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): - if edge.dst_conn == k: + if edge.dst_conn == original_in_connector_name: dace_helpers.redirect_edge( state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb ) - for k, _v in out_connectors_to_move.items(): - new_connector_name = k + for original_out_connector_name, _v in out_connectors_to_move.items(): + new_connector_name = original_out_connector_name if new_connector_name in first_cb.out_connectors: - new_connector_name = second_arrays_rename_map[k] - out_connectors_to_move_rename_map[k] = new_connector_name + new_connector_name = second_arrays_rename_map[original_out_connector_name] + out_connectors_to_move_rename_map[original_out_connector_name] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): - if edge.src_conn == k: + if edge.src_conn == original_out_connector_name: dace_helpers.redirect_edge( state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb ) @@ -207,10 +243,6 @@ def _find_corresponding_state_in_second( elif not true_branch and "false_branch" in state.name: corresponding_state_in_second = state break - if corresponding_state_in_second is None: - raise RuntimeError( - f"Could not find corresponding state in second conditional block for state {inner_state_name}" - ) return corresponding_state_in_second # Copy first the nodes from the second conditional block to the first @@ -224,8 +256,6 @@ def _find_corresponding_state_in_second( if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) - new_desc = copy.deepcopy(node.desc(second_cb.sdfg)) - new_desc.name = new_data_name nodes_renamed_map[node] = new_node first_inner_state.add_node(new_node) @@ -255,7 +285,7 @@ def _find_corresponding_state_in_second( else edge.dst_conn, new_memlet, ) - for edge in list(graph.out_edges(access_node)): + for edge in list(graph.out_edges(conditional_access_node)): if edge.dst == second_cb: graph.remove_edge(edge) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index e3f417276e..3a18c40353 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -8,6 +8,7 @@ """Common functionality for the transformations/optimization pipeline.""" +import uuid from typing import Optional, Sequence, TypeVar, Union import dace @@ -20,6 +21,15 @@ _PassT = TypeVar("_PassT", bound=dace_ppl.Pass) +def unique_name(name: str) -> str: + """Adds a unique string to `name`.""" + maximal_length = 200 + unique_sufix = str(uuid.uuid1()).replace("-", "_") + if len(name) > (maximal_length - len(unique_sufix)): + name = name[: (maximal_length - len(unique_sufix) - 1)] + return f"{name}_{unique_sufix}" + + def gt_make_transients_persistent( sdfg: dace.SDFG, device: dace.DeviceType, diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py index 3f686ae583..03e94b2858 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_auto_optimizer_hooks.py @@ -21,7 +21,7 @@ def _make_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("test")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("test")) state = sdfg.add_state(is_start_block=True) for name in "abcde": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py index 53bba67408..7bb46d302b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_copy_chain_remover.py @@ -29,7 +29,7 @@ def _make_simple_linear_chain_sdfg() -> dace.SDFG: All intermediates have the same size. """ - sdfg = dace.SDFG(util.unique_name("simple_linear_chain_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_linear_chain_sdfg")) for name in ["a", "b", "c", "d", "e"]: sdfg.add_array( @@ -75,7 +75,7 @@ def _make_diff_sizes_pull_chain_sdfg() -> tuple[ - The AccessNode that is used as final output, refers to `e`. - The Tasklet that is within the Map. """ - sdfg = dace.SDFG(util.unique_name("diff_size_linear_pull_chain_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("diff_size_linear_pull_chain_sdfg")) array_size_increment = 10 array_size = 10 @@ -118,7 +118,7 @@ def _make_diff_sizes_push_chain_sdfg() -> tuple[ Same as `_make_simple_linear_pull_chain_sdfg()` but the intermediates become smaller and smaller, so the full shape of the destination array is always written. """ - sdfg = dace.SDFG(util.unique_name("diff_size_linear_push_chain_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("diff_size_linear_push_chain_sdfg")) array_size_decrement = 10 array_size = 50 @@ -155,7 +155,7 @@ def _make_diff_sizes_push_chain_sdfg() -> tuple[ def _make_multi_stage_reduction_sdfg() -> dace.SDFG: """Creates an SDFG that has a two stage copy reduction.""" - sdfg = dace.SDFG(util.unique_name("multi_stage_reduction")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multi_stage_reduction")) state: dace.SDFGState = sdfg.add_state(is_start_block=True) # This is the size of the arrays, if not mentioned here, then its size is 10. @@ -221,7 +221,7 @@ def _make_not_fully_copied() -> dace.SDFG: Make an SDFG where two intermediate array is not fully copied. Thus the transformation only applies once, when `d` is removed. """ - sdfg = dace.SDFG(util.unique_name("not_fully_copied_intermediate")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("not_fully_copied_intermediate")) for name in ["a", "b", "c", "d", "e"]: sdfg.add_array( @@ -257,7 +257,7 @@ def _make_possible_cyclic_sdfg() -> dace.SDFG: If the transformation would remove `a1` then it would create a cycle. Thus the transformation should not apply. """ - sdfg = dace.SDFG(util.unique_name("possible_cyclic_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("possible_cyclic_sdfg")) anames = ["i1", "a1", "a2", "o1"] for name in anames: @@ -331,7 +331,7 @@ def make_inner_sdfg() -> dace.SDFG: inner_sdfg = make_inner_sdfg() - sdfg = dace.SDFG(util.unique_name("linear_chain_with_nested_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("linear_chain_with_nested_sdfg")) state = sdfg.add_state(is_start_block=True) array_size_increment = 10 @@ -370,7 +370,7 @@ def make_inner_sdfg() -> dace.SDFG: def _make_a1_has_output_sdfg() -> dace.SDFG: """Here `a1` has an output degree of 2, one to `a2` and one to another output.""" - sdfg = dace.SDFG(util.unique_name("a1_has_an_additional_output_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("a1_has_an_additional_output_sdfg")) state = sdfg.add_state(is_start_block=True) # All other arrays have a size of 10. @@ -411,7 +411,9 @@ def _make_copy_chain_with_reduction_node( dace_nodes.AccessNode, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name("copy_chain_remover_with_reduction_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("copy_chain_remover_with_reduction_sdfg") + ) state = sdfg.add_state(is_start_block=True) if output_an_array: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py index 0db2706c0d..4bb61a01ad 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_create_local_double_buffering.py @@ -73,7 +73,7 @@ def _create_sdfg_double_read_part_2( def _create_sdfg_double_read( version: int, ) -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"double_read_version_{version}")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"double_read_version_{version}")) state = sdfg.add_state(is_start_block=True) for name in "AB": sdfg.add_array( @@ -98,7 +98,7 @@ def _create_sdfg_double_read( def _create_non_scalar_read() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"non_scalar_read_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"non_scalar_read_sdfg")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -145,7 +145,7 @@ def test_local_double_buffering_double_read_sdfg(): def test_local_double_buffering_no_connection(): """There is no direct connection between read and write.""" - sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_connection")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("local_double_buffering_no_connection")) state = sdfg.add_state(is_start_block=True) for name in "AB": sdfg.add_array( @@ -212,7 +212,7 @@ def test_local_double_buffering_no_connection(): def test_local_double_buffering_no_apply(): """Here it does not apply, because are all distinct.""" - sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("local_double_buffering_no_apply")) state = sdfg.add_state(is_start_block=True) for name in "AB": sdfg.add_array( @@ -237,7 +237,7 @@ def test_local_double_buffering_no_apply(): def test_local_double_buffering_already_buffered(): """It is already buffered.""" - sdfg = dace.SDFG(util.unique_name("local_double_buffering_no_apply")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("local_double_buffering_no_apply")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( "A", diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py index ef13fd6ea7..375f8da1d5 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_dead_dataflow_elimination.py @@ -19,7 +19,7 @@ def _make_empty_memlets_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("empty_memlets")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("empty_memlets")) state = sdfg.add_state(is_start_block=True) anames = ["a", "b", "c"] @@ -46,7 +46,7 @@ def _make_empty_memlets_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: def _make_zero_iter_step_map() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("empty_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("empty_map")) state = sdfg.add_state(is_start_block=True) anames = ["a", "b", "c"] diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index cadf525159..a2673295a8 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -21,7 +21,7 @@ def _mk_distributed_buffer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("distributed_buffer_sdfg")) for name in ["a", "b", "tmp"]: sdfg.add_array(name, shape=(10, 10), dtype=dace.float64, transient=False) @@ -84,7 +84,9 @@ def test_distributed_buffer_remover(): def _make_distributed_buffer_global_memory_data_race_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_global_memory_data_race") + ) arr_names = ["a", "b", "t"] for name in arr_names: sdfg.add_array( @@ -136,7 +138,9 @@ def test_distributed_buffer_global_memory_data_race(): def _make_distributed_buffer_global_memory_data_race_sdfg2() -> tuple[ dace.SDFG, dace.SDFGState, dace.SDFGState ]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race2_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_global_memory_data_race2_sdfg") + ) arr_names = ["a", "b", "t"] for name in arr_names: sdfg.add_array( @@ -189,7 +193,9 @@ def test_distributed_buffer_global_memory_data_race2(): def _make_distributed_buffer_global_memory_data_no_rance() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_global_memory_data_no_rance_sdfg") + ) arr_names = ["a", "t"] for name in arr_names: sdfg.add_array( @@ -235,7 +241,11 @@ def test_distributed_buffer_global_memory_data_no_rance(): def _make_distributed_buffer_global_memory_data_no_rance2() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_no_rance2_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name( + "distributed_buffer_global_memory_data_no_rance2_sdfg" + ) + ) arr_names = ["a", "t"] for name in arr_names: sdfg.add_array( @@ -291,7 +301,9 @@ def test_distributed_buffer_global_memory_data_no_rance2(): def _make_distributed_buffer_non_sink_temporary_sdfg() -> tuple[ dace.SDFG, dace.SDFGState, dace.SDFGState ]: - sdfg = dace.SDFG(util.unique_name("distributed_buffer_non_sink_temporary_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("distributed_buffer_non_sink_temporary_sdfg") + ) state = sdfg.add_state(is_start_block=True) wb_state = sdfg.add_state_after(state) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py index da70155cf3..6579f1219e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_double_write_remover.py @@ -25,7 +25,7 @@ def _make_double_write_single_consumer( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.AccessNode, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("double_write_elimination_1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_write_elimination_1")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -148,7 +148,7 @@ def _make_double_write_multi_consumer( dace_nodes.AccessNode, dace_nodes.MapExit, ]: - sdfg = dace.SDFG(util.unique_name("double_write_elimination_2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_write_elimination_2")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -255,7 +255,9 @@ def _make_double_write_multi_producer_map( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.AccessNode, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("double_write_elimination_multi_producer_map")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("double_write_elimination_multi_producer_map") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index ec23b4ee24..39004d3786 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -32,7 +32,7 @@ def _make_if_block_with_tasklet( b2_type: dace.typeclass = dace.float64, output_type: dace.typeclass = dace.float64, ) -> dace_nodes.NestedSDFG: - inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) + inner_sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_stmt_")) types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} for name in {b1_name, b2_name, cond_name, output_name}: @@ -42,7 +42,7 @@ def _make_if_block_with_tasklet( transient=False, ) - if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) + if_region = dace.sdfg.state.ConditionalBlock(gtx_transformations.utils.unique_name("if")) inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) @@ -87,7 +87,7 @@ def _make_if_block_with_tasklet( def _make_map_with_conditional_blocks() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("map_with_conditional_blocks")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_with_conditional_blocks")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index dd4927bb03..2fcacd191d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -15,6 +15,7 @@ from gt4py.next.program_processors.runners.dace.transformations import ( gpu_utils as gtx_dace_fieldview_gpu_utils, + utils as gtx_transformations_utils, ) from . import util @@ -42,7 +43,7 @@ def _get_trivial_gpu_promotable( tasklet_code: The body of the Tasklet inside the trivial map. trivial_map_range: Range of the trivial map, defaults to `"0"`. """ - sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) @@ -155,7 +156,7 @@ def test_trivial_gpu_map_promoter_2(): @pytest.mark.parametrize("method", [0, 1]) def test_set_gpu_properties(method: int): """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()`.""" - sdfg = dace.SDFG(util.unique_name("gpu_properties_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_properties_test")) state = sdfg.add_state(is_start_block=True) map_entries: dict[int, dace_nodes.MapEntry] = {} @@ -232,7 +233,7 @@ def test_set_gpu_properties(method: int): def test_set_gpu_properties_1D(): """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()` with 1D maps.""" - sdfg = dace.SDFG(util.unique_name("gpu_properties_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_properties_test")) state = sdfg.add_state(is_start_block=True) map_entries: dict[int, dace_nodes.MapEntry] = {} @@ -299,7 +300,7 @@ def test_set_gpu_properties_1D(): def test_set_gpu_properties_2D_3D(): """Tests the `gtx_dace_fieldview_gpu_utils.gt_set_gpu_blocksize()` with 2D, 3D and 4D maps.""" - sdfg = dace.SDFG(util.unique_name("gpu_properties_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_properties_test")) state = sdfg.add_state(is_start_block=True) map_entries: dict[int, dace_nodes.MapEntry] = {} @@ -355,7 +356,7 @@ def test_set_gpu_properties_2D_3D(): def test_set_gpu_maxnreg(): """Tests if gpu_maxnreg property is set correctly to GPU maps.""" - sdfg = dace.SDFG(util.unique_name("gpu_maxnreg_test")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("gpu_maxnreg_test")) state = sdfg.add_state(is_start_block=True) dim = 2 shape = (10,) * (dim - 1) + (1,) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py index d7fd220b5f..6bad175597 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_horizontal_map_split_fusion.py @@ -34,7 +34,7 @@ def _make_sdfg_with_multiple_maps_that_share_inputs( - Outputs: out1[i, j], out2[i, j], out3[i, j], out4[i, j] """ shape = (N, N) - sdfg = dace.SDFG(util.unique_name("multiple_maps")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_maps")) state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "c", "d", "out1", "out2", "out3", "out4"]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py index b0c7df401d..6e6afecfb0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_inline_fuser.py @@ -33,7 +33,7 @@ def _create_simple_fusable_sdfg() -> tuple[ dace_nodes.AccessNode, dace_graph.MultiConnectorEdge[dace.Memlet], ]: - sdfg = dace.SDFG(util.unique_name(f"simple_fusable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"simple_fusable_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "abc": @@ -121,7 +121,7 @@ def _make_laplap_sdfg( dace_graph.MultiConnectorEdge[dace.Memlet], dace_nodes.Tasklet, ]: - sdfg = dace.SDFG(util.unique_name(f"laplap1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"laplap1")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -223,7 +223,7 @@ def _make_multiple_value_read_sdfg( dace_nodes.MapEntry, dace_graph.MultiConnectorEdge[dace.Memlet], ]: - sdfg = dace.SDFG(util.unique_name(f"multiple_value_generator")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"multiple_value_generator")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -367,7 +367,7 @@ def test_multiple_value_exchange_partial(): def _make_sdfg_with_dref_tasklet(): - sdfg = dace.SDFG(util.unique_name(f"sdfg_with_dref_target")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"sdfg_with_dref_target")) state = sdfg.add_state(is_start_block=True) for name in "abc": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index a94e3262c5..683ad9ebf0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -32,7 +32,7 @@ def _get_simple_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], np can be taken out. This is because how it is constructed. However, applying some simplistic transformations will enable the transformation. """ - sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -55,7 +55,7 @@ def _get_chained_sdfg() -> tuple[dace.SDFG, Callable[[np.ndarray, np.ndarray], n The bottom Tasklet is the only dependent Tasklet. """ - sdfg = dace.SDFG(util.unique_name("chained_block_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("chained_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -159,7 +159,7 @@ def _get_sdfg_with_empty_memlet( is either dependent or independent), the access node between the tasklets and the second tasklet that is always dependent. """ - sdfg = dace.SDFG(util.unique_name("empty_memlet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("empty_memlet_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("M", dace.int32) @@ -404,7 +404,7 @@ def test_direct_map_exit_connection() -> dace.SDFG: Because the tasklet is connected to the map exit it can not be independent. """ - sdfg = dace.SDFG(util.unique_name("mapped_tasklet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("mapped_tasklet_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_array("a", (10,), dace.float64, transient=False) sdfg.add_array("b", (10, 30), dace.float64, transient=False) @@ -591,7 +591,7 @@ def _make_loop_blocking_sdfg_with_inner_map( The function will return the SDFG, the state and the map entry for the outer and inner map. """ - sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_inner_map")) state = sdfg.add_state(is_start_block=True) for name in "AB": @@ -756,7 +756,7 @@ def _make_loop_blocking_sdfg_with_independent_inner_map() -> tuple[ """ Creates a nested Map that is independent from the blocking parameter. """ - sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_independent_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_inner_independent_map")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40, 3), dtype=dace.float64, transient=False) @@ -833,7 +833,7 @@ def _make_loop_blocking_with_reduction( Depending on `reduction_is_dependent` the node is either dependent or not. """ sdfg = dace.SDFG( - util.unique_name( + gtx_transformations.utils.unique_name( "sdfg_with_" + ("" if reduction_is_dependent else "in") + "dependent_reduction" ) ) @@ -954,7 +954,7 @@ def _make_mixed_memlet_sdfg( - `tskl1`. - `tskl2`. """ - sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("mixed_memlet_sdfg")) state = sdfg.add_state(is_start_block=True) names_array = ["A", "B", "C"] names_scalar = ["tmp1", "tmp2"] @@ -1148,7 +1148,7 @@ def test_loop_blocking_mixed_memlets_2(): def test_loop_blocking_no_independent_nodes(): import dace - sdfg = dace.SDFG(util.unique_name("mixed_memlet_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("mixed_memlet_sdfg")) state = sdfg.add_state(is_start_block=True) names = ["A", "B", "C"] for aname in names: @@ -1210,7 +1210,7 @@ def test_loop_blocking_no_independent_nodes(): def _make_only_last_two_elements_sdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("simple_block_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_block_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_symbol("N", dace.int32) sdfg.add_symbol("B", dace.int32) @@ -1279,7 +1279,7 @@ def test_blocking_size_too_big(): Here the blocking size is larger than the size in that dimension. Thus the transformation will not apply. """ - sdfg = dace.SDFG(util.unique_name("blocking_size_too_large")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("blocking_size_too_large")) state = sdfg.add_state(is_start_block=True) for name in "ab": @@ -1325,7 +1325,7 @@ def test_blocking_size_too_big(): def _make_loop_blocking_sdfg_with_semi_independent_map() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry ]: - sdfg = dace.SDFG(util.unique_name("sdfg_with_inner_semi_independent_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_inner_semi_independent_map")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40, 3), dtype=dace.float64, transient=False) @@ -1415,7 +1415,7 @@ def test_loop_blocking_sdfg_with_semi_independent_map(): def _make_loop_blocking_only_independent_inner_map() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry ]: - sdfg = dace.SDFG(util.unique_name("sdfg_with_only_independent_inner_map")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_only_independent_inner_map")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40, 3), dtype=dace.float64, transient=False) @@ -1479,7 +1479,7 @@ def _make_loop_blocking_output_access_node( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.Tasklet, dace_nodes.AccessNode ]: - sdfg = dace.SDFG(util.unique_name("sdfg_with_direct_output_access_node")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("sdfg_with_direct_output_access_node")) state = sdfg.add_state(is_start_block=True) sdfg.add_array("A", shape=(40,), dtype=dace.float64, transient=False) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py index d8cf8e33f8..1e7ce47197 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_make_transients_persistent.py @@ -21,7 +21,9 @@ def _make_transients_persistent_inner_access_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("transients_persistent_inner_access_sdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("transients_persistent_inner_access_sdfg") + ) state = sdfg.add_state(is_start_block=True) for name in "abc": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py index aa5fce9d76..1bf30f3697 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_buffer_elimination.py @@ -48,7 +48,7 @@ def _make_test_sdfg( if out_offset is None: out_offset = in_offset - sdfg = dace.SDFG(util.unique_name("map_buffer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) state = sdfg.add_state(is_start_block=True) names = {input_name, tmp_name, output_name} for name in names: @@ -220,7 +220,7 @@ def test_map_buffer_elimination_offset_5(): def test_map_buffer_elimination_not_apply(): """Indirect accessing, because of this the double buffer is needed.""" - sdfg = dace.SDFG(util.unique_name("map_buffer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) state = sdfg.add_state(is_start_block=True) names = ["A", "tmp", "idx"] @@ -269,7 +269,7 @@ def test_map_buffer_elimination_with_nested_sdfgs(): stride1, stride2, stride3 = [dace.symbol(f"stride{i}", dace.int32) for i in range(3)] # top-level sdfg - sdfg = dace.SDFG(util.unique_name("map_buffer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) inp, inp_desc = sdfg.add_array("__inp", (10,), dace.float64) out, out_desc = sdfg.add_array( "__out", (10, 10, 10), dace.float64, strides=(stride1, stride2, stride3) @@ -278,14 +278,14 @@ def test_map_buffer_elimination_with_nested_sdfgs(): state = sdfg.add_state() tmp_node = state.add_access(tmp) - nsdfg1 = dace.SDFG(util.unique_name("map_buffer")) + nsdfg1 = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) inp1, inp1_desc = nsdfg1.add_array("__inp", (10,), dace.float64) out1, out1_desc = nsdfg1.add_array("__out", (10, 10), dace.float64) tmp1, _ = nsdfg1.add_temp_transient_like(out1_desc) state1 = nsdfg1.add_state() tmp1_node = state1.add_access(tmp1) - nsdfg2 = dace.SDFG(util.unique_name("map_buffer")) + nsdfg2 = dace.SDFG(gtx_transformations.utils.unique_name("map_buffer")) inp2, _ = nsdfg2.add_array("__inp", (10,), dace.float64) out2, out2_desc = nsdfg2.add_array("__out", (10,), dace.float64) tmp2, _ = nsdfg2.add_temp_transient_like(out2_desc) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py index 83e0bd921a..3fd392946c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_fusion_utils.py @@ -15,6 +15,7 @@ from gt4py.next.program_processors.runners.dace.transformations import ( map_fusion_utils as gtx_map_fusion_utils, + utils as gtx_transformations_utils, ) import numpy as np @@ -24,7 +25,7 @@ def test_copy_map_graph(): N = dace.symbol("N", dace.int32) - sdfg = dace.SDFG(util.unique_name("copy_map_graph")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("copy_map_graph")) A, A_desc = sdfg.add_array("A", [N], dtype=dace.float64) B = sdfg.add_datadesc("B", A_desc.clone()) st = sdfg.add_state() @@ -103,7 +104,7 @@ def test_copy_map_graph(): def test_split_overlapping_map_range(map_ranges): first_ndrange, second_ndrange = map_ranges[0:2] - sdfg = dace.SDFG(util.unique_name("split_overlapping_map_range")) + sdfg = dace.SDFG(gtx_transformations_utils.unique_name("split_overlapping_map_range")) st = sdfg.add_state() first_map_entry, _ = st.add_map("first", first_ndrange) second_map_entry, _ = st.add_map("second", second_ndrange) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py index c067e47f4f..95132c39f0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -62,7 +62,7 @@ def _perform_reorder_test( def _make_test_sdfg(map_params: list[str]) -> dace.SDFG: """Generate an SDFG for the test.""" - sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("gpu_promotable_sdfg")) state: dace.SDFGState = sdfg.add_state("state", is_start_block=True) dim = len(map_params) for aname in ["a", "b"]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py index 435c37d629..cef3058ab0 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_promoter.py @@ -31,7 +31,7 @@ def _make_serial_map_promotion_sdfg( ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: shape_1d = (N,) shape_2d = (N, N) - sdfg = dace.SDFG(util.unique_name("serial_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("serial_promotable_sdfg")) state = sdfg.add_state(is_start_block=True) # 1D Arrays @@ -252,7 +252,7 @@ def test_serial_map_promotion_on_symbolic_range(use_symbolic_range): def test_serial_map_promotion_2d_top_1d_bottom(): """Does not apply because the bottom map needs to be promoted.""" - sdfg = dace.SDFG(util.unique_name("serial_map_promoter_2d_map_on_top")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("serial_map_promoter_2d_map_on_top")) state = sdfg.add_state(is_start_block=True) # 2D Arrays @@ -311,7 +311,7 @@ def test_serial_map_promotion_2d_top_1d_bottom(): def _make_horizontal_promoter_sdfg( d1_map_is_vertical: bool, ) -> tuple[dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapEntry]: - sdfg = dace.SDFG(util.unique_name("serial_map_promoter_tester")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("serial_map_promoter_tester")) state = sdfg.add_state(is_start_block=True) h_idx = gtx_dace_lowering.get_map_variable(gtx_common.Dimension("boden")) @@ -447,7 +447,9 @@ def test_horizonal_promotion_promotion_and_merge(d1_map_is_vertical: bool): def _make_sdfg_different_1d_map_name( d1_map_param: str, ) -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("serial_map_promoter_different_names_" + d1_map_param)) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("serial_map_promoter_different_names_" + d1_map_param) + ) state = sdfg.add_state(is_start_block=True) # 1D Arrays diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py index 6d569e79b5..28f0653922 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_splitter.py @@ -22,7 +22,7 @@ def _make_sdfg_simple() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("simple_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "abc": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py index f4c9020de4..d6dc0ae54a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_to_copy.py @@ -21,7 +21,7 @@ def _make_sdfg_1( consumer_is_map: bool, ) -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("simple_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_sdfg")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py index 5338b49920..89f8439472 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_dataflow_into_if_body.py @@ -37,7 +37,7 @@ def _make_if_block( b2_type: dace.typeclass = dace.float64, output_type: dace.typeclass = dace.float64, ) -> dace_nodes.NestedSDFG: - inner_sdfg = dace.SDFG(util.unique_name("if_stmt_")) + inner_sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_stmt_")) types = {b1_name: b1_type, b2_name: b2_type, cond_name: dace.bool_, output_name: output_type} for name in {b1_name, b2_name, cond_name, output_name}: @@ -47,7 +47,7 @@ def _make_if_block( transient=False, ) - if_region = dace.sdfg.state.ConditionalBlock(util.unique_name("if")) + if_region = dace.sdfg.state.ConditionalBlock(gtx_transformations.utils.unique_name("if")) inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) @@ -126,7 +126,7 @@ def test_if_mover_independent_branches(): d = b ``` """ - sdfg = dace.SDFG(util.unique_name("independent_branches")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("independent_branches")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -237,7 +237,7 @@ def test_if_mover_independent_branches(): def test_if_mover_invalid_if_block(): - sdfg = dace.SDFG(util.unique_name("invalid")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("invalid")) state = sdfg.add_state(is_start_block=True) input_names = ["a", "b", "c", "d"] @@ -346,7 +346,7 @@ def test_if_mover_dependent_branch_1(): d = b ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_dependent_branches")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -487,7 +487,7 @@ def test_if_mover_dependent_branch_2(): d = b1 ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_dependent_branches_2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches_2")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -587,7 +587,7 @@ def test_if_mover_dependent_branch_3(): Very similar test to `test_if_mover_dependent_branch_1()`, but the common data is an AccessNode outside the Map. """ - sdfg = dace.SDFG(util.unique_name("if_mover_dependent_branches")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_dependent_branches")) state = sdfg.add_state(is_start_block=True) gnames = ["a", "b", "c", "d", "cond"] @@ -683,7 +683,7 @@ def test_if_mover_no_ops(): ``` I.e. there is no gain from moving something inside the body. """ - sdfg = dace.SDFG(util.unique_name("if_mover_no_ops")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_no_ops")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -743,7 +743,7 @@ def test_if_mover_one_branch_is_nothing(): I.e. in one case something can be moved in but there is nothing to move for the other branch. """ - sdfg = dace.SDFG(util.unique_name("if_mover_one_branch_is_nothing")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_one_branch_is_nothing")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -833,7 +833,7 @@ def test_if_mover_chain(): e = aa if cc else bb ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_chain_of_blocks")) state = sdfg.add_state(is_start_block=True) # Inputs @@ -947,7 +947,7 @@ def test_if_mover_chain(): def test_if_mover_symbolic_tasklet(): - sdfg = dace.SDFG(util.unique_name("if_mover_symbols_in_tasklets")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_symbols_in_tasklets")) state = sdfg.add_state(is_start_block=True) for i in [1, 2]: @@ -1062,7 +1062,7 @@ def test_if_mover_access_node_between(): e = aa if cc else bb ``` """ - sdfg = dace.SDFG(util.unique_name("if_mover_chain_of_blocks")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("if_mover_chain_of_blocks")) state = sdfg.add_state(is_start_block=True) # Inputs diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py index 2ca57691f2..51a577e467 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_move_tasklet_into_map.py @@ -26,7 +26,7 @@ def _make_movable_tasklet( ) -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.Tasklet, dace_nodes.AccessNode, dace_nodes.MapEntry ]: - sdfg = dace.SDFG(util.unique_name("gpu_promotable_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("gpu_promotable_sdfg")) state = sdfg.add_state("state", is_start_block=True) sdfg.add_scalar("outer_scalar", dtype=dace.float64, transient=True) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py index fc07d3a831..30dd8de029 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multi_state_global_self_copy_elimination.py @@ -32,7 +32,7 @@ def apply_distributed_self_copy_elimination( def _make_not_apply_because_of_write_to_g_sdfg() -> dace.SDFG: """This SDFG is not eligible, because there is a write to `G`.""" - sdfg = dace.SDFG(util.unique_name("not_apply_because_of_write_to_g_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("not_apply_because_of_write_to_g_sdfg")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -83,7 +83,7 @@ def _make_eligible_sdfg_1() -> dace.SDFG: The main difference is that there is no mutating write to `a` and thus the transformation applies. """ - sdfg = dace.SDFG(util.unique_name("_make_eligible_sdfg_1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("_make_eligible_sdfg_1")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -126,7 +126,7 @@ def _make_eligible_sdfg_1() -> dace.SDFG: def _make_multiple_temporaries_sdfg1() -> dace.SDFG: """Generates an SDFG in which `G` is saved into different temporaries.""" - sdfg = dace.SDFG(util.unique_name("multiple_temporaries")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_temporaries")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -170,7 +170,7 @@ def _make_multiple_temporaries_sdfg2() -> dace.SDFG: generated by `_make_multiple_temporaries_sdfg()` is that the temporaries are used sequentially. """ - sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_temporaries_sequential")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -281,7 +281,7 @@ def _make_non_eligible_because_of_pseudo_temporary() -> dace.SDFG: Note that in this particular case it would be possible, but we do not support it. """ - sdfg = dace.SDFG(util.unique_name("multiple_temporaries_sequential")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("multiple_temporaries_sequential")) # This is the `G` array. sdfg.add_array(name="a", shape=(10,), dtype=dace.float64, transient=False) @@ -321,7 +321,7 @@ def _make_wb_single_state_sdfg() -> dace.SDFG: This pattern is handled by the `SingleStateGlobalSelfCopyElimination` transformation. """ - sdfg = dace.SDFG(util.unique_name("single_state_write_back_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("single_state_write_back_sdfg")) sdfg.add_array("g", shape=(10,), dtype=dace.float64, transient=False) sdfg.add_array("t", shape=(10,), dtype=dace.float64, transient=True) @@ -351,7 +351,7 @@ def _make_wb_single_state_sdfg() -> dace.SDFG: def _make_non_eligible_sdfg_with_branches(): """Creates an SDFG with two different definitions of `T`.""" - sdfg = dace.SDFG(util.unique_name("non_eligible_sdfg_with_branches_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("non_eligible_sdfg_with_branches_sdfg")) # This is the `G` array, it is also used as output. sdfg.add_array("a", shape=(10,), dtype=dace.float64, transient=False) @@ -400,7 +400,7 @@ def _make_write_into_global_at_t_definition() -> tuple[ This SDFG is different from `_make_not_apply_because_of_write_to_g_sdfg` as the write happens before we define `t`. """ - sdfg = dace.SDFG(util.unique_name("write_into_global_at_t_definition")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("write_into_global_at_t_definition")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py index b763393253..963b3321bc 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_multiple_copies_global.py @@ -22,7 +22,7 @@ def test_complex_copies_global_access_node(): N = 64 K = 80 - sdfg = dace.SDFG(util.unique_name("vertically_implicit_solver_like_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("vertically_implicit_solver_like_sdfg")) A, _ = sdfg.add_array("A", [N, K + 1], dtype=dace.float64) B, _ = sdfg.add_array("B", [N, K + 1], dtype=dace.float64) tmp0, _ = sdfg.add_temp_transient([N], dtype=dace.float64) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py index a755504f40..e7aa91c517 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py @@ -24,7 +24,7 @@ def _make_map_with_scalar_copies() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.MapEntry, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("scalar_elimination")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("scalar_elimination")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py index 04312fb7c5..d86529f79d 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_point_view.py @@ -25,7 +25,7 @@ def _make_sdfg_with_map_with_view( use_array_as_temp: bool = False, ) -> dace.SDFG: shape = (N, N) - sdfg = dace.SDFG(util.unique_name("simple_map_with_view")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_map_with_view")) state = sdfg.add_state(is_start_block=True) for name in ["a", "out"]: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py index 306f05454e..76e99a4678 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_single_state_global_self_copy_elimination.py @@ -25,7 +25,7 @@ def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: """Generates an SDFG that contains the self copying pattern.""" - sdfg = dace.SDFG(util.unique_name("self_copy_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("self_copy_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "GT": @@ -46,7 +46,7 @@ def _make_self_copy_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: def _make_direct_self_copy_elimination_used_sdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("direct_self_copy_elimination_used")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("direct_self_copy_elimination_used")) state = sdfg.add_state(is_start_block=True) for name in "ABCG": @@ -88,7 +88,7 @@ def _make_self_copy_sdfg_with_multiple_paths() -> tuple[ `SingleStateGlobalDirectSelfCopyElimination` transformation can not handle this case, but its split node can handle it. """ - sdfg = dace.SDFG(util.unique_name("self_copy_sdfg_with_multiple_paths")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("self_copy_sdfg_with_multiple_paths")) state = sdfg.add_state(is_start_block=True) for name in "GT": @@ -132,7 +132,9 @@ def _make_concat_where_like( j_idx = 0 j_range = "1:10" - sdfg = dace.SDFG(util.unique_name(f"self_copy_sdfg_concat_where_like_{j_desc}_level")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name(f"self_copy_sdfg_concat_where_like_{j_desc}_level") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -215,7 +217,11 @@ def _make_concat_where_like_with_silent_write_to_g1( dace_nodes.AccessNode, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name(f"self_copy_sdfg_concat_where_like_multiple_writes_to_g1")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name( + f"self_copy_sdfg_concat_where_like_multiple_writes_to_g1" + ) + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -307,7 +313,9 @@ def _make_concat_where_like_41_to_60( dace_nodes.AccessNode, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name(f"self_copy_sdfg_concat_where_like_41_to_60")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name(f"self_copy_sdfg_concat_where_like_41_to_60") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -443,7 +451,7 @@ def _make_concat_where_like_not_possible() -> tuple[ """Because the "Bulk Map" writes more into `tmp` than is written back the transformation is not applicable. """ - sdfg = dace.SDFG(util.unique_name(f"self_copy_too_big_bulk_write")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"self_copy_too_big_bulk_write")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( @@ -499,7 +507,7 @@ def _make_multi_t_patch_sdfg() -> tuple[ uninitialized because it is not read. """ - sdfg = dace.SDFG(util.unique_name(f"multi_t_patch_description_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"multi_t_patch_description_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "gabt": @@ -582,7 +590,7 @@ def _make_not_everything_is_written_back( in fact running `SplitAccessNode` would do the job. """ - sdfg = dace.SDFG(util.unique_name(f"not_everything_is_written_back")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"not_everything_is_written_back")) state = sdfg.add_state(is_start_block=True) consumer_shape = (10,) if consume_all_of_t else (6,) @@ -645,7 +653,7 @@ def _make_not_everything_is_written_back( def _make_write_write_conflict( conflict_at_t: bool, ) -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"write_write_conflict_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"write_write_conflict_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "gto": @@ -717,7 +725,7 @@ def _make_write_write_conflict( def _make_read_write_conflict() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name(f"write_write_conflict_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name(f"write_write_conflict_sdfg")) state = sdfg.add_state(is_start_block=True) for name in "gtom": @@ -896,7 +904,7 @@ def test_global_self_copy_elimination_tmp_downstream(): def test_direct_global_self_copy_simple(): - sdfg = dace.SDFG(util.unique_name("simple_direct_self_copy")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_direct_self_copy")) state = sdfg.add_state(is_start_block=True) sdfg.add_array( diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py index be7ed5405d..1bfa02cfe1 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_access_node.py @@ -64,7 +64,7 @@ def _perform_test( def test_map_producer_ac_consumer(): """The data is generated by a Map and then consumed by an AccessNode.""" - sdfg = dace.SDFG(util.unique_name("map_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -96,7 +96,7 @@ def test_map_producer_ac_consumer(): def test_map_producer_map_consumer(): """The data is generated by a Map and then consumed by another Map.""" - sdfg = dace.SDFG(util.unique_name("map_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -135,7 +135,7 @@ def test_map_producer_map_consumer(): def test_ac_producer_ac_consumer(): - sdfg = dace.SDFG(util.unique_name("ac_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -159,7 +159,7 @@ def test_ac_producer_ac_consumer(): def test_ac_producer_map_consumer(): - sdfg = dace.SDFG(util.unique_name("ac_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": @@ -193,7 +193,9 @@ def test_simple_splitable_ac_source_not_full_consume(): """Similar to `test_simple_splitable_ac_source_full_consume`, but one consumer does not fully consumer what is produced. """ - sdfg = dace.SDFG(util.unique_name("simple_splitable_ac_source_not_full_consume")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("simple_splitable_ac_source_not_full_consume") + ) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -221,7 +223,9 @@ def test_simple_splitable_ac_source_multiple_consumer(): """Similar to `test_simple_splitable_ac_source_not_full_consume`, but there are multiple consumer, per producer. """ - sdfg = dace.SDFG(util.unique_name("simple_splitable_ac_source_multiple_consumer")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("simple_splitable_ac_source_multiple_consumer") + ) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -259,7 +263,9 @@ def _make_transient_producer_sdfg( depending on the value of `partial_read`. """ sdfg = dace.SDFG( - util.unique_name("partial_ac_read" + ("_partial_read" if partial_read else "_full_read")) + gtx_transformations.utils.unique_name( + "partial_ac_read" + ("_partial_read" if partial_read else "_full_read") + ) ) state = sdfg.add_state(is_start_block=True) for name in ["a", "b", "c", "d", "t1", "t2"]: @@ -302,7 +308,7 @@ def test_transient_producer_partial_read(): def test_overlapping_consume_ac_source(): """There are 2 producers, but only one consumer that needs both producers.""" - sdfg = dace.SDFG(util.unique_name("overlapping_consume_ac_source")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("overlapping_consume_ac_source")) state = sdfg.add_state(is_start_block=True) for name in "abtc": @@ -333,7 +339,7 @@ def _make_map_producer_multiple_consumer( `partial_read` it will either read the full data for only parts of it. """ sdfg = dace.SDFG( - util.unique_name( + gtx_transformations.utils.unique_name( "map_producer_map_consumer" + ("_partial_read" if partial_read else "_full_read") ) ) @@ -392,7 +398,7 @@ def test_map_producer_multi_consumer_partialread(): def test_same_producer(): - sdfg = dace.SDFG(util.unique_name("same_producer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("same_producer")) state = sdfg.add_state(is_start_block=True) for name in "abt": @@ -454,7 +460,7 @@ def test_map_producer_complex_map_consumer(): b[15:20] -> e[15:20] c[20:25] -> e[20:25] """ - sdfg = dace.SDFG(util.unique_name("map_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -531,7 +537,7 @@ def test_map_producer_complex_map_consumer(): def test_map_producer_map_consumer_complex(): """The data is generated by a Map and then consumed by another Map.""" - sdfg = dace.SDFG(util.unique_name("map_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -623,7 +629,7 @@ def test_map_producer_ac_consumer_complex(): " nodes that have as inputs Maps whose outputs are assigned to AccessNodes." ) - sdfg = dace.SDFG(util.unique_name("map_producer_map_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_map_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -693,7 +699,7 @@ def test_ac_producer_complex_map_consumer(): " nodes that have AccessNodes as input and Maps as outputs." ) - sdfg = dace.SDFG(util.unique_name("map_producer_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("map_producer_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -760,7 +766,7 @@ def test_ac_producer_complex_ac_consumer(): b[33:38] -> e[22:27] c[14:19] -> e[27:32] """ - sdfg = dace.SDFG(util.unique_name("ac_producer_complex_ac_consumer")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_complex_ac_consumer")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -817,7 +823,7 @@ def test_ac_producer_ac_consumer_complex(): b[28:33] -> d[27:32] b[33:38] -> e[34:39] """ - sdfg = dace.SDFG(util.unique_name("ac_producer_ac_consumer_complex")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("ac_producer_ac_consumer_complex")) state = sdfg.add_state(is_start_block=True) for name in "abtcde": @@ -868,7 +874,7 @@ def test_unused_partial_read_from_inout_node(): a[0:5] -> c[0:5] t[5:10] -> c[5:10] """ - sdfg = dace.SDFG(util.unique_name("full_write_from_global_to_split_node")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("full_write_from_global_to_split_node")) for data in "abc": sdfg.add_array(data, [10], dace.float64) t, _ = sdfg.add_temp_transient([10], dace.float64) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py index e5d5532ade..9dda69f253 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_split_memlet.py @@ -25,7 +25,9 @@ def _make_split_edge_two_ac_producer_one_ac_consumer_1d_sdfg() -> tuple[dace.SDFG, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("split_edge_two_ac_producer_one_ac_consumer_1d")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("split_edge_two_ac_producer_one_ac_consumer_1d") + ) state = sdfg.add_state(is_start_block=True) sdfg.add_array("src1", shape=(10,), dtype=dace.float64, transient=False) @@ -79,7 +81,7 @@ def _make_split_edge_mock_apply_diffusion_to_w_sdfg() -> tuple[ ]: # Test is roughly equivalent to what we see in `apply_diffusion_to_w` # Although instead of Maps we sometimes use direct edges. - sdfg = dace.SDFG(util.unique_name("split_edge_mock_apply_diffusion_to_w")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("split_edge_mock_apply_diffusion_to_w")) state = sdfg.add_state(is_start_block=True) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py index d8373593c5..06e2caf187 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_splitting_tools.py @@ -37,7 +37,7 @@ def _make_distributed_split_sdfg() -> tuple[ dace.SDFGState, dace_nodes.AccessNode, ]: - sdfg = dace.SDFG(util.unique_name("distributed_split_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("distributed_split_sdfg")) state = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state) @@ -115,7 +115,7 @@ def test_distributed_split(): def _make_split_node_simple_sdfg() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.MapExit, dace_nodes.MapExit ]: - sdfg = dace.SDFG(util.unique_name("single_state_split")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("single_state_split")) state = sdfg.add_state(is_start_block=True) for name in "abt": @@ -201,7 +201,7 @@ def test_simple_node_split(): def _make_split_edge_sdfg() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode, dace_nodes.AccessNode ]: - sdfg = dace.SDFG(util.unique_name("split_edge")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("split_edge")) state = sdfg.add_state(is_start_block=True) for name in "abt": @@ -281,7 +281,7 @@ def test_split_edge(): def test_split_edge_2d(): - sdfg = dace.SDFG(util.unique_name("split_edge_2d")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("split_edge_2d")) state = sdfg.add_state(is_start_block=True) for name in "ab": @@ -422,7 +422,7 @@ def test_subset_merging_stability(): def _make_sdfg_for_deterministic_splitting() -> tuple[ dace.SDFG, dace.SDFGState, dace_nodes.AccessNode ]: - sdfg = dace.SDFG(util.unique_name("deterministic_splitter")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("deterministic_splitter")) state = sdfg.add_state(is_start_block=True) for name in "abtcd": diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py index 653a36bad9..d5734f9f27 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_state_fusion.py @@ -25,7 +25,7 @@ def _make_simple_two_state_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("simple_linear_states")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple_linear_states")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -71,7 +71,7 @@ def _make_simple_two_state_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGS def _make_global_in_both_read_and_write() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: """The first state contains a read to a global and the second contains a write to the global.""" - sdfg = dace.SDFG(util.unique_name("global_in_both_states_read_and_write")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_in_both_states_read_and_write")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -106,7 +106,7 @@ def _make_global_in_both_read_and_write() -> tuple[dace.SDFG, dace.SDFGState, da def _make_global_both_state_read() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: """In both states the same global is read.""" - sdfg = dace.SDFG(util.unique_name("global_read_in_the_same_state")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_read_in_the_same_state")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -140,7 +140,7 @@ def _make_global_both_state_read() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFG def _make_global_both_state_write() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: """In both states the same global is written.""" - sdfg = dace.SDFG(util.unique_name("global_write_in_the_same_state")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_write_in_the_same_state")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -176,7 +176,9 @@ def _make_empty_state( first_state_empty: bool, ) -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: sdfg = dace.SDFG( - util.unique_name("global_" + ("first" if first_state_empty else "second") + "_state_empty") + gtx_transformations.utils.unique_name( + "global_" + ("first" if first_state_empty else "second") + "_state_empty" + ) ) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -204,7 +206,7 @@ def _make_empty_state( def _make_global_merge_1() -> dace.SDFG: """The first state writes to a global while the second state reads and writes to it.""" - sdfg = dace.SDFG(util.unique_name("global_merge_1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_merge_1")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -237,7 +239,7 @@ def _make_global_merge_1() -> dace.SDFG: def _make_global_merge_2() -> dace.SDFG: """In both states the global data is read and written to.""" - sdfg = dace.SDFG(util.unique_name("global_merge_2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("global_merge_2")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -269,7 +271,7 @@ def _make_global_merge_2() -> dace.SDFG: def _make_swapping_sdfg() -> dace.SDFG: """Makes an SDFG that implements `x, y = y, x` with Memlets.""" - sdfg = dace.SDFG(util.unique_name("swapping_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("swapping_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -299,7 +301,7 @@ def _make_swapping_sdfg() -> dace.SDFG: def _make_non_concurrent_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("non_concurrent_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("non_concurrent_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -345,7 +347,7 @@ def _make_non_concurrent_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGSta def _make_double_producer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("double_producer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_producer_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -394,7 +396,7 @@ def _make_double_producer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGSt def _make_double_consumer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("double_consumer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("double_consumer_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -468,7 +470,7 @@ def _make_double_consumer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGSt def _make_hidden_double_producer_sdfg() -> tuple[dace.SDFG, dace.SDFGState, dace.SDFGState]: - sdfg = dace.SDFG(util.unique_name("hidden_double_producer_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("hidden_double_producer_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) @@ -624,7 +626,7 @@ def test_empty_second_state(): def test_both_states_are_empty(): - sdfg = dace.SDFG(util.unique_name("full_empty_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("full_empty_sdfg")) state1 = sdfg.add_state(is_start_block=True) state2 = sdfg.add_state_after(state1) assert util.count_nodes(sdfg, dace_nodes.AccessNode) == 0 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index e2fabc6383..f1c8da0143 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -25,7 +25,7 @@ def _make_strides_propagation_level3_sdfg() -> dace.SDFG: """Generates the level 3 SDFG (nested-nested) SDFG for `test_strides_propagation()`.""" - sdfg = dace.SDFG(util.unique_name("level3")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("level3")) state = sdfg.add_state(is_start_block=True) names = ["a3", "c3"] @@ -58,7 +58,7 @@ def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.Neste The function returns the level 2 SDFG and the NestedSDFG node that contains the level 3 SDFG. """ - sdfg = dace.SDFG(util.unique_name("level2")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("level2")) state = sdfg.add_state(is_start_block=True) names = ["a2", "a2_alias", "b2", "c2"] @@ -125,7 +125,7 @@ def _make_strides_propagation_level1_sdfg() -> tuple[ - The NestedSDFG node that contains the lebel 3 SDFG (member of the level 2 SDFG). """ - sdfg = dace.SDFG(util.unique_name("level1")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("level1")) state = sdfg.add_state(is_start_block=True) names = ["a1", "b1", "c1"] @@ -238,7 +238,9 @@ def test_strides_propagation(): def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_nsdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_dependent_symbol_nsdfg") + ) state = sdfg.add_state(is_start_block=True) array_names = ["a2", "b2"] @@ -266,7 +268,9 @@ def _make_strides_propagation_dependent_symbol_nsdfg() -> dace.SDFG: def _make_strides_propagation_dependent_symbol_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_dependent_symbol_sdfg")) + sdfg_level1 = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_dependent_symbol_sdfg") + ) state = sdfg_level1.add_state(is_start_block=True) array_names = ["a1", "b1"] @@ -345,7 +349,9 @@ def test_strides_propagation_symbolic_expression(): def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: - sdfg = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_nsdfg")) + sdfg = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_shared_symbols_nsdfg") + ) state = sdfg.add_state(is_start_block=True) # NOTE: Both arrays have the same symbols used for strides. @@ -379,7 +385,9 @@ def _make_strides_propagation_shared_symbols_nsdfg() -> dace.SDFG: def _make_strides_propagation_shared_symbols_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_shared_symbols_sdfg")) + sdfg_level1 = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_shared_symbols_sdfg") + ) state = sdfg_level1.add_state(is_start_block=True) # NOTE: Both arrays use the same symbols as strides. @@ -471,7 +479,9 @@ def ref(a1, b1): def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: - sdfg_level1 = dace.SDFG(util.unique_name("strides_propagation_stride_1_nsdfg")) + sdfg_level1 = dace.SDFG( + gtx_transformations.utils.unique_name("strides_propagation_stride_1_nsdfg") + ) state = sdfg_level1.add_state(is_start_block=True) a_stride_sym = dace.symbol("a_stride", dtype=dace.uint32) @@ -500,7 +510,7 @@ def _make_strides_propagation_stride_1_nsdfg() -> dace.SDFG: def _make_strides_propagation_stride_1_sdfg() -> tuple[dace.SDFG, dace_nodes.NestedSDFG]: - sdfg = dace.SDFG(util.unique_name("strides_propagation_stride_1_sdfg")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("strides_propagation_stride_1_sdfg")) state = sdfg.add_state(is_start_block=True) a_stride_sym = dace.symbol("a_stride", dtype=dace.uint32) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py index f6e85171d8..44a0a95f2e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_vertical_map_split_fusion.py @@ -23,7 +23,9 @@ def serial_map_sdfg(N, extra_intermediate_edge=False): sdfg = dace.SDFG( - util.unique_name("serial_map" if extra_intermediate_edge else "serial_map_extra_edge") + gtx_transformations.utils.unique_name( + "serial_map" if extra_intermediate_edge else "serial_map_extra_edge" + ) ) A, _ = sdfg.add_array("A", [N], dtype=dace.float64) B, _ = sdfg.add_array("B", [N], dtype=dace.float64) @@ -150,7 +152,7 @@ def test_vertical_map_fusion_disabled(): @pytest.mark.parametrize("run_map_fusion", [True, False]) def test_vertical_map_fusion_with_neighbor_access(run_map_fusion: bool): N = 80 - sdfg = dace.SDFG(util.unique_name("simple")) + sdfg = dace.SDFG(gtx_transformations.utils.unique_name("simple")) A, _ = sdfg.add_array("A", shape=(N,), dtype=dace.float64, strides=(1,)) B, _ = sdfg.add_array("B", shape=(N,), dtype=dace.float64, strides=(1,)) C, _ = sdfg.add_array("C", shape=(N,), dtype=dace.float64, strides=(1,)) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py index 4c8e7f3899..fec2b09537 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/util.py @@ -6,7 +6,6 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import uuid from typing import Literal, Union, overload, Any import numpy as np @@ -14,6 +13,9 @@ import copy from dace.sdfg import nodes as dace_nodes from dace import data as dace_data +from gt4py.next.program_processors.runners.dace.transformations import ( + utils as gtx_transformations_utils, +) @overload @@ -58,15 +60,6 @@ def count_nodes( return len(found_nodes) -def unique_name(name: str) -> str: - """Adds a unique string to `name`.""" - maximal_length = 200 - unique_sufix = str(uuid.uuid1()).replace("-", "_") - if len(name) > (maximal_length - len(unique_sufix)): - name = name[: (maximal_length - len(unique_sufix) - 1)] - return f"{name}_{unique_sufix}" - - def compile_and_run_sdfg( sdfg: dace.SDFG, *args: Any, @@ -82,7 +75,7 @@ def compile_and_run_sdfg( with dace.config.set_temporary("compiler.use_cache", value=False): sdfg_clone = copy.deepcopy(sdfg) - sdfg_clone.name = unique_name(sdfg_clone.name) + sdfg_clone.name = gtx_transformations_utils.unique_name(sdfg_clone.name) sdfg_clone._recompile = True sdfg_clone._regenerate_code = True # TODO(phimuell): Find out if it has an effect. csdfg = sdfg_clone.compile() From 3319c000848799c9ba2ca80f839f65f46c506799 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 14:35:29 +0100 Subject: [PATCH 23/34] Apply review comments on RemoveScalarCopies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philip Müller --- .../runners/dace/transformations/__init__.py | 4 +- .../dace/transformations/auto_optimize.py | 2 +- ...ing_scalars.py => remove_scalar_copies.py} | 48 +++++++++++++------ ...calars.py => test_remove_scalar_copies.py} | 8 ++-- 4 files changed, 39 insertions(+), 23 deletions(-) rename src/gt4py/next/program_processors/runners/dace/transformations/{remove_aliasing_scalars.py => remove_scalar_copies.py} (85%) rename tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/{test_remove_aliasing_scalars.py => test_remove_scalar_copies.py} (93%) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py index 1fe71d20da..0c1ec3424f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/__init__.py @@ -54,7 +54,7 @@ ) from .redundant_array_removers import CopyChainRemover, DoubleWriteRemover, gt_remove_copy_chain from .remove_access_node_copies import RemoveAccessNodeCopies -from .remove_aliasing_scalars import RemoveAliasingScalars +from .remove_scalar_copies import RemoveScalarCopies from .remove_views import RemovePointwiseViews from .scan_loop_unrolling import ScanLoopUnrolling from .simplify import ( @@ -107,8 +107,8 @@ "MultiStateGlobalSelfCopyElimination", "MultiStateGlobalSelfCopyElimination2", "RemoveAccessNodeCopies", - "RemoveAliasingScalars", "RemovePointwiseViews", + "RemoveScalarCopies", "ScanLoopUnrolling", "SingleStateGlobalDirectSelfCopyElimination", "SingleStateGlobalSelfCopyElimination", diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 10a6256fef..74bdb03351 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -730,7 +730,7 @@ def _gt_auto_process_dataflow_inside_maps( single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.RemoveAliasingScalars( + gtx_transformations.RemoveScalarCopies( single_use_data=single_use_data, ), validate=False, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py similarity index 85% rename from src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py rename to src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index 3572f43980..fc7b3288be 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_aliasing_scalars.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -15,7 +15,27 @@ @dace_properties.make_properties -class RemoveAliasingScalars(dace_transformation.SingleStateTransformation): +class RemoveScalarCopies(dace_transformation.SingleStateTransformation): + """Removes copies between two scalar transient variables. + Exaxmple: + ___ + / \ + | A | + \___/ + | + \/ + ___ + / \ + | B | + \___/ + is transformed to + ___ + / \ + | A | + \___/ + and all uses of B are replaced with A. + """ + first_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) second_access_node = dace_transformation.PatternNode(dace_nodes.AccessNode) @@ -59,20 +79,24 @@ def can_be_applied( second_node_desc = second_node.desc(sdfg) scope_dict = graph.scope_dict() - if first_node not in scope_dict or second_node not in scope_dict: - return False # Make sure that both access nodes are in the same scope if scope_dict[first_node] != scope_dict[second_node]: return False # Make sure that both access nodes are transients - if not first_node_desc.transient or not second_node_desc.transient: + if not (first_node_desc.transient and second_node_desc.transient): + return False + + edges = list(graph.edges_between(first_node, second_node)) + if len(edges) != 1: + return False + + # Check that the second access node has only one incoming edge, which is the one from the first access node. + if graph.in_degree(second_node) != 1: return False - edges = graph.edges_between(first_node, second_node) - assert len(edges) == 1 - edge = next(iter(edges)) + edge = edges[0] # Check if the edge transfers only one element if edge.data.num_elements() != 1: @@ -85,7 +109,8 @@ def can_be_applied( if out_edges.data.num_elements() != 1: return False - # Make sure that the edge subset is 1 + # Make sure that the data descriptors of both access nodes are scalars + # TODO(iomaganaris): We could extend this transfromation to handle AccessNodes that are arrays with 1 element as well. if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( second_node_desc, dace.data.Scalar ): @@ -97,13 +122,6 @@ def can_be_applied( ): return False - # Make sure that both access nodes are transients - if not first_node_desc.transient or not second_node_desc.transient: - return False - - if graph.in_degree(second_node) != 1: - return False - # Make sure that both access nodes are single use data if self.assume_single_use_data: single_use_data = {sdfg: {first_node.data}} diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_scalar_copies.py similarity index 93% rename from tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py rename to tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_scalar_copies.py index e7aa91c517..c7616541db 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_aliasing_scalars.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_remove_scalar_copies.py @@ -46,10 +46,8 @@ def _make_map_with_scalar_copies() -> tuple[ tmp0, tmp1, tmp2 = (state.add_access(f"tmp{i}") for i in range(3)) me, mx = state.add_map("copy_map", ndrange={"__i": "0:10"}) - me.add_in_connector("IN_a") - me.add_out_connector("OUT_a") - mx.add_in_connector("IN_b") - mx.add_out_connector("OUT_b") + me.add_scope_connectors("a") + mx.add_scope_connectors("b") state.add_edge(a, None, me, "IN_a", dace.Memlet("a[__i]")) state.add_edge(me, "OUT_a", tmp0, None, dace.Memlet("a[__i]")) state.add_edge(tmp0, None, tmp1, None, dace.Memlet("tmp1[0]")) @@ -72,7 +70,7 @@ def test_remove_double_write_single_consumer(): find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) sdfg.apply_transformations_repeated( - gtx_transformations.RemoveAliasingScalars( + gtx_transformations.RemoveScalarCopies( single_use_data=single_use_data, assume_single_use_data=False, ), From 202aac61703a6428968492f2f66f5669292bacf9 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 14:37:49 +0100 Subject: [PATCH 24/34] Remove _dacegraphs --- _dacegraphs/invalid.sdfgz | Bin 5370 -> 0 bytes 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 _dacegraphs/invalid.sdfgz diff --git a/_dacegraphs/invalid.sdfgz b/_dacegraphs/invalid.sdfgz deleted file mode 100644 index 2869814a1c9a9c302e27afa792eef634827ef0c0..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 5370 zcmV95~>_x$uSL%B9$2`B7ndQRH^Z9b2 zE@%33Y5u*QWhBE?3p2yms~I7l%%s!|esZitCaLADd;>ghrdBKU`SAtCYb!r0<@$SZ z;oWVU>-LLxcV-PIYez&bHlc}4cxV&a z*o22RVSaTzYv99Qjk;8;uy!0)g_bs6sl|Hk7DYqms~dMn_4H|1&zyGk?CDmouxw_Z zpcdoRP%oF4FI>H|N+qY~mrD&~X_nSiEik{@Sb3Fd7fif3pIMa`bTqBj`%*2=Z`65N z!LNV*d0P9X`C^_w$9vqL&-0%1n|T5KX(BbjsWd22BtzU&hNSfZVZt26NhUcJ_qOhZ zT3;0RWWuD95t1^@wN4bKMp2T%fi!VcCxm32WGT^t+BDIU5SLQyOy)DHG51A@%5qE> z%j(zhp}tj9{?}aP?xA{6Iqfb>C>FN*z@h~ZjH`|##h_N8<}HgL~oy}yzraiq92Nj-tAtp&ebZPUCtN9g#fX= z3mmXq%jwPifQOr_MHM1j_qUw(^GkO>Y5ZUSQ~7fBVQp3_Hy0nw+N?j!muj)NQs*DO z|K;z$|NQO$eK^ng_2&=J8u|Le!u^@gHh#ovy0_w{-5O{FS+n;KD(LBzQ6uF-Kx^0Tv3cNpJ0?^8Q07(_#C+o*pefGt%r9LVhm->k`CC?BY~}rust={+D6zaNhu=|<_$@TIFlM> zne<3XZxoX*$zU4M4qJ94u-z7*z=+zOur-CZhKg(c5V8N=FGSKqkm09?+|Kf-AuTm0hULk;rrBPh_glg7--t;)o`>Arepu{xOXY5NU3N_U=^O^l=Jej#+npmH+tE2Ldx&=S5RpB4vd4T;+_G>6>b0)x zqaS`Mr9HcAGxz88?#fKvOO`iz$6NmK`)|#!DNY8x=bq|BDkMpSMV@kpxXCi4d5V$8 zm}16+mYq{2Ie49_H`f)nJq^``bgpUXTsvOpnzqiheL6SOKBb^CKrkVOnS?U}W)cJ= zD>6zP%e>8K*Om`{mJiu)YW?p^*VMWB=4?4)!@G1xG}5FLna!YECSOhd2B=v~{^zR+ zexH1?&#qMw$6B&fB@Fk zpSO*^y%Kdjam!(1c%kome9u#22M}s?P8!a(=@$FcCPcBh-KMqj#%8Hwi}@d-a%{7v zYX(PkEM4E^vGE2q9^j6~rW@3B-;5-xx96L7sp)Q`)LuQ zG&4dbD6xhiSi~=cSCk_~oY02pL~hGmYoAx+7tp%Ed_UN}e)OIG-z|5Hz9aadI+hup zN(tpxpkm-Q;YfS$k&s?;sGS10lW8zbHCL;tYP;3TSIf)qE|-7&;?AL(&8!l{QdPl` z$TB9O+qBoP4k>V2i0V_M661xW*f1}fSm?H~6z%nGt8<6jO#_zimlVK@Yo=l8!d$f9 zvF{A9cS)n;PW`B>cjl?RcF)15?!VT;O?DlD>i(l$7qG4ic<@~T?h%; zP;Q6H-7LTJhQRR@wEWmwi{>TC2AZ!dQzjzlp32H);2m9ynT@aMdCRdcdR{F*x$=_= zR^1=4^q4PUzhkDC_VeC)XYb*#vqv5X%q_`tqEio=(Imz2sq)gn{y9@f!FGf8+*l@( zt}=F(Fn}@xwxQZ?sgY0*DCtwP#@zzfiSS0+n>W(P8^_8UfzYhFrWr~>T~3g7!ho|x zYLF^81nfSqz|jZnLOv36kkes~x;@hXyEG5ItuJ-7OcuLlhqL=qJ?_QcAAmjC|DMqI z_=_-J>Ds0p4DT1=MckVgapc7(+(aIVo?#;}U<11Lrt07cA*9e+bZ%mIl%`>WduVYU z^zv33!V-$8%3vcLlr2aSr4cyVX(DKvxs;9c^`LLw?hN0kc~a9fq+$0aGvk!IDxL#j zd%_ZhWR`%J9JVRa%sRwKYOI2FJLAWi@vu1K>9K`%T-Y14X;YA+Cx}7#23Y9^ zcMF|@yG7!hahk#3T#ST1egx=iIs377na|ntF_-y%yt&NhU*CTQzqRM?3E;NDPA^0A zM|TxwRP--TdC#Z9^mENZ__`jx(VN(eVXx>E{x{y%;t(*jzxy~4w3*BU?rL{9Xl}b^ zM}#4RDs+v#m(jMdBl8QlnniF#aE!73Jn;Lxy#PfWORQkXdJ92j5&`?Inj3_{4U7p0 z*kO4ZOnNI5u^q8pi&9u}h|q}8s57(IJ^ z4x5?;^m8#R!n(E`Ko2WW!Sf6tQEhNulm25Kk_k|-^}<}fL7kd9)IY0;xBait;-ogmYq(CXmC3 zaqvZhWGJBbS4d-#A^{=$g%Um?)V~^+tr8T`XKNDbvvuk&S@c9)D+KQ=Ar8DFN{C;^ zC5!fvh$Y5q4S6UK#O&7~P^2CtVu?W}O;WA|WU{h6x@1L{tQy`(dh4Od+yjoGCu znK=#NCF67=4uCkFh|q}BiKh|1vZI+!JgeH(WY`g!dX%aRq>jvKq$LE2z=CJiU@0-> zuUqZvO8H~A+LZ*VUHM^Z*R#MD>1d6xme@)}S7dXOU{mzAcv@soJ8eVXoprWL)2Nwl zB&sbRjiygYi=uZg8f#51#DP&(8#$<@IQ_NDeTMZYDJNRVHAuzU!94wR%tMk-1Pb^1 ziJAM!nL`E5lMu}p>Ar+^U&6XC;b(^!9UUL#Km~M&9XLcAG9Wd@koBzIn#RGZ#$hQc zh9s#NsBIjiZajT`W1v+&BwX022`j>sm|+U(L5iQ=nxB9YKL@vjOHyAGyQBlyC5`NI z{Ol48*OAj&^xzkAiv03hc&z-vLH7F-OpjC-PnhgtV7iN7@V^*~;sGp*Ba6OB`ioFB z4IP~WlVA)?gArPrL8&l?Cc_9V)R+k|dVF_iv}>OvLB_$RZ|ElLlZDwEHBm+nM6!s; zC-U)m2~3V1p=s0v8a+_Sj|P?P!5cajD8oE|dO>@Drf(P!!*zYblzqdreS_3}L-c)v zg$l!@3jM?i!{rJ?1q*{E3&TVU;j)EL;X<%6}Q!9u^MNyjl%>=$wW zhH*bc0vG}Fb^$X~7&ufK7*^ZcQKl}z9!GtAjt0sEhX@77orK4kN}Q=gHhoUBbIhb4 zqu8e7CH{!EX&f8}*`~b!{058^*274lLyVMu2|}W|63vy!$ZyA789$ZC7zWFT=|qMn z6p5=UabO*`s?sCT$Sbl~Xb+2pKjX4WpTr}FZYcNqu#c92q`lP5=UQqL;F{UbpQ=wX zl6#dwdxwYSBnkA>Y-*^6`XBC^__xI}pKSaSzc_gZ#X}OT9I=C9aH^H)p*ZQCdga92 z59x_H2!$nA8Iq7CO50!D{gC#PAC8fnCAuF*_rtzPTZW`kO-LQ1`^#s;!GvZR3e+knM#DrY0p%~OkOj_bY{c^HgP%=2S=RFM8L%9 z%=5{w*^y0WUO;nwXd0Z?rMd1-i8DxZ&7VqheHLTT91$dGrwA1&DO3=h1B*q+I}`@783Td%GuVY?hH=+C@5=uS0A-SK#?B4 zARMMZX^N3ci_pXlvM(u>z`l;tnkf(AVfB^*XzMw5Iv>%(cU_W>#WN>pUvjYiRJE0@z!g) z&F(G(`!?+RSkrv>%L=NlZM&BR9>AXwU8pb(l3qX*WGWfR-#7zLT;VxMJ;V>OwcXb> z)sJ1+*@fLTu<1Z7#BU}xg-n1T=LJn% z=7_^-8nCAtyYSQ(>qsY~?WB^6d=|w6FX5g*mBX6<20!vbXjWk|m z(LLdfahj~-p4dNu71K+iA+|G4z++v()HB0aAy0fJ zG}l^c8aXX;THUJL08Z19)8c@5Y6k@&S;8btIa%R6$`ye)QbbGSwbHOVs?)SKp=;Fg zp0bvg)H{;WF<8w+R*P_dY7YftA^3xn96?&kB3Cv1BU48p3kSyxCtjsiMplcgwjYXX zAgkHPYH@-288$wJ76#LdAcKpIPi+MRO=GaOMqsl*d74O?y>4QAwLSepq*m`BpB+kS zv+Y#vk<=p8>o-0Roq$r75XLO{#S`g^AYKy0x)9EbdPTtL%6ZrUjg0n|7>(94niz%A zt|O^Mw0D!*q4Q6hf_TFbu*(z$zXOg50u={@Gg49z#!VuLj-(bzEgg&0!eVQ-#H(!J zok#cRB(o4xkfTH}SW%JSBP}7(DWbw=puyysmk+TvH%~AAEuX&=1h_u%?@qAa?Et?! zL4C_mx4VIS%VTBS*-sFM7DCOId&}&#OEko5p*#Lv~%gMtRVDhh_&T#*lPX18q$@%<~TTGypC-e1WvCJoGvU2B8FSl}b zQowRz9YpAzUl(tiFD8(+WU_!*xp4N^Z@&Bf{p8y})z$T-`*L!vAPTR88{sbish#9Y zc$&Xin6m7D%<~IC1{}I7!u`${=M%MBsn6vJc Date: Tue, 3 Feb 2026 14:52:37 +0100 Subject: [PATCH 25/34] Added more elaborative comment in FuseHorizontalConditionBlocks for removal of conditional block --- .../dace/transformations/fuse_horizontal_conditionblocks.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 6ae3fefcd2..19ddb98ff9 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -289,7 +289,10 @@ def _find_corresponding_state_in_second( if edge.dst == second_cb: graph.remove_edge(edge) - # Need to remove both references to remove NestedSDFG from graph + # TODO(iomaganaris): Atm need to remove both references to remove NestedSDFG from graph + # second_conditional_block is inside the SDFG of NestedSDFG second_cb and removing only + # one of them keeps a reference to the other one so none is properly deleted from the SDFG. + # For now remove both but maybe this can be improved in the future. graph.remove_node(second_conditional_block) graph.remove_node(second_cb) From c55a04a5df167a9bc8ff74521c6dd13b784b7c13 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 15:26:29 +0100 Subject: [PATCH 26/34] make sure that the symbol mapping is the same between the fused nested sdfgs --- .../dace/transformations/fuse_horizontal_conditionblocks.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 19ddb98ff9..762fae9ae1 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -96,6 +96,12 @@ def can_be_applied( if first_cb.sdfg.number_of_nodes() > 1 or second_cb.sdfg.number_of_nodes() > 1: return False + # Check that the symbol mappings are compatible + sym_map1 = first_cb.symbol_mapping + sym_map2 = second_cb.symbol_mapping + if any(str(sym_map1[sym]) != str(sym_map2[sym]) for sym in sym_map2 if sym in sym_map1): + return False + # Get the actual conditional blocks first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) From d232b7349feb31c0add37517b96094f4553830dd Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 15:30:23 +0100 Subject: [PATCH 27/34] Added check for NestedSDFGs in FuseHorizontalConditionBlocks --- .../runners/dace/transformations/auto_optimize.py | 2 ++ .../transformations/fuse_horizontal_conditionblocks.py | 7 +++++++ 2 files changed, 9 insertions(+) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 74bdb03351..aeb5ca2699 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -737,6 +737,8 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) + # Make sure that this runs before MoveDataflowIntoIfBody because atm it doesn't handle + # NestedSDFGs inside the ConditionalBlocks it fuses. sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), validate=True, diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 762fae9ae1..f29cdca619 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -167,6 +167,13 @@ def can_be_applied( ): return False + # TODO(iomaganaris): Currently we do not handle NestedSDFGs inside the conditional blocks + # however this should be very close to handling Tasklets which is currently working. Since + # a test is missing and there is no immediate need for this feature, we leave it for future work. + for node in second_cb.sdfg.nodes(): + if isinstance(node, dace_nodes.NestedSDFG): + return False + return True def apply( From ed7835d2e4625f4e635c2d02456091ff424af5b8 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 3 Feb 2026 15:32:37 +0100 Subject: [PATCH 28/34] Handle single use data check only for second AccessNode in RemoveScalarCopies MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philip Müller --- .../transformations/remove_scalar_copies.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index fc7b3288be..fd4c3fc340 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -12,7 +12,7 @@ from dace import properties as dace_properties, transformation as dace_transformation from dace.sdfg import nodes as dace_nodes from dace.transformation import helpers as dace_helpers - +from dace.transformation.passes import analysis as dace_analysis @dace_properties.make_properties class RemoveScalarCopies(dace_transformation.SingleStateTransformation): @@ -110,7 +110,7 @@ def can_be_applied( return False # Make sure that the data descriptors of both access nodes are scalars - # TODO(iomaganaris): We could extend this transfromation to handle AccessNodes that are arrays with 1 element as well. + # TODO(iomaganaris): We could extend this transfromation to handle AccessNodes that are arrays with 1 element as well if not isinstance(first_node_desc, dace.data.Scalar) or not isinstance( second_node_desc, dace.data.Scalar ): @@ -122,20 +122,11 @@ def can_be_applied( ): return False - # Make sure that both access nodes are single use data - if self.assume_single_use_data: - single_use_data = {sdfg: {first_node.data}} - if self._single_use_data is None: - find_single_use_data = first_node.FindSingleUseData() - single_use_data = find_single_use_data.apply_pass(sdfg, None) - else: - single_use_data = self._single_use_data - if first_node.data not in single_use_data[sdfg]: - return False + # Make sure that the second AccessNode data is single use data since we're only going to remove that one if self.assume_single_use_data: single_use_data = {sdfg: {second_node.data}} if self._single_use_data is None: - find_single_use_data = second_node.FindSingleUseData() + find_single_use_data = dace_analysis.FindSingleUseData() single_use_data = find_single_use_data.apply_pass(sdfg, None) else: single_use_data = self._single_use_data From 678f18f85cdb8079ff529c06831d575d9d2ab308 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Wed, 4 Feb 2026 13:50:55 +0100 Subject: [PATCH 29/34] Fix issues with true/false_branch_x_x_x --- .../transformations/fuse_horizontal_conditionblocks.py | 8 ++++---- .../runners/dace/transformations/remove_scalar_copies.py | 1 + .../test_fuse_horizontal_conditionblocks.py | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index f29cdca619..b5f40f5508 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -119,10 +119,10 @@ def can_be_applied( state.name for state in second_conditional_block.all_states() ] if not ( - "true_branch" in first_conditional_block_state_names - and "false_branch" in first_conditional_block_state_names - and "true_branch" in second_conditional_block_state_names - and "false_branch" in second_conditional_block_state_names + any("true_branch" in name for name in first_conditional_block_state_names) + and any("false_branch" in name for name in first_conditional_block_state_names) + and any("true_branch" in name for name in second_conditional_block_state_names) + and any("false_branch" in name for name in second_conditional_block_state_names) ): return False diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index fd4c3fc340..e4ddd56148 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -14,6 +14,7 @@ from dace.transformation import helpers as dace_helpers from dace.transformation.passes import analysis as dace_analysis + @dace_properties.make_properties class RemoveScalarCopies(dace_transformation.SingleStateTransformation): """Removes copies between two scalar transient variables. diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 39004d3786..2eb11a55d7 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -46,7 +46,7 @@ def _make_if_block_with_tasklet( inner_sdfg.add_node(if_region, is_start_block=True) then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) - tstate = then_body.add_state("true_branch", is_start_block=True) + tstate = then_body.add_state("true_branch_0_1_2_3_4", is_start_block=True) tasklet = tstate.add_tasklet( "true_tasklet", inputs={"__tasklet_in"}, @@ -69,7 +69,7 @@ def _make_if_block_with_tasklet( ) else_body = dace.sdfg.state.ControlFlowRegion("else_body", sdfg=inner_sdfg) - fstate = else_body.add_state("false_branch", is_start_block=True) + fstate = else_body.add_state("false_branch_0_1_2_3_4", is_start_block=True) fstate.add_nedge( fstate.add_access(b2_name), fstate.add_access(output_name), From 562742be51469c31f6bfcdd9d605d60df0ef77a5 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 10:41:41 +0100 Subject: [PATCH 30/34] Address Philip's comments --- .../fuse_horizontal_conditionblocks.py | 109 +++++++++--------- .../transformations/remove_scalar_copies.py | 3 +- 2 files changed, 56 insertions(+), 56 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index b5f40f5508..5a6420e0a8 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -93,7 +93,7 @@ def can_be_applied( return False # Make sure that the conditional blocks contain only one conditional block each - if first_cb.sdfg.number_of_nodes() > 1 or second_cb.sdfg.number_of_nodes() > 1: + if first_cb.sdfg.number_of_nodes() != 1 or second_cb.sdfg.number_of_nodes() != 1: return False # Check that the symbol mappings are compatible @@ -170,9 +170,10 @@ def can_be_applied( # TODO(iomaganaris): Currently we do not handle NestedSDFGs inside the conditional blocks # however this should be very close to handling Tasklets which is currently working. Since # a test is missing and there is no immediate need for this feature, we leave it for future work. - for node in second_cb.sdfg.nodes(): - if isinstance(node, dace_nodes.NestedSDFG): - return False + for state in second_cb.sdfg.states(): + for node in state.nodes(): + if isinstance(node, dace_nodes.NestedSDFG): + return False return True @@ -197,30 +198,23 @@ def apply( # Store the new names for the arrays in the second conditional block to avoid name clashes and add their data descriptors # to the first conditional block SDFG - second_arrays_rename_map = {} + second_arrays_rename_map: dict[str, str] = {} for data_name, data_desc in original_arrays_second_conditional_block.items(): if data_name == "__cond": continue - if data_name in original_arrays_first_conditional_block: - new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" - second_arrays_rename_map[data_name] = new_data_name - data_desc_renamed = copy.deepcopy(data_desc) - data_desc_renamed.name = new_data_name - if new_data_name not in first_cb.sdfg.arrays: - first_cb.sdfg.add_datadesc(new_data_name, data_desc_renamed) - else: - second_arrays_rename_map[data_name] = data_name - if data_name not in first_cb.sdfg.arrays: - first_cb.sdfg.add_datadesc(data_name, copy.deepcopy(data_desc)) + new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" + second_arrays_rename_map[data_name] = new_data_name + data_desc_renamed = copy.deepcopy(data_desc) + first_cb.sdfg.add_datadesc(new_data_name, data_desc_renamed) second_conditional_states = list(second_conditional_block.all_states()) # Move the connectors from the second conditional block to the first in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} out_connectors_to_move = second_cb.out_connectors - in_connectors_to_move_rename_map = {} - out_connectors_to_move_rename_map = {} - for original_in_connector_name, _v in in_connectors_to_move.items(): + in_connectors_to_move_rename_map: dict[str, str] = {} + out_connectors_to_move_rename_map: dict[str, str] = {} + for original_in_connector_name in in_connectors_to_move: new_connector_name = original_in_connector_name if new_connector_name in first_cb.in_connectors: new_connector_name = second_arrays_rename_map[original_in_connector_name] @@ -231,7 +225,7 @@ def apply( dace_helpers.redirect_edge( state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb ) - for original_out_connector_name, _v in out_connectors_to_move.items(): + for original_out_connector_name in out_connectors_to_move: new_connector_name = original_out_connector_name if new_connector_name in first_cb.out_connectors: new_connector_name = second_arrays_rename_map[original_out_connector_name] @@ -246,22 +240,24 @@ def apply( def _find_corresponding_state_in_second( inner_state: dace.SDFGState, ) -> dace.SDFGState: - inner_state_name = inner_state.name - true_branch = "true_branch" in inner_state_name - corresponding_state_in_second = None - for state in second_conditional_states: - if true_branch and "true_branch" in state.name: - corresponding_state_in_second = state - break - elif not true_branch and "false_branch" in state.name: - corresponding_state_in_second = state - break - return corresponding_state_in_second + is_true_branch = "true_branch" in inner_state.name + branch_type = "true_branch" if is_true_branch else "false_branch" + return next(state for state in second_conditional_states if branch_type in state.name) + + corresponding_states_first_to_second = { + state: _find_corresponding_state_in_second(state) + for state in first_conditional_block.all_states() + } # Copy first the nodes from the second conditional block to the first - nodes_renamed_map = {} + # Create a dictionary that maps the original nodes in the second conditional + # block to the new nodes in the first conditional block to be able to properly connect the edges later + nodes_renamed_map: dict[dace_nodes.Node, dace_nodes.Node] = {} + # Save edges of second conditional block to a state to be able to delete the nodes from the second conditional block + second_state_edges_to_add: dict[dace.SDFGState, dace_graph.Edge] = {} for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) + corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] + second_state_edges_to_add[corresponding_state_in_second] = [] nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: new_node = node @@ -270,34 +266,39 @@ def _find_corresponding_state_in_second( new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) nodes_renamed_map[node] = new_node + second_state_edges_to_add[corresponding_state_in_second].extend( + corresponding_state_in_second.in_edges(node) + ) + second_state_edges_to_add[corresponding_state_in_second].extend( + corresponding_state_in_second.out_edges(node) + ) + # Remove the original node from the second conditional block to avoid any potential issues + # with the nodes coexisting in two states + corresponding_state_in_second.remove_node(node) first_inner_state.add_node(new_node) # Then copy the edges - second_to_first_connections = {} + second_to_first_connections: dict[str, str] = {} for node in nodes_renamed_map: if isinstance(node, dace_nodes.AccessNode): second_to_first_connections[node.data] = nodes_renamed_map[node].data for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) - nodes_to_move = list(corresponding_state_in_second.nodes()) - for node in nodes_to_move: - for edge in list(corresponding_state_in_second.out_edges(node)): - dst = edge.dst - if dst in nodes_to_move: - new_memlet = copy.deepcopy(edge.data) - if edge.data.data in second_to_first_connections: - new_memlet.data = second_to_first_connections[edge.data.data] - first_inner_state.add_edge( - nodes_renamed_map[node], - nodes_renamed_map[node].data - if isinstance(node, dace_nodes.AccessNode) and edge.src_conn - else edge.src_conn, - nodes_renamed_map[dst], - second_to_first_connections[dst.data] - if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn - else edge.dst_conn, - new_memlet, - ) + corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] + for edge in second_state_edges_to_add[corresponding_state_in_second]: + new_memlet = copy.deepcopy(edge.data) + if edge.data.data in second_to_first_connections: + new_memlet.data = second_to_first_connections[edge.data.data] + first_inner_state.add_edge( + nodes_renamed_map[edge.src], + nodes_renamed_map[edge.src].data + if isinstance(edge.src, dace_nodes.AccessNode) and edge.src_conn + else edge.src_conn, + nodes_renamed_map[edge.dst], + second_to_first_connections[edge.dst.data] + if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn + else edge.dst_conn, + new_memlet, + ) for edge in list(graph.out_edges(conditional_access_node)): if edge.dst == second_cb: graph.remove_edge(edge) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py index e4ddd56148..2cfef5f59d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/remove_scalar_copies.py @@ -92,13 +92,12 @@ def can_be_applied( edges = list(graph.edges_between(first_node, second_node)) if len(edges) != 1: return False + edge = edges[0] # Check that the second access node has only one incoming edge, which is the one from the first access node. if graph.in_degree(second_node) != 1: return False - edge = edges[0] - # Check if the edge transfers only one element if edge.data.num_elements() != 1: return False From 636f4c5137b2809f0d2e0f76f79935308edb2b72 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 14:08:33 +0100 Subject: [PATCH 31/34] Fix issues --- .../fuse_horizontal_conditionblocks.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 5a6420e0a8..b6190de55f 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -216,8 +216,7 @@ def apply( out_connectors_to_move_rename_map: dict[str, str] = {} for original_in_connector_name in in_connectors_to_move: new_connector_name = original_in_connector_name - if new_connector_name in first_cb.in_connectors: - new_connector_name = second_arrays_rename_map[original_in_connector_name] + new_connector_name = second_arrays_rename_map[original_in_connector_name] in_connectors_to_move_rename_map[original_in_connector_name] = new_connector_name first_cb.add_in_connector(new_connector_name) for edge in graph.in_edges(second_cb): @@ -227,8 +226,7 @@ def apply( ) for original_out_connector_name in out_connectors_to_move: new_connector_name = original_out_connector_name - if new_connector_name in first_cb.out_connectors: - new_connector_name = second_arrays_rename_map[original_out_connector_name] + new_connector_name = second_arrays_rename_map[original_out_connector_name] out_connectors_to_move_rename_map[original_out_connector_name] = new_connector_name first_cb.add_out_connector(new_connector_name) for edge in graph.out_edges(second_cb): @@ -262,9 +260,8 @@ def _find_corresponding_state_in_second( for node in nodes_to_move: new_node = node if isinstance(node, dace_nodes.AccessNode): - if node.data in first_cb.in_connectors or node.data in first_cb.out_connectors: - new_data_name = second_arrays_rename_map[node.data] - new_node = dace_nodes.AccessNode(new_data_name) + new_data_name = second_arrays_rename_map[node.data] + new_node = dace_nodes.AccessNode(new_data_name) nodes_renamed_map[node] = new_node second_state_edges_to_add[corresponding_state_in_second].extend( corresponding_state_in_second.in_edges(node) From caf69a359a9c7023cfbedfd2bfeef4595f6fee44 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 16:26:50 +0100 Subject: [PATCH 32/34] Address Philip's comments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Philip Müller --- .../fuse_horizontal_conditionblocks.py | 120 ++++++++---------- 1 file changed, 52 insertions(+), 68 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index b6190de55f..8962b9be9d 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -189,17 +189,16 @@ def apply( first_conditional_block = next(iter(first_cb.sdfg.nodes())) second_conditional_block = next(iter(second_cb.sdfg.nodes())) - # Store original arrays to check later that all the necessary arrays have been moved - original_arrays_first_conditional_block = first_conditional_block.sdfg.arrays.copy() - original_arrays_second_conditional_block = second_conditional_block.sdfg.arrays.copy() - total_original_arrays = len(original_arrays_first_conditional_block) + len( - original_arrays_second_conditional_block + # Store number of original arrays to check later that all the necessary arrays have been moved + total_original_arrays = len(first_conditional_block.sdfg.arrays) + len( + second_conditional_block.sdfg.arrays ) - # Store the new names for the arrays in the second conditional block to avoid name clashes and add their data descriptors - # to the first conditional block SDFG + # Store the new names for the arrays of the second conditional block (transients and globals) to avoid name clashes and add their data descriptors + # to the first conditional block SDFG. We don't have to add `__cond` because we know it's the same for both conditional blocks. + # TODO(iomaganaris): Remove inputs to the conditional block that come from the same AccessNodes (same data) second_arrays_rename_map: dict[str, str] = {} - for data_name, data_desc in original_arrays_second_conditional_block.items(): + for data_name, data_desc in second_conditional_block.sdfg.arrays.items(): if data_name == "__cond": continue new_data_name = gtx_transformations.utils.unique_name(data_name) + "_from_cb_fusion" @@ -210,30 +209,24 @@ def apply( second_conditional_states = list(second_conditional_block.all_states()) # Move the connectors from the second conditional block to the first - in_connectors_to_move = {k: v for k, v in second_cb.in_connectors.items() if k != "__cond"} - out_connectors_to_move = second_cb.out_connectors - in_connectors_to_move_rename_map: dict[str, str] = {} - out_connectors_to_move_rename_map: dict[str, str] = {} - for original_in_connector_name in in_connectors_to_move: - new_connector_name = original_in_connector_name - new_connector_name = second_arrays_rename_map[original_in_connector_name] - in_connectors_to_move_rename_map[original_in_connector_name] = new_connector_name - first_cb.add_in_connector(new_connector_name) - for edge in graph.in_edges(second_cb): - if edge.dst_conn == original_in_connector_name: - dace_helpers.redirect_edge( - state=graph, edge=edge, new_dst_conn=new_connector_name, new_dst=first_cb - ) - for original_out_connector_name in out_connectors_to_move: - new_connector_name = original_out_connector_name - new_connector_name = second_arrays_rename_map[original_out_connector_name] - out_connectors_to_move_rename_map[original_out_connector_name] = new_connector_name - first_cb.add_out_connector(new_connector_name) - for edge in graph.out_edges(second_cb): - if edge.src_conn == original_out_connector_name: - dace_helpers.redirect_edge( - state=graph, edge=edge, new_src_conn=new_connector_name, new_src=first_cb - ) + for edge in graph.in_edges(second_cb): + if edge.dst_conn == "__cond": + continue + first_cb.add_in_connector(second_arrays_rename_map[edge.dst_conn]) + dace_helpers.redirect_edge( + state=graph, + edge=edge, + new_dst_conn=second_arrays_rename_map[edge.dst_conn], + new_dst=first_cb, + ) + for edge in graph.out_edges(second_cb): + first_cb.add_out_connector(second_arrays_rename_map[edge.src_conn]) + dace_helpers.redirect_edge( + state=graph, + edge=edge, + new_src_conn=second_arrays_rename_map[edge.src_conn], + new_src=first_cb, + ) def _find_corresponding_state_in_second( inner_state: dace.SDFGState, @@ -242,20 +235,14 @@ def _find_corresponding_state_in_second( branch_type = "true_branch" if is_true_branch else "false_branch" return next(state for state in second_conditional_states if branch_type in state.name) - corresponding_states_first_to_second = { - state: _find_corresponding_state_in_second(state) - for state in first_conditional_block.all_states() - } - # Copy first the nodes from the second conditional block to the first # Create a dictionary that maps the original nodes in the second conditional # block to the new nodes in the first conditional block to be able to properly connect the edges later nodes_renamed_map: dict[dace_nodes.Node, dace_nodes.Node] = {} - # Save edges of second conditional block to a state to be able to delete the nodes from the second conditional block - second_state_edges_to_add: dict[dace.SDFGState, dace_graph.Edge] = {} for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] - second_state_edges_to_add[corresponding_state_in_second] = [] + corresponding_state_in_second = _find_corresponding_state_in_second(first_inner_state) + # Save edges of second conditional block to a state to be able to delete the nodes from the second conditional block + edges_to_copy = list(corresponding_state_in_second.edges()) nodes_to_move = list(corresponding_state_in_second.nodes()) for node in nodes_to_move: new_node = node @@ -263,43 +250,40 @@ def _find_corresponding_state_in_second( new_data_name = second_arrays_rename_map[node.data] new_node = dace_nodes.AccessNode(new_data_name) nodes_renamed_map[node] = new_node - second_state_edges_to_add[corresponding_state_in_second].extend( - corresponding_state_in_second.in_edges(node) - ) - second_state_edges_to_add[corresponding_state_in_second].extend( - corresponding_state_in_second.out_edges(node) - ) # Remove the original node from the second conditional block to avoid any potential issues # with the nodes coexisting in two states corresponding_state_in_second.remove_node(node) first_inner_state.add_node(new_node) - # Then copy the edges - second_to_first_connections: dict[str, str] = {} - for node in nodes_renamed_map: - if isinstance(node, dace_nodes.AccessNode): - second_to_first_connections[node.data] = nodes_renamed_map[node].data - for first_inner_state in first_conditional_block.all_states(): - corresponding_state_in_second = corresponding_states_first_to_second[first_inner_state] - for edge in second_state_edges_to_add[corresponding_state_in_second]: - new_memlet = copy.deepcopy(edge.data) - if edge.data.data in second_to_first_connections: - new_memlet.data = second_to_first_connections[edge.data.data] - first_inner_state.add_edge( - nodes_renamed_map[edge.src], - nodes_renamed_map[edge.src].data - if isinstance(edge.src, dace_nodes.AccessNode) and edge.src_conn - else edge.src_conn, - nodes_renamed_map[edge.dst], - second_to_first_connections[edge.dst.data] - if isinstance(edge.dst, dace_nodes.AccessNode) and edge.dst_conn - else edge.dst_conn, - new_memlet, + for edge_to_copy in edges_to_copy: + new_edge = first_inner_state.add_edge( + nodes_renamed_map[edge_to_copy.src], + edge_to_copy.src_conn, + nodes_renamed_map[edge_to_copy.dst], + edge_to_copy.dst_conn, + edge_to_copy.data, ) + if ( + not new_edge.data.is_empty() + ) and new_edge.data.data in second_arrays_rename_map: + new_edge.data.data = second_arrays_rename_map[new_edge.data.data] + for edge in list(graph.out_edges(conditional_access_node)): if edge.dst == second_cb: graph.remove_edge(edge) + # Copy missing symbols from second conditional block to the first one + missing_symbols = { + sym2: val2 + for sym2, val2 in second_cb.symbol_mapping.items() + if sym2 not in first_cb.symbol_mapping + } + for missing_symb, symb_def in missing_symbols.items(): + first_cb.symbol_mapping[missing_symb] = symb_def + first_cb.add_symbol( + missing_symb, second_cb.sdfg.symbols[missing_symb], find_new_name=False + ) + # TODO(iomaganaris): Atm need to remove both references to remove NestedSDFG from graph # second_conditional_block is inside the SDFG of NestedSDFG second_cb and removing only # one of them keeps a reference to the other one so none is properly deleted from the SDFG. From 4f03b21943b759d1f02801b9221a35681e9918d4 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Tue, 10 Feb 2026 16:37:03 +0100 Subject: [PATCH 33/34] Removed check for nestedsdfg --- .../transformations/fuse_horizontal_conditionblocks.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index 8962b9be9d..ba9a3d9ba5 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -167,14 +167,6 @@ def can_be_applied( ): return False - # TODO(iomaganaris): Currently we do not handle NestedSDFGs inside the conditional blocks - # however this should be very close to handling Tasklets which is currently working. Since - # a test is missing and there is no immediate need for this feature, we leave it for future work. - for state in second_cb.sdfg.states(): - for node in state.nodes(): - if isinstance(node, dace_nodes.NestedSDFG): - return False - return True def apply( From 902eed0c207ab9d66a8c94f615dc5355fbbe2c60 Mon Sep 17 00:00:00 2001 From: Ioannis Magkanaris Date: Wed, 11 Feb 2026 10:03:44 +0100 Subject: [PATCH 34/34] Added test for symbols --- .../fuse_horizontal_conditionblocks.py | 2 +- .../test_fuse_horizontal_conditionblocks.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py index ba9a3d9ba5..0c3d93587c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/fuse_horizontal_conditionblocks.py @@ -272,7 +272,7 @@ def _find_corresponding_state_in_second( } for missing_symb, symb_def in missing_symbols.items(): first_cb.symbol_mapping[missing_symb] = symb_def - first_cb.add_symbol( + first_cb.sdfg.add_symbol( missing_symb, second_cb.sdfg.symbols[missing_symb], find_new_name=False ) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py index 2eb11a55d7..45b3c5ff5b 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_fuse_horizontal_conditionblocks.py @@ -47,11 +47,13 @@ def _make_if_block_with_tasklet( then_body = dace.sdfg.state.ControlFlowRegion("then_body", sdfg=inner_sdfg) tstate = then_body.add_state("true_branch_0_1_2_3_4", is_start_block=True) + inner_sdfg.add_symbol("multiplier", dace.float64) + tasklet = tstate.add_tasklet( "true_tasklet", inputs={"__tasklet_in"}, outputs={"__tasklet_out"}, - code="__tasklet_out = __tasklet_in * 2.0", + code="__tasklet_out = __tasklet_in * multiplier", ) tstate.add_edge( tstate.add_access(b1_name), @@ -79,11 +81,13 @@ def _make_if_block_with_tasklet( if_region.add_branch(dace.sdfg.state.CodeBlock(cond_name), then_body) if_region.add_branch(dace.sdfg.state.CodeBlock(f"not {cond_name}"), else_body) - return state.add_nested_sdfg( + nested_sdfg = state.add_nested_sdfg( sdfg=inner_sdfg, inputs={b1_name, b2_name, cond_name}, outputs={output_name}, ) + nested_sdfg.symbol_mapping["multiplier"] = 2.0 + return nested_sdfg def _make_map_with_conditional_blocks() -> dace.SDFG: @@ -188,6 +192,10 @@ def test_fuse_horizontal_condition_blocks(): n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, dace.sdfg.state.ConditionalBlock) ] assert len(new_conditional_blocks) == 1 + conditional_block = new_conditional_blocks[0] + assert ( + len(conditional_block.sdfg.symbols) == 1 and "multiplier" in conditional_block.sdfg.symbols + ) util.compile_and_run_sdfg(sdfg, **res) assert util.compare_sdfg_res(ref=ref, res=res)