diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 1d04c21fc3..91aba1ba4c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -15,6 +15,7 @@ import dace from dace import data as dace_data from dace.sdfg import nodes as dace_nodes, propagation as dace_propagation, utils as dace_sdutils +from dace.transformation import dataflow as dace_dataflow from dace.transformation.auto import auto_optimize as dace_aoptimize from dace.transformation.passes import analysis as dace_analysis @@ -130,6 +131,7 @@ def gt_auto_optimize( assume_pointwise: bool = True, optimization_hooks: Optional[dict[GT4PyAutoOptHook, GT4PyAutoOptHookFun]] = None, demote_fields: Optional[list[str]] = None, + fuse_tasklets: bool = False, validate: bool = True, validate_all: bool = False, **kwargs: Any, @@ -197,6 +199,7 @@ def gt_auto_optimize( see `GT4PyAutoOptHook` for more information. demote_fields: Consider these fields as transients for the purpose of optimization. Use at your own risk. See Notes for all implications. + fuse_tasklets: Reduces the number of Tasklets by fusing them. validate: Perform validation during the steps. validate_all: Perform extensive validation. @@ -324,6 +327,7 @@ def gt_auto_optimize( blocking_only_if_independent_nodes=blocking_only_if_independent_nodes, scan_loop_unrolling=scan_loop_unrolling, scan_loop_unrolling_factor=scan_loop_unrolling_factor, + fuse_tasklets=fuse_tasklets, validate_all=validate_all, ) @@ -660,6 +664,7 @@ def _gt_auto_process_dataflow_inside_maps( blocking_only_if_independent_nodes: Optional[bool], scan_loop_unrolling: bool, scan_loop_unrolling_factor: int, + fuse_tasklets: bool, validate_all: bool, ) -> dace.SDFG: """Optimizes the dataflow inside the top level Maps of the SDFG inplace. @@ -674,6 +679,35 @@ def _gt_auto_process_dataflow_inside_maps( time, so the compiler will fully unroll them anyway. """ + # Separate Tasklets into dependent and independent parts to promote data + # reusability. It is important that this step has to be performed before + # `TaskletFusion` is used. + if blocking_dim is not None: + sdfg.apply_transformations_once_everywhere( + gtx_transformations.LoopBlocking( + blocking_size=blocking_size, + blocking_parameter=blocking_dim, + require_independent_nodes=blocking_only_if_independent_nodes, + ), + validate=False, + validate_all=validate_all, + ) + + # Merge Tasklets into bigger ones. + # NOTE: Empirical observation for Graupel have shown that this leads to an increase + # in performance, however, it has to be run before `GT4PyMoveTaskletIntoMap` + # (not fully clear why though, probably a compiler artefact) and as well as + # `MoveDataflowIntoIfBody` (not fully clear either, it `TaskletFusion` makes + # things simpler or prevent it from doing certain, negative, things). + # TODO(phimuell): Investigate more. + # TODO(phimuell): Restrict it to Tasklets only inside Maps. + if fuse_tasklets: + sdfg.apply_transformations_repeated( + dace_dataflow.TaskletFusion, + validate=False, + validate_all=validate_all, + ) + # Constants (tasklets are needed to write them into a variable) should not be # arguments to a kernel but be present inside the body. sdfg.apply_transformations_once_everywhere( @@ -681,6 +715,8 @@ def _gt_auto_process_dataflow_inside_maps( validate=False, validate_all=validate_all, ) + + # TODO(phimuell): figuring out if this is needed? gtx_transformations.gt_simplify( sdfg, skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST, @@ -688,19 +724,6 @@ def _gt_auto_process_dataflow_inside_maps( validate_all=validate_all, ) - # Blocking is performed first, because this ensures that as much as possible - # is moved into the k independent part. - if blocking_dim is not None: - sdfg.apply_transformations_once_everywhere( - gtx_transformations.LoopBlocking( - blocking_size=blocking_size, - blocking_parameter=blocking_dim, - require_independent_nodes=blocking_only_if_independent_nodes, - ), - validate=False, - validate_all=validate_all, - ) - # Move dataflow into the branches of the `if` such that they are only evaluated # if they are needed. Important to call it repeatedly. # TODO(phimuell): It is unclear if `MoveDataflowIntoIfBody` should be called @@ -714,6 +737,8 @@ def _gt_auto_process_dataflow_inside_maps( validate=False, validate_all=validate_all, ) + + # TODO(phimuell): figuring out if this is needed? gtx_transformations.gt_simplify( sdfg, skip=gtx_transformations.constants._GT_AUTO_OPT_INNER_DATAFLOW_STAGE_SIMPLIFY_SKIP_LIST,