From ca6214c9f0d99b53b21a224f12030e8950a790ce Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 5 Feb 2026 16:52:31 +0100 Subject: [PATCH 1/5] Fix different static args after `with_backend` --- src/gt4py/next/ffront/decorator.py | 41 +++++++++++-------- src/gt4py/next/otf/options.py | 4 +- .../otf_tests/test_compiled_program.py | 11 +++++ 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 48c5b92ed5..86554152a8 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -19,7 +19,7 @@ import typing import warnings from collections.abc import Callable -from typing import Any, Generic, Optional, TypeVar +from typing import Any, Generic, Optional, Sequence, TypeVar from gt4py import eve from gt4py._core import definitions as core_defs @@ -88,33 +88,39 @@ def with_connectivities( @functools.cached_property def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: + return self._make_compiled_programs_pool( + static_params=self.compilation_options.static_params, + ) + + def _make_compiled_programs_pool( + self, + static_params: Sequence[str], + ) -> compiled_program.CompiledProgramsPool: if self.backend is None or self.backend == eve.NOTHING: raise RuntimeError("Cannot compile a program without backend.") - if self.compilation_options.static_params is None: - object.__setattr__(self.compilation_options, "static_params", ()) - argument_descriptor_mapping = { - arguments.StaticArg: self.compilation_options.static_params, + arguments.StaticArg: static_params, } program_type = ffront_type_info.type_in_program_context(self.__gt_type__()) assert isinstance(program_type, ts_ffront.ProgramType) - return compiled_program.CompiledProgramsPool( + pool = compiled_program.CompiledProgramsPool( backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible ) + return pool + def compile( self, offset_provider: common.OffsetProviderType | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, - enable_jit: bool | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: """ @@ -125,13 +131,17 @@ def compile( but adds the compiled variants to the current program instance. """ # TODO(havogt): we should reconsider if we want to return a new program on `compile` (and - # rename to `with_static_args` or similar) once we have a better understanding of the - # use-cases. + # rename to `with_static_args` or similar) once we have a better understanding of the + # use-cases. + # check if pool has already been initialized, since this is also a cached property go via + # the dict directly. Note that we don't need to check any args, since the pool checks + # this on compile anyway. + if "_compiled_programs" not in self.__dict__: + pool = self._make_compiled_programs_pool( + static_params=tuple(static_args.keys()), + ) + object.__setattr__(self, "_compiled_programs", pool) - if enable_jit is not None: - object.__setattr__(self.compilation_options, "enable_jit", enable_jit) - if self.compilation_options.static_params is None: - object.__setattr__(self.compilation_options, "static_params", tuple(static_args.keys())) if self.compilation_options.connectivities is None and offset_provider is None: raise ValueError( "Cannot compile a program without connectivities / OffsetProviderType." @@ -267,10 +277,10 @@ def with_grid_type(self, grid_type: common.GridType) -> Program: def with_static_params(self, *static_params: str | None) -> Program: if not static_params or (static_params == (None,)): - _static_params = None + _static_params: tuple[str, ...] = () else: assert all(p is not None for p in static_params) - _static_params = typing.cast(tuple[str], static_params) + _static_params = typing.cast(tuple[str, ...], static_params) return dataclasses.replace( self, compilation_options=dataclasses.replace( @@ -435,7 +445,6 @@ def compile( | common.OffsetProvider | list[common.OffsetProviderType | common.OffsetProvider] | None = None, - enable_jit: bool | None = None, **static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]], ) -> Self: raise NotImplementedError("Compilation of programs with bound arguments is not implemented") diff --git a/src/gt4py/next/otf/options.py b/src/gt4py/next/otf/options.py index 303996e458..bfb452941c 100644 --- a/src/gt4py/next/otf/options.py +++ b/src/gt4py/next/otf/options.py @@ -27,9 +27,7 @@ class CompilationOptions: enable_jit: bool = dataclasses.field(default_factory=lambda: config.ENABLE_JIT_DEFAULT) #: if the user requests static params, they will be used later to initialize CompiledPrograms - static_params: Sequence[str] | None = ( - None # TODO: describe that this value will eventually be a sequence of strings - ) + static_params: Sequence[str] = () # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information #: A dictionary holding static/compile-time information about the offset providers. diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 96acf4edb5..7ff97491af 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -124,3 +124,14 @@ def pirate(program: toolchain.CompilableProgram): ) _verify_program_has_expected_true_value(hijacked_program.data) + + +def test_different_static_args_work_after_backend_change(): + prg1 = prog.with_backend(gtfn.run_gtfn) + prg2 = prog.with_backend(gtfn.run_gtfn) + + # compile with static args + prg1.compile(cond=[True], offset_provider={}) + + # compile without static args + prg2.compile(offset_provider={}) From c8962b463e3b15be00f21d3fab2d25695e67ca9d Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Thu, 5 Feb 2026 17:14:13 +0100 Subject: [PATCH 2/5] Cleanup --- src/gt4py/next/ffront/decorator.py | 16 ++++++++-------- .../feature_tests/dace/test_orchestration.py | 18 ++++++++++-------- .../ffront_tests/test_compiled_program.py | 11 +++++++++-- 3 files changed, 27 insertions(+), 18 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 86554152a8..0ed9589a68 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -75,19 +75,21 @@ def __gt_type__(self) -> ts.CallableType: ... def with_backend(self, backend: next_backend.Backend) -> Self: return dataclasses.replace(self, backend=backend) - def with_connectivities( - self, - connectivities: common.OffsetProvider, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information - ) -> Self: + def with_compilation_option( + self, **compilation_options: Unpack[options.CompilationOptionsArgs] + ) -> Program: return dataclasses.replace( self, compilation_options=dataclasses.replace( - self.compilation_options, connectivities=connectivities + self.compilation_options, **compilation_options ), ) @functools.cached_property def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: + # note(tehrengruber): If the program is compiled using `compile` this method + # is skipped and a pool with (potentially different) options as given to + # `compile` is used. return self._make_compiled_programs_pool( static_params=self.compilation_options.static_params, ) @@ -106,15 +108,13 @@ def _make_compiled_programs_pool( program_type = ffront_type_info.type_in_program_context(self.__gt_type__()) assert isinstance(program_type, ts_ffront.ProgramType) - pool = compiled_program.CompiledProgramsPool( + return compiled_program.CompiledProgramsPool( backend=self.backend, definition_stage=self.definition_stage, program_type=program_type, argument_descriptor_mapping=argument_descriptor_mapping, # type: ignore[arg-type] # covariant `type[T]` not possible ) - return pool - def compile( self, offset_provider: common.OffsetProviderType diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index c4f442e993..0906ec38c0 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -50,14 +50,14 @@ def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( backend - ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( - in_field, tmp_field - ) + ).with_compilation_option( + connectivities=gtx_common.offset_provider_to_type(cartesian_case.offset_provider) + )(in_field, tmp_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( backend - ).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))( - tmp_field, out_field - ) + ).with_compilation_option( + connectivities=gtx_common.offset_provider_to_type(cartesian_case.offset_provider) + )(tmp_field, out_field) # use unique SDFG folder in dace cache to avoid clashes between parallel pytest workers with dace.config.set_temporary("cache", value="unique"): @@ -109,7 +109,7 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 allocator=allocator, ) - testee2 = testee.with_backend(backend).with_connectivities({"E2V": e2v}) + testee2 = testee.with_backend(backend).with_compilation_option(connectivities={"E2V": e2v}) @dace.program def sdfg( @@ -118,7 +118,9 @@ def sdfg( offset_provider: OffsetProvider_t, connectivities: dace.compiletime, ): - testee2.with_connectivities(connectivities)(a, out, offset_provider=offset_provider) + testee2.with_compilation_option(connectivities=connectivities)( + a, out, offset_provider=offset_provider + ) return out connectivities = {"E2V": e2v} # replace 'e2v' with 'e2v.__gt_type__()' when GTIR is AOT diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 8437e71367..cc19faf6a6 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -16,7 +16,7 @@ from gt4py import next as gtx from gt4py._core import definitions as core_defs from gt4py.next import errors, config -from gt4py.next.otf import compiled_program, options +from gt4py.next.otf import compiled_program, options, arguments from gt4py.next.ffront.decorator import Program from gt4py.next.ffront.fbuiltins import int32, neighbor_sum @@ -579,7 +579,14 @@ def test_compile_variants_not_compiled_then_reset_static_params( field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")() # the compile_variants_testee has static_params set and is compiled (in a previous test) - assert len(compile_variants_testee.compilation_options.static_params) > 0 + assert ( + len( + compile_variants_testee._compiled_programs.argument_descriptor_mapping[ + arguments.StaticArg + ] + ) + > 0 + ) assert compile_variants_testee._compiled_programs is not None # but now we reset the compiled programs From 86fce72510180426e50c48d78c243629c52c834b Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Fri, 6 Feb 2026 11:08:21 +0100 Subject: [PATCH 3/5] Cleanup --- src/gt4py/next/ffront/decorator.py | 6 +- src/gt4py/next/otf/options.py | 6 +- .../ffront_tests/test_compiled_program.py | 44 ++++++------ .../otf_tests/test_compiled_program.py | 67 +++++++++++++------ 4 files changed, 79 insertions(+), 44 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 0ed9589a68..f25863b334 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -77,7 +77,7 @@ def with_backend(self, backend: next_backend.Backend) -> Self: def with_compilation_option( self, **compilation_options: Unpack[options.CompilationOptionsArgs] - ) -> Program: + ) -> Self: return dataclasses.replace( self, compilation_options=dataclasses.replace( @@ -91,7 +91,7 @@ def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: # is skipped and a pool with (potentially different) options as given to # `compile` is used. return self._make_compiled_programs_pool( - static_params=self.compilation_options.static_params, + static_params=self.compilation_options.static_params or (), ) def _make_compiled_programs_pool( @@ -133,7 +133,7 @@ def compile( # TODO(havogt): we should reconsider if we want to return a new program on `compile` (and # rename to `with_static_args` or similar) once we have a better understanding of the # use-cases. - # check if pool has already been initialized, since this is also a cached property go via + # check if pool has already been initialized. since this is also a cached property go via # the dict directly. Note that we don't need to check any args, since the pool checks # this on compile anyway. if "_compiled_programs" not in self.__dict__: diff --git a/src/gt4py/next/otf/options.py b/src/gt4py/next/otf/options.py index bfb452941c..a1f28b17f7 100644 --- a/src/gt4py/next/otf/options.py +++ b/src/gt4py/next/otf/options.py @@ -26,8 +26,10 @@ class CompilationOptions: # mostly important for testing. Users should not rely on it. enable_jit: bool = dataclasses.field(default_factory=lambda: config.ENABLE_JIT_DEFAULT) - #: if the user requests static params, they will be used later to initialize CompiledPrograms - static_params: Sequence[str] = () + #: If the user requests static params, they will be used later to initialize CompiledPrograms. + #: By default the set of static params is set when compiling for the first time, e.g. on call + #: when jitting is enabled, or on a call to `compiled`. + static_params: Sequence[str] | None = None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information #: A dictionary holding static/compile-time information about the offset providers. diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index cc19faf6a6..2098dc080b 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -202,6 +202,11 @@ def testee( return testee +@pytest.fixture +def compile_testee_unstructured_no_jit(compile_testee_unstructured): + return compile_testee_unstructured.with_compilation_option(enable_jit=False) + + def test_compile_unstructured(unstructured_case, compile_testee_unstructured): if unstructured_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") @@ -262,73 +267,72 @@ def test_compile_unstructured_jit( def test_compile_unstructured_wrong_offset_provider( - unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor + unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor ): if unstructured_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") # compiled for skip_value_mesh - compile_testee_unstructured.compile( + compile_testee_unstructured_no_jit.compile( offset_provider=skip_value_mesh_descriptor.offset_provider, - enable_jit=False, ) # but executing the simple_mesh - args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured) + args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit) # make sure the backend is never called - object.__setattr__(compile_testee_unstructured, "backend", _raise_on_compile) + object.__setattr__(compile_testee_unstructured_no_jit, "backend", _raise_on_compile) with pytest.raises(RuntimeError, match="No program.*static.*arg.*"): - compile_testee_unstructured( + compile_testee_unstructured_no_jit( *args, offset_provider=unstructured_case.offset_provider, **kwargs ) def test_compile_unstructured_modified_offset_provider( - unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor + unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor ): if unstructured_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") # compiled for skip_value_mesh - compile_testee_unstructured.compile( + compile_testee_unstructured_no_jit.compile( offset_provider=skip_value_mesh_descriptor.offset_provider, - enable_jit=False, ) # but executing the simple_mesh - args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured) + args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit) # make sure the backend is never called - object.__setattr__(compile_testee_unstructured, "backend", _raise_on_compile) + object.__setattr__(compile_testee_unstructured_no_jit, "backend", _raise_on_compile) with pytest.raises(RuntimeError, match="No program.*static.*arg.*"): - compile_testee_unstructured( + compile_testee_unstructured_no_jit( *args, offset_provider=unstructured_case.offset_provider, **kwargs ) def test_compile_unstructured_for_two_offset_providers( - unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor + unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor ): if unstructured_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") # compiled for skip_value_mesh and simple_mesh - compile_testee_unstructured.compile( + compile_testee_unstructured_no_jit.compile( offset_provider=[ skip_value_mesh_descriptor.offset_provider, unstructured_case.offset_provider, ], - enable_jit=False, ) # make sure the backend is never called - object.__setattr__(compile_testee_unstructured, "backend", _raise_on_compile) + object.__setattr__(compile_testee_unstructured_no_jit, "backend", _raise_on_compile) - args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured) - compile_testee_unstructured(*args, offset_provider=unstructured_case.offset_provider, **kwargs) + args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit) + compile_testee_unstructured_no_jit( + *args, offset_provider=unstructured_case.offset_provider, **kwargs + ) v2e_numpy = unstructured_case.offset_provider[V2E.value].asnumpy() assert np.allclose( @@ -400,7 +404,9 @@ def test_compile_variants(cartesian_case, compile_variants_testee): # make sure the backend is never called object.__setattr__(compile_variants_testee, "backend", _raise_on_compile) - assert compile_variants_testee.compilation_options.static_params == ( + assert compile_variants_testee._compiled_programs.argument_descriptor_mapping[ + arguments.StaticArg + ] == ( "scalar_int", "scalar_float", "scalar_bool", diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 7ff97491af..79d6666af4 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -55,19 +55,22 @@ def test_sanitize_static_args_wrong_type(): TDim = gtx.Dimension("TDim") -@gtx.field_operator -def fop(cond: bool, a: gtx.Field[gtx.Dims[TDim], float], b: gtx.Field[gtx.Dims[TDim], float]): - return a if cond else b - +@pytest.fixture +def testee_prog(): + @gtx.field_operator + def fop(cond: bool, a: gtx.Field[gtx.Dims[TDim], float], b: gtx.Field[gtx.Dims[TDim], float]): + return a if cond else b + + @gtx.program(backend=gtfn.run_gtfn) + def prog( + cond: bool, + a: gtx.Field[gtx.Dims[TDim], gtx.float64], + b: gtx.Field[gtx.Dims[TDim], gtx.float64], + out: gtx.Field[gtx.Dims[TDim], gtx.float64], + ): + fop(cond, a, b, out=out) -@gtx.program -def prog( - cond: bool, - a: gtx.Field[gtx.Dims[TDim], gtx.float64], - b: gtx.Field[gtx.Dims[TDim], gtx.float64], - out: gtx.Field[gtx.Dims[TDim], gtx.float64], -): - fop(cond, a, b, out=out) + return prog def _verify_program_has_expected_true_value(program: itir.Program): @@ -78,11 +81,11 @@ def _verify_program_has_expected_true_value(program: itir.Program): assert program.body[0].expr.args[0].value # is True -def test_inlining_of_scalars_works(): +def test_inlining_of_scalars_works(testee_prog): input_pair = toolchain.CompilableProgram( - data=prog.definition_stage, + data=testee_prog.definition_stage, args=arguments.CompileTimeArgs( - args=list(prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), + args=list(testee_prog.past_stage.past_node.type.definition.pos_or_kw_args.values()), kwargs={}, offset_provider={}, column_axis=None, @@ -96,7 +99,7 @@ def test_inlining_of_scalars_works(): _verify_program_has_expected_true_value(transformed) -def test_inlining_of_scalar_works_integration(): +def test_inlining_of_scalar_works_integration(testee_prog): """ Test that `.compile` replaces the scalar arg in the program. Unlike the previous test, this test uses a full backend and makes sure the replacement step is there. @@ -114,7 +117,7 @@ def pirate(program: toolchain.CompilableProgram): hacked_gtfn_backend = gtfn.GTFNBackendFactory(name_postfix="_custom", otf_workflow=pirate) - testee = prog.with_backend(hacked_gtfn_backend).compile(cond=[True], offset_provider={}) + testee = testee_prog.with_backend(hacked_gtfn_backend).compile(cond=[True], offset_provider={}) testee( cond=True, a=gtx.zeros(domain={TDim: 1}, dtype=gtx.float64), @@ -126,12 +129,36 @@ def pirate(program: toolchain.CompilableProgram): _verify_program_has_expected_true_value(hijacked_program.data) -def test_different_static_args_work_after_backend_change(): - prg1 = prog.with_backend(gtfn.run_gtfn) - prg2 = prog.with_backend(gtfn.run_gtfn) +def test_different_static_args_work_after_backend_change(testee_prog): + prg1 = testee_prog.with_backend(gtfn.run_gtfn) + prg2 = testee_prog.with_backend(gtfn.run_gtfn) # compile with static args prg1.compile(cond=[True], offset_provider={}) # compile without static args prg2.compile(offset_provider={}) + + +def test_different_static_args_work_after_static_params_change(testee_prog): + testee_prog2 = testee_prog.with_compilation_option(static_params=["cond"]) + + # compile without static args + testee_prog.compile(offset_provider={}) + + # compile with static args + testee_prog2.compile(cond=[True], offset_provider={}) + + +def test_different_static_args_break_same_prg_after_static_params_change(testee_prog): + prg = testee_prog.with_compilation_option(static_params=[]) + + # compile without static args + prg.compile(offset_provider={}) + + # compile with different static args + with pytest.raises( + ValueError, + match="Argument descriptor StaticArg must be the same for all compiled programs", + ): + prg.compile(cond=[True], offset_provider={}) From 8a6bbfead9ee4f2c65ddbdf5ce3eec96dc856960 Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Feb 2026 13:39:18 +0100 Subject: [PATCH 4/5] Cleanup --- src/gt4py/next/ffront/decorator.py | 12 +++++----- src/gt4py/next/otf/options.py | 2 +- .../feature_tests/dace/test_orchestration.py | 8 +++---- .../ffront_tests/test_compiled_program.py | 22 +++++++++++++------ .../otf_tests/test_compiled_program.py | 4 ++-- 5 files changed, 28 insertions(+), 20 deletions(-) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index f25863b334..a857b4e700 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -75,7 +75,7 @@ def __gt_type__(self) -> ts.CallableType: ... def with_backend(self, backend: next_backend.Backend) -> Self: return dataclasses.replace(self, backend=backend) - def with_compilation_option( + def with_compilation_options( self, **compilation_options: Unpack[options.CompilationOptionsArgs] ) -> Self: return dataclasses.replace( @@ -87,9 +87,10 @@ def with_compilation_option( @functools.cached_property def _compiled_programs(self) -> compiled_program.CompiledProgramsPool: - # note(tehrengruber): If the program is compiled using `compile` this method - # is skipped and a pool with (potentially different) options as given to - # `compile` is used. + # This cached property initializer is only called when JITting the first + # program variant of the pool. If the program is compiled by directly + # calling `compile()`, the pool is initialized with the options passed + # to `compile()` instead of re-using the existing compilations options. return self._make_compiled_programs_pool( static_params=self.compilation_options.static_params or (), ) @@ -137,10 +138,9 @@ def compile( # the dict directly. Note that we don't need to check any args, since the pool checks # this on compile anyway. if "_compiled_programs" not in self.__dict__: - pool = self._make_compiled_programs_pool( + self.__dict__["_compiled_programs"] = self._make_compiled_programs_pool( static_params=tuple(static_args.keys()), ) - object.__setattr__(self, "_compiled_programs", pool) if self.compilation_options.connectivities is None and offset_provider is None: raise ValueError( diff --git a/src/gt4py/next/otf/options.py b/src/gt4py/next/otf/options.py index a1f28b17f7..de9f35cfb0 100644 --- a/src/gt4py/next/otf/options.py +++ b/src/gt4py/next/otf/options.py @@ -28,7 +28,7 @@ class CompilationOptions: #: If the user requests static params, they will be used later to initialize CompiledPrograms. #: By default the set of static params is set when compiling for the first time, e.g. on call - #: when jitting is enabled, or on a call to `compiled`. + #: when jitting is enabled, or on a call to `compile`. static_params: Sequence[str] | None = None # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information diff --git a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py index 0906ec38c0..ca0a106d74 100644 --- a/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py +++ b/tests/next_tests/integration_tests/feature_tests/dace/test_orchestration.py @@ -50,12 +50,12 @@ def sdfg(): tmp_field = xp.empty_like(out_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( backend - ).with_compilation_option( + ).with_compilation_options( connectivities=gtx_common.offset_provider_to_type(cartesian_case.offset_provider) )(in_field, tmp_field) lap_program.with_grid_type(cartesian_case.grid_type).with_backend( backend - ).with_compilation_option( + ).with_compilation_options( connectivities=gtx_common.offset_provider_to_type(cartesian_case.offset_provider) )(tmp_field, out_field) @@ -109,7 +109,7 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811 allocator=allocator, ) - testee2 = testee.with_backend(backend).with_compilation_option(connectivities={"E2V": e2v}) + testee2 = testee.with_backend(backend).with_compilation_options(connectivities={"E2V": e2v}) @dace.program def sdfg( @@ -118,7 +118,7 @@ def sdfg( offset_provider: OffsetProvider_t, connectivities: dace.compiletime, ): - testee2.with_compilation_option(connectivities=connectivities)( + testee2.with_compilation_options(connectivities=connectivities)( a, out, offset_provider=offset_provider ) return out diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py index 2098dc080b..732d74119f 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_compiled_program.py @@ -204,7 +204,7 @@ def testee( @pytest.fixture def compile_testee_unstructured_no_jit(compile_testee_unstructured): - return compile_testee_unstructured.with_compilation_option(enable_jit=False) + return compile_testee_unstructured.with_compilation_options(enable_jit=False) def test_compile_unstructured(unstructured_case, compile_testee_unstructured): @@ -241,23 +241,24 @@ def skip_value_mesh_descriptor(exec_alloc_descriptor): def test_compile_unstructured_jit( - unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor + unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor ): if unstructured_case.backend is None: pytest.skip("Embedded compiled program doesn't make sense.") # compiled for skip_value_mesh and simple_mesh - compile_testee_unstructured.compile( + compile_testee_unstructured_no_jit.compile( offset_provider=[ skip_value_mesh_descriptor.offset_provider, unstructured_case.offset_provider, ], - enable_jit=False, ) # and executing the simple_mesh - args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured) - compile_testee_unstructured(*args, offset_provider=unstructured_case.offset_provider, **kwargs) + args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit) + compile_testee_unstructured_no_jit( + *args, offset_provider=unstructured_case.offset_provider, **kwargs + ) v2e_numpy = unstructured_case.offset_provider[V2E.value].asnumpy() assert np.allclose( @@ -643,7 +644,14 @@ def test_compile_variants_not_compiled_then_set_new_static_params( field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")() # the compile_variants_testee has static_params set and is compiled (in a previous test) - assert len(compile_variants_testee.compilation_options.static_params) > 0 + assert ( + len( + compile_variants_testee._compiled_programs.argument_descriptor_mapping[ + arguments.StaticArg + ] + ) + > 0 + ) assert compile_variants_testee._compiled_programs is not None # but now we reset the compiled programs and fix to other static params diff --git a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py index 79d6666af4..51645fba02 100644 --- a/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py +++ b/tests/next_tests/unit_tests/otf_tests/test_compiled_program.py @@ -141,7 +141,7 @@ def test_different_static_args_work_after_backend_change(testee_prog): def test_different_static_args_work_after_static_params_change(testee_prog): - testee_prog2 = testee_prog.with_compilation_option(static_params=["cond"]) + testee_prog2 = testee_prog.with_compilation_options(static_params=["cond"]) # compile without static args testee_prog.compile(offset_provider={}) @@ -151,7 +151,7 @@ def test_different_static_args_work_after_static_params_change(testee_prog): def test_different_static_args_break_same_prg_after_static_params_change(testee_prog): - prg = testee_prog.with_compilation_option(static_params=[]) + prg = testee_prog.with_compilation_options(static_params=[]) # compile without static args prg.compile(offset_provider={}) From d00246ce8c2cb9403cb3178ff19aa3c33e6af6fe Mon Sep 17 00:00:00 2001 From: Till Ehrengruber Date: Fri, 6 Feb 2026 14:03:45 +0100 Subject: [PATCH 5/5] Fix dace test --- .../runners_tests/dace_tests/test_dace_bindings.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py index 6c538af4d6..f17e0ae57a 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/test_dace_bindings.py @@ -281,7 +281,7 @@ def testee_op( a[0] + 2 * a[1][0] + 3 * a[1][1] + 4 * b[0][0] + 5 * b[1] ) # skip 'a[1][2]' on purpose to cover unused scalar args - @gtx.program + @gtx.program(enable_jit=False) def testee( a: tuple[int32, tuple[int32, cases.IJKField, int32]], b: tuple[tuple[cases.IJKField], int32], # use 'b_0' to test tuple with single element @@ -327,7 +327,7 @@ def testee( program = ( testee.with_grid_type(gtx_common.GridType.CARTESIAN) .with_backend(backend) - .compile(enable_jit=False, offset_provider={}, **static_args) + .compile(offset_provider={}, **static_args) ) program(a, b, out=c, M=M, N=N, K=K) assert np.all(c.asnumpy() == ref) @@ -344,7 +344,7 @@ def testee_op(a: cases.VField) -> cases.VField: tmp_2 = neighbor_sum(tmp(V2E), axis=V2EDim) return tmp_2 - @gtx.program + @gtx.program(enable_jit=False) def testee(a: cases.VField, b: cases.VField): testee_op(a, out=b) @@ -384,7 +384,7 @@ def testee(a: cases.VField, b: cases.VField): program = ( testee.with_grid_type(gtx_common.GridType.UNSTRUCTURED) .with_backend(backend) - .compile(enable_jit=False, offset_provider=offset_provider, **static_args) + .compile(offset_provider=offset_provider, **static_args) ) program(a, b, offset_provider=offset_provider) assert np.all(b.asnumpy() == ref)