From 5c87061fb23a2999db70fa15ef852a5dc959c99e Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 10 Feb 2026 14:08:21 +0100 Subject: [PATCH 1/3] Remove `make_dace_backend` and absorb functionality in existing factory --- src/gt4py/next/factory_utils.py | 119 +++++++++++++ .../runners/dace/workflow/backend.py | 164 ++++++++---------- .../runners/dace/workflow/factory.py | 50 +++--- .../next/program_processors/runners/gtfn.py | 85 ++++----- .../ffront_tests/test_execution.py | 3 + tests/next_tests/unit_tests/factory_utils.py | 48 +++++ 6 files changed, 311 insertions(+), 158 deletions(-) create mode 100644 src/gt4py/next/factory_utils.py create mode 100644 tests/next_tests/unit_tests/factory_utils.py diff --git a/src/gt4py/next/factory_utils.py b/src/gt4py/next/factory_utils.py new file mode 100644 index 0000000000..1923c2f092 --- /dev/null +++ b/src/gt4py/next/factory_utils.py @@ -0,0 +1,119 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +from typing import Any, Callable, overload + +import factory + + +class Factory(factory.Factory): + """` + Factory that defaults all ``factory.Trait`` params to ``False``. + + Ensures that every ``factory.Trait`` declared in ``Params`` is present in + the keyword arguments passed to ``create``, defaulting to ``False`` when + not explicitly provided. This allows trait-dependent declarations to + always access the trait flag, even when the caller does not mention it. + """ + + @classmethod + def create(cls, **kwargs: Any) -> Any: + # adjust keyword arguments so that traits options are available even when not given + # explicitly + for name, param in cls._meta.parameters.items(): + if isinstance(param, factory.Trait): + kwargs.setdefault(name, False) + return super().create(**kwargs) + + +class DynamicTransformer(factory.declarations.BaseDeclaration): + CAPTURE_OVERRIDES = True + UNROLL_CONTEXT_BEFORE_EVALUATION = False + + def __init__(self, default: Any, *, transform: Callable[[Any, Any], Any]) -> None: + super().__init__() + self.default = default + self.transform = transform + + def evaluate_pre(self, instance: Any, step: Any, overrides: dict[str, Any]) -> Any: + # The call-time value, if present, is set under the "" key. + value_or_declaration = overrides.pop("", self.default) + + if isinstance(value_or_declaration, factory.Transformer.Force): + bypass_transform = True + value_or_declaration = value_or_declaration.forced_value + else: + bypass_transform = False + + value = self._unwrap_evaluate_pre( + value_or_declaration, + instance=instance, + step=step, + overrides=overrides, + ) + if bypass_transform: + return value + + transform = self._unwrap_evaluate_pre( + self.transform, + instance=instance, + step=step, + overrides=overrides, + ) + + return transform(instance, value) + + +@overload +def dynamic_transformer(func: Callable[[Any, Any], Any], *, default: Any) -> DynamicTransformer: ... + + +@overload +def dynamic_transformer( + *, default: Any +) -> Callable[[Callable[[Any, Any], Any]], DynamicTransformer]: ... + + +def dynamic_transformer( + func: Callable[[Any, Any], Any] | None = None, *, default: Any +) -> DynamicTransformer | Callable[[Callable[[Any, Any], Any]], DynamicTransformer]: + """ + Decorator that creates a factory field whose value is always passed through a transform. + + Works like ``factory.Transformer`` but the transform function receives the + full factory instance, so it can read other parameters/traits. The + *default* argument provides the base value (may be a factory declaration + such as ``factory.SelfAttribute``). Use ``factory.Transformer.Force`` to + bypass the transform at call-time. + + Example: + >>> import dataclasses, factory + >>> @dataclasses.dataclass + ... class Person: + ... name: str + ... nickname: str + >>> class PersonFactory(factory.Factory): + ... class Meta: + ... model = Person + ... + ... name = "Joe" + ... + ... @dynamic_transformer(default=factory.SelfAttribute(".name")) + ... def nickname(self, nickname): + ... return f"{nickname}y" + >>> PersonFactory().nickname + 'Joey' + >>> PersonFactory(name="John").nickname + 'Johny' + >>> PersonFactory(name=factory.Transformer.Force("John")).name + 'John' + """ + if func is None: + return functools.partial(dynamic_transformer, default=default) + return DynamicTransformer(default=default, transform=func) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py index 32e3ba8a31..63b3006a8a 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/backend.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/backend.py @@ -9,75 +9,22 @@ from __future__ import annotations import warnings -from typing import Any, Final +from typing import Callable, Final import factory import gt4py.next.allocators as next_allocators from gt4py._core import definitions as core_defs -from gt4py.next import backend, common, config +from gt4py.next import backend, common, config, factory_utils from gt4py.next.otf import stages, workflow from gt4py.next.program_processors.runners.dace.workflow.factory import DaCeWorkflowFactory -class DaCeBackendFactory(factory.Factory): +class DaCeBackendFactory(factory_utils.Factory): """ - Workflow factory for the GTIR-DaCe backend. + Configurable factory for DaCe backend. - Several parameters are inherithed from `backend.Backend`, see below the specific ones. - - Args: - auto_optimize: Enables the SDFG transformation pipeline. - """ - - class Meta: - model = backend.Backend - - class Params: - name_device = "cpu" - name_cached = "" - name_postfix = "" - gpu = factory.Trait( - allocator=next_allocators.StandardGPUFieldBufferAllocator(), - device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, - name_device="gpu", - ) - cached = factory.Trait( - executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) - ), - name_cached="_cached", - ) - device_type = core_defs.DeviceType.CPU - hash_function = stages.compilation_hash - otf_workflow = factory.SubFactory( - DaCeWorkflowFactory, - device_type=factory.SelfAttribute("..device_type"), - auto_optimize=factory.SelfAttribute("..auto_optimize"), - ) - auto_optimize = factory.Trait(name_postfix="_opt") - - name = factory.LazyAttribute( - lambda o: f"run_dace_{o.name_device}{o.name_cached}{o.name_postfix}" - ) - - executor = factory.LazyAttribute(lambda o: o.otf_workflow) - allocator = next_allocators.StandardCPUFieldBufferAllocator() - transforms = backend.DEFAULT_TRANSFORMS - - -def make_dace_backend( - gpu: bool, - cached: bool = True, - auto_optimize: bool = True, - async_sdfg_call: bool = True, - optimization_args: dict[str, Any] | None = None, - use_metrics: bool = True, - use_zero_origin: bool = False, -) -> backend.Backend: - """Customize the dace backend with the given configuration parameters. - - Args: + Parameters: gpu: Enable GPU transformations and code generation. cached: Cache the lowered SDFG as a JSON file and the compiled programs. auto_optimize: Enable the SDFG auto-optimize pipeline. @@ -99,72 +46,107 @@ def make_dace_backend( A dace backend with custom configuration for the target device. """ - # The `gt_optimization_args` set contains the parameters of `gt_auto_optimize()` - # that are derived from the gt4py configuration, and therefore cannot be customized. - gt_optimization_args: Final[set[str]] = {"gpu", "constant_symbols", "unit_strides_kind"} - - if optimization_args is None: - optimization_args = {} - elif optimization_args and not auto_optimize: - warnings.warn("Optimizations args given, but auto-optimize is disabled.", stacklevel=2) - elif intersect_args := gt_optimization_args.intersection(optimization_args.keys()): - raise ValueError( - f"The following optimization arguments cannot be overriden: {intersect_args}." + class Meta: + model = backend.Backend + + class Params: + gpu = factory.Trait( + allocator=next_allocators.StandardGPUFieldBufferAllocator(), + device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, ) + cached: bool = True + auto_optimize: bool = True + async_sdfg_call: bool = factory.SelfAttribute(".gpu") # type: ignore[assignment] + use_metrics: bool = True + use_zero_origin: bool = False + device_type = core_defs.DeviceType.CPU + hash_function = stages.compilation_hash - # Set `unit_strides_kind` based on the gt4py env configuration. - optimization_args = optimization_args | { - "unit_strides_kind": common.DimensionKind.HORIZONTAL - if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE - else None - } - - return DaCeBackendFactory( # type: ignore[return-value] # factory-boy typing not precise enough - gpu=gpu, - cached=cached, - auto_optimize=auto_optimize, - otf_workflow__cached_translation=cached, - otf_workflow__bare_translation__async_sdfg_call=(async_sdfg_call if gpu else False), - otf_workflow__bare_translation__auto_optimize_args=optimization_args, - otf_workflow__bare_translation__use_metrics=use_metrics, - otf_workflow__bare_translation__disable_field_origin_on_program_arguments=use_zero_origin, + @factory_utils.dynamic_transformer(default=factory.Dict({})) + def optimization_args(self, optimization_args: dict) -> dict: + # The `gt_optimization_args` set contains the parameters of `gt_auto_optimize()` + # that are derived from the gt4py configuration, and therefore cannot be customized. + gt_optimization_args: Final[set[str]] = {"gpu", "constant_symbols", "unit_strides_kind"} + + if optimization_args and not self.auto_optimize: + warnings.warn( + "Optimizations args given, but auto-optimize is disabled.", stacklevel=2 + ) + elif intersect_args := gt_optimization_args.intersection(optimization_args.keys()): + raise ValueError( + f"The following optimization arguments cannot be overriden: {intersect_args}." + ) + + # Set `unit_strides_kind` based on the gt4py env configuration. + return optimization_args | { + "unit_strides_kind": common.DimensionKind.HORIZONTAL + if config.UNSTRUCTURED_HORIZONTAL_HAS_UNIT_STRIDE + else None + } + + @factory.lazy_attribute + def name(self: factory.builder.Resolver) -> str: + name = "run_dace_" + name += "gpu" if self.gpu else "cpu" + if self.auto_optimize: + name += "_opt" + if self.cached: + name += "_cached" + return name + + @factory_utils.dynamic_transformer( + default=factory.SubFactory( + DaCeWorkflowFactory, + device_type=factory.SelfAttribute("..device_type"), + auto_optimize=factory.SelfAttribute("..auto_optimize"), + translation__async_sdfg_call=factory.SelfAttribute("...async_sdfg_call"), + translation__auto_optimize_args=factory.SelfAttribute("...optimization_args"), + translation__use_metrics=factory.SelfAttribute("...use_metrics"), + translation__disable_field_origin_on_program_arguments=factory.SelfAttribute( + "...use_zero_origin" + ), + ) ) + def executor(self: factory.builder.Resolver, value: Callable) -> Callable: + if self.cached: + return workflow.CachedStep(value, hash_function=self.hash_function) + return value + + allocator = next_allocators.StandardCPUFieldBufferAllocator() + transforms = backend.DEFAULT_TRANSFORMS + + +make_dace_backend = DaCeBackendFactory run_dace_cpu = make_dace_backend( gpu=False, cached=False, auto_optimize=True, - async_sdfg_call=False, ) run_dace_cpu_noopt = make_dace_backend( gpu=False, cached=False, auto_optimize=False, - async_sdfg_call=False, ) run_dace_cpu_cached = make_dace_backend( gpu=False, cached=True, auto_optimize=True, - async_sdfg_call=False, ) run_dace_gpu = make_dace_backend( gpu=True, cached=False, auto_optimize=True, - async_sdfg_call=True, ) run_dace_gpu_noopt = make_dace_backend( gpu=True, cached=False, auto_optimize=False, - async_sdfg_call=True, ) run_dace_gpu_cached = make_dace_backend( gpu=True, cached=True, auto_optimize=True, - async_sdfg_call=True, ) diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py index 62febd0965..35a5f7e41f 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/factory.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/factory.py @@ -9,12 +9,12 @@ from __future__ import annotations import functools -from typing import Final +from typing import Callable, Final import factory from gt4py._core import definitions as core_defs, filecache -from gt4py.next import config +from gt4py.next import config, factory_utils from gt4py.next.otf import recipes, stages, workflow from gt4py.next.program_processors.runners.dace.workflow import ( bindings as bindings_step, @@ -36,35 +36,35 @@ class Meta: model = recipes.OTFCompileWorkflow class Params: + cache_translation: bool = True auto_optimize: bool = False device_type: core_defs.DeviceType = core_defs.DeviceType.CPU cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough lambda: config.CMAKE_BUILD_TYPE ) - cached_translation = factory.Trait( - translation=factory.LazyAttribute( - lambda o: workflow.CachedStep( - o.bare_translation, - hash_function=stages.fingerprint_compilable_program, - cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), - ) - ), - ) - - bare_translation = factory.SubFactory( + @factory_utils.dynamic_transformer( + default=factory.SubFactory( DaCeTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), auto_optimize=factory.SelfAttribute("..auto_optimize"), ) + ) + def translation(self: factory.builder.Resolver, value: Callable) -> Callable: + if self.cache_translation: + return workflow.CachedStep( + value, + hash_function=stages.fingerprint_compilable_program, + cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "translation_cache")), + ) + return value - translation = factory.LazyAttribute(lambda o: o.bare_translation) - bindings = factory.LazyAttribute( - lambda o: functools.partial( - bindings_step.bind_sdfg, - bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME, + @factory.lazy_attribute + def bindings(self: factory.builder.Resolver) -> Callable: + return functools.partial( + bindings_step.bind_sdfg, bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME ) - ) + compilation = factory.SubFactory( DaCeCompilationStepFactory, bind_func_name=_GT_DACE_BINDING_FUNCTION_NAME, @@ -72,9 +72,9 @@ class Params: device_type=factory.SelfAttribute("..device_type"), cmake_build_type=factory.SelfAttribute("..cmake_build_type"), ) - decoration = factory.LazyAttribute( - lambda o: functools.partial( - decoration_step.convert_args, - device=o.device_type, - ) - ) + + @factory.lazy_attribute + def decoration( + self: factory.builder.Resolver, + ) -> Callable[[stages.CompiledProgram], stages.CompiledProgram]: + return functools.partial(decoration_step.convert_args, device=self.device_type) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 038f2959b6..d544d22586 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -7,7 +7,7 @@ # SPDX-License-Identifier: BSD-3-Clause import functools -from typing import Any +from typing import Any, Callable import factory import numpy as np @@ -15,7 +15,7 @@ import gt4py._core.definitions as core_defs import gt4py.next.allocators as next_allocators from gt4py._core import filecache -from gt4py.next import backend, common, config, field_utils +from gt4py.next import backend, common, config, factory_utils, field_utils from gt4py.next.embedded import nd_array_field from gt4py.next.instrumentation import metrics from gt4py.next.otf import recipes, stages, workflow @@ -111,6 +111,7 @@ class Meta: model = recipes.OTFCompileWorkflow class Params: + cache_translation: bool = True device_type: core_defs.DeviceType = core_defs.DeviceType.CPU cmake_build_type: config.CMakeBuildType = factory.LazyFunction( # type: ignore[assignment] # factory-boy typing not precise enough lambda: config.CMAKE_BUILD_TYPE @@ -119,67 +120,71 @@ class Params: lambda o: compiledb.CompiledbFactory(cmake_build_type=o.cmake_build_type) ) - cached_translation = factory.Trait( - translation=factory.LazyAttribute( - lambda o: workflow.CachedStep( - o.bare_translation, - hash_function=stages.fingerprint_compilable_program, - cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), - ) - ), - ) - - bare_translation = factory.SubFactory( + @factory_utils.dynamic_transformer( + default=factory.SubFactory( gtfn_module.GTFNTranslationStepFactory, device_type=factory.SelfAttribute("..device_type"), ) - - translation = factory.LazyAttribute(lambda o: o.bare_translation) + ) + def translation(self: factory.builder.Resolver, value: Any) -> Any: + if self.cache_translation: + return workflow.CachedStep( + value, + hash_function=stages.fingerprint_compilable_program, + cache=filecache.FileCache(str(config.BUILD_CACHE_DIR / "gtfn_cache")), + ) + return value bindings: workflow.Workflow[stages.ProgramSource, stages.CompilableSource] = ( nanobind.bind_source ) + compilation = factory.SubFactory( compiler.CompilerFactory, cache_lifetime=factory.LazyFunction(lambda: config.BUILD_CACHE_LIFETIME), builder_factory=factory.SelfAttribute("..builder_factory"), ) - decoration = factory.LazyAttribute( - lambda o: functools.partial(convert_args, device=o.device_type) - ) + @factory.lazy_attribute + def decoration( + self: factory.builder.Resolver, + ) -> Callable[[stages.CompiledProgram], stages.CompiledProgram]: + return functools.partial(convert_args, device=self.device_type) -class GTFNBackendFactory(factory.Factory): + +class GTFNBackendFactory(factory_utils.Factory): class Meta: model = backend.Backend class Params: - name_device = "cpu" - name_cached = "" - name_temps = "" name_postfix = "" gpu = factory.Trait( allocator=next_allocators.StandardGPUFieldBufferAllocator(), device_type=core_defs.CUPY_DEVICE_TYPE or core_defs.DeviceType.CUDA, - name_device="gpu", - ) - cached = factory.Trait( - executor=factory.LazyAttribute( - lambda o: workflow.CachedStep(o.otf_workflow, hash_function=o.hash_function) - ), - name_cached="_cached", ) + cached = factory.Trait(executor__cache_translation=True) device_type = core_defs.DeviceType.CPU hash_function = stages.compilation_hash - otf_workflow = factory.SubFactory( + + @factory.lazy_attribute + def name(self: factory.builder.Resolver) -> str: + name = "run_gtfn_" + name += "gpu" if self.gpu else "cpu" + if self.cached: + name += "_cached" + name += self.name_postfix + return name + + @factory_utils.dynamic_transformer( + default=factory.SubFactory( GTFNCompileWorkflowFactory, device_type=factory.SelfAttribute("..device_type") ) - - name = factory.LazyAttribute( - lambda o: f"run_gtfn_{o.name_device}{o.name_temps}{o.name_cached}{o.name_postfix}" ) + def executor(self: factory.builder.Resolver, value: Callable) -> Callable: + if self.cached: + return workflow.CachedStep(value, hash_function=self.hash_function) + return value - executor = factory.LazyAttribute(lambda o: o.otf_workflow) allocator = next_allocators.StandardCPUFieldBufferAllocator() transforms = backend.DEFAULT_TRANSFORMS @@ -187,17 +192,13 @@ class Params: run_gtfn = GTFNBackendFactory() run_gtfn_imperative = GTFNBackendFactory( - name_postfix="_imperative", otf_workflow__translation__use_imperative_backend=True + name_postfix="_imperative", executor__translation__use_imperative_backend=True ) -run_gtfn_cached = GTFNBackendFactory(cached=True, otf_workflow__cached_translation=True) +run_gtfn_cached = GTFNBackendFactory(cached=True, executor__cache_translation=True) run_gtfn_gpu = GTFNBackendFactory(gpu=True) -run_gtfn_gpu_cached = GTFNBackendFactory( - gpu=True, cached=True, otf_workflow__cached_translation=True -) +run_gtfn_gpu_cached = GTFNBackendFactory(gpu=True, cached=True, executor__cache_translation=True) -run_gtfn_no_transforms = GTFNBackendFactory( - otf_workflow__bare_translation__enable_itir_transforms=False -) +run_gtfn_no_transforms = GTFNBackendFactory(executor__translation__enable_itir_transforms=False) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 8060d5bb36..ed8261fda5 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -60,6 +60,9 @@ def testee(a: cases.IJKField) -> cases.IJKField: field_0 = field_tuple[0] field_1 = field_tuple[1] return field_0 + # TODO: this breaks with dace: investigate + # field_tmp = 2*field_0 + # return field_tmp cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) diff --git a/tests/next_tests/unit_tests/factory_utils.py b/tests/next_tests/unit_tests/factory_utils.py new file mode 100644 index 0000000000..1e1323ab5c --- /dev/null +++ b/tests/next_tests/unit_tests/factory_utils.py @@ -0,0 +1,48 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +import dataclasses + +import factory +import pytest + +import factory + +from gt4py.next import factory_utils + + +@dataclasses.dataclass +class Person: + name: str + nickname: str + + +class PersonFactory(factory.Factory): + class Meta: + model = Person + + class Params: + endearment: str = True + + name = "Joe" + + @factory_utils.dynamic_transformer(default=factory.SelfAttribute(".name")) + def nickname(self, nickname): + if self.endearment: + name = f"{nickname}y" + return name + + +def test_transformer_applies_transform_to_default_and_override(): + # default value is transformed + person = PersonFactory() + assert person.nickname == "Joey" + + # overridden `name` value is also transformed + john = PersonFactory(name="John") + assert john.nickname == "Johny" From a546b40a2f7e7090e1f1f9f9a9967c69be2cf8fb Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 10 Feb 2026 15:40:36 +0100 Subject: [PATCH 2/3] Undo unrelated change --- .../feature_tests/ffront_tests/test_execution.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index ed8261fda5..8060d5bb36 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -60,9 +60,6 @@ def testee(a: cases.IJKField) -> cases.IJKField: field_0 = field_tuple[0] field_1 = field_tuple[1] return field_0 - # TODO: this breaks with dace: investigate - # field_tmp = 2*field_0 - # return field_tmp cases.verify_with_default_data(cartesian_case, testee, ref=lambda a: a) From 7a730994037a7582b0c34ad57aaeeafc26a7d008 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Tue, 10 Feb 2026 15:44:05 +0100 Subject: [PATCH 3/3] Small fix --- .../unit_tests/otf_tests/test_compiled_program.py | 4 ++-- .../codegens_tests/gtfn_tests/test_gtfn_module.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 96acf4edb5..ee7fc79446 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -106,13 +106,13 @@ def test_inlining_of_scalar_works_integration(): hijacked_program = None def pirate(program: toolchain.CompilableProgram): - # Replaces the gtfn otf_workflow: and steals the compilable program, + # Replaces the gtfn backend transformation: and steals the compilable program, # then returns a dummy "CompiledProgram" that does nothing. nonlocal hijacked_program hijacked_program = program return lambda *args, **kwargs: None - hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", otf_workflow=pirate) + hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", executor=pirate) testee = prog.with_backend(hacked_gtfn_backend).compile(cond=[True], offset_provider={}) testee( diff --git a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py index 7f759ef504..02960ac6be 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py +++ b/tests/next_tests/unit_tests/program_processor_tests/codegens_tests/gtfn_tests/test_gtfn_module.py @@ -134,11 +134,11 @@ def test_gtfn_file_cache(program_example): args=arguments.CompileTimeArgs.from_concrete(*parameters, **{"offset_provider": {}}), ) cached_gtfn_translation_step = gtfn.GTFNBackendFactory( - gpu=False, cached=True, otf_workflow__cached_translation=True + gpu=False, cached=True, executor__cached_translation=True ).executor.step.translation bare_gtfn_translation_step = gtfn.GTFNBackendFactory( - gpu=False, cached=True, otf_workflow__cached_translation=False + gpu=False, cached=True, executor__cached_translation=False ).executor.step.translation cache_key = stages.fingerprint_compilable_program(compilable_program) @@ -162,7 +162,7 @@ def test_gtfn_file_cache(program_example): def test_gtfn_file_cache_whole_workflow(cartesian_case_no_backend): cartesian_case = cartesian_case_no_backend cartesian_case.backend = gtfn.GTFNBackendFactory( - gpu=False, cached=True, otf_workflow__cached_translation=True + gpu=False, cached=True, executor__cached_translation=True ) cartesian_case.allocator = next_allocators.StandardCPUFieldBufferAllocator()