Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import typing
import warnings
from collections.abc import Callable
from typing import Any, Generic, Optional, TypeVar
from typing import Any, Generic, Optional, Sequence, TypeVar

from gt4py import eve
from gt4py._core import definitions as core_defs
Expand Down Expand Up @@ -75,27 +75,35 @@ def __gt_type__(self) -> ts.CallableType: ...
def with_backend(self, backend: next_backend.Backend) -> Self:
return dataclasses.replace(self, backend=backend)

def with_connectivities(
self,
connectivities: common.OffsetProvider, # TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
def with_compilation_options(
self, **compilation_options: Unpack[options.CompilationOptionsArgs]
) -> Self:
return dataclasses.replace(
self,
compilation_options=dataclasses.replace(
self.compilation_options, connectivities=connectivities
self.compilation_options, **compilation_options
),
)

@functools.cached_property
def _compiled_programs(self) -> compiled_program.CompiledProgramsPool:
# This cached property initializer is only called when JITting the first
# program variant of the pool. If the program is compiled by directly
# calling `compile()`, the pool is initialized with the options passed
# to `compile()` instead of re-using the existing compilations options.
return self._make_compiled_programs_pool(
static_params=self.compilation_options.static_params or (),
)

def _make_compiled_programs_pool(
self,
static_params: Sequence[str],
) -> compiled_program.CompiledProgramsPool:
if self.backend is None or self.backend == eve.NOTHING:
raise RuntimeError("Cannot compile a program without backend.")

if self.compilation_options.static_params is None:
object.__setattr__(self.compilation_options, "static_params", ())

argument_descriptor_mapping = {
arguments.StaticArg: self.compilation_options.static_params,
arguments.StaticArg: static_params,
}

program_type = ffront_type_info.type_in_program_context(self.__gt_type__())
Expand All @@ -114,7 +122,6 @@ def compile(
| common.OffsetProvider
| list[common.OffsetProviderType | common.OffsetProvider]
| None = None,
enable_jit: bool | None = None,
**static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
) -> Self:
"""
Expand All @@ -125,13 +132,16 @@ def compile(
but adds the compiled variants to the current program instance.
"""
# TODO(havogt): we should reconsider if we want to return a new program on `compile` (and
# rename to `with_static_args` or similar) once we have a better understanding of the
# use-cases.
# rename to `with_static_args` or similar) once we have a better understanding of the
# use-cases.
# check if pool has already been initialized. since this is also a cached property go via
# the dict directly. Note that we don't need to check any args, since the pool checks
# this on compile anyway.
if "_compiled_programs" not in self.__dict__:
self.__dict__["_compiled_programs"] = self._make_compiled_programs_pool(
static_params=tuple(static_args.keys()),
)

if enable_jit is not None:
object.__setattr__(self.compilation_options, "enable_jit", enable_jit)
if self.compilation_options.static_params is None:
object.__setattr__(self.compilation_options, "static_params", tuple(static_args.keys()))
if self.compilation_options.connectivities is None and offset_provider is None:
raise ValueError(
"Cannot compile a program without connectivities / OffsetProviderType."
Expand Down Expand Up @@ -267,10 +277,10 @@ def with_grid_type(self, grid_type: common.GridType) -> Program:

def with_static_params(self, *static_params: str | None) -> Program:
if not static_params or (static_params == (None,)):
_static_params = None
_static_params: tuple[str, ...] = ()
else:
assert all(p is not None for p in static_params)
_static_params = typing.cast(tuple[str], static_params)
_static_params = typing.cast(tuple[str, ...], static_params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: do you understand why the cast is used/needed here after the assert?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear:

error: Incompatible types in assignment (expression has type "tuple[str | None, ...]", variable has type "tuple[str, ...]")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I get it. mypy does not understand the all() inside the assert. A custom typeguard would be needed here, but it can be done in a different PR.

return dataclasses.replace(
self,
compilation_options=dataclasses.replace(
Expand Down Expand Up @@ -435,7 +445,6 @@ def compile(
| common.OffsetProvider
| list[common.OffsetProviderType | common.OffsetProvider]
| None = None,
enable_jit: bool | None = None,
**static_args: list[xtyping.MaybeNestedInTuple[core_defs.Scalar]],
) -> Self:
raise NotImplementedError("Compilation of programs with bound arguments is not implemented")
Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/otf/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ class CompilationOptions:
# mostly important for testing. Users should not rely on it.
enable_jit: bool = dataclasses.field(default_factory=lambda: config.ENABLE_JIT_DEFAULT)

#: if the user requests static params, they will be used later to initialize CompiledPrograms
static_params: Sequence[str] | None = (
None # TODO: describe that this value will eventually be a sequence of strings
)
#: If the user requests static params, they will be used later to initialize CompiledPrograms.
#: By default the set of static params is set when compiling for the first time, e.g. on call
#: when jitting is enabled, or on a call to `compile`.
static_params: Sequence[str] | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: wouldn't it be simpler to just default to an empty tuple?

Suggested change
static_params: Sequence[str] | None = None
static_params: Sequence[str] = ()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first changed it to be (), but I found it hard to understand under which circumstances it is ok for this value to be different than what is inside the compiled programs pool. I therefore went back to None here and set the value to () in _compiled_programs when the value is actually selected. I'll change it back when you like.


# TODO(ricoh): replace with common.OffsetProviderType once the temporary pass doesn't require the runtime information
#: A dictionary holding static/compile-time information about the offset providers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ def sdfg():
tmp_field = xp.empty_like(out_field)
lap_program.with_grid_type(cartesian_case.grid_type).with_backend(
backend
).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))(
in_field, tmp_field
)
).with_compilation_options(
connectivities=gtx_common.offset_provider_to_type(cartesian_case.offset_provider)
)(in_field, tmp_field)
lap_program.with_grid_type(cartesian_case.grid_type).with_backend(
backend
).with_connectivities(gtx_common.offset_provider_to_type(cartesian_case.offset_provider))(
tmp_field, out_field
)
).with_compilation_options(
connectivities=gtx_common.offset_provider_to_type(cartesian_case.offset_provider)
)(tmp_field, out_field)

# use unique SDFG folder in dace cache to avoid clashes between parallel pytest workers
with dace.config.set_temporary("cache", value="unique"):
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_sdfgConvertible_connectivities(unstructured_case): # noqa: F811
allocator=allocator,
)

testee2 = testee.with_backend(backend).with_connectivities({"E2V": e2v})
testee2 = testee.with_backend(backend).with_compilation_options(connectivities={"E2V": e2v})

@dace.program
def sdfg(
Expand All @@ -118,7 +118,9 @@ def sdfg(
offset_provider: OffsetProvider_t,
connectivities: dace.compiletime,
):
testee2.with_connectivities(connectivities)(a, out, offset_provider=offset_provider)
testee2.with_compilation_options(connectivities=connectivities)(
a, out, offset_provider=offset_provider
)
return out

connectivities = {"E2V": e2v} # replace 'e2v' with 'e2v.__gt_type__()' when GTIR is AOT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from gt4py import next as gtx
from gt4py._core import definitions as core_defs
from gt4py.next import errors, config
from gt4py.next.otf import compiled_program, options
from gt4py.next.otf import compiled_program, options, arguments
from gt4py.next.ffront.decorator import Program
from gt4py.next.ffront.fbuiltins import int32, neighbor_sum

Expand Down Expand Up @@ -202,6 +202,11 @@ def testee(
return testee


@pytest.fixture
def compile_testee_unstructured_no_jit(compile_testee_unstructured):
return compile_testee_unstructured.with_compilation_options(enable_jit=False)


def test_compile_unstructured(unstructured_case, compile_testee_unstructured):
if unstructured_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")
Expand Down Expand Up @@ -236,23 +241,24 @@ def skip_value_mesh_descriptor(exec_alloc_descriptor):


def test_compile_unstructured_jit(
unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor
unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor
):
if unstructured_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

# compiled for skip_value_mesh and simple_mesh
compile_testee_unstructured.compile(
compile_testee_unstructured_no_jit.compile(
offset_provider=[
skip_value_mesh_descriptor.offset_provider,
unstructured_case.offset_provider,
],
enable_jit=False,
)

# and executing the simple_mesh
args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured)
compile_testee_unstructured(*args, offset_provider=unstructured_case.offset_provider, **kwargs)
args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit)
compile_testee_unstructured_no_jit(
*args, offset_provider=unstructured_case.offset_provider, **kwargs
)

v2e_numpy = unstructured_case.offset_provider[V2E.value].asnumpy()
assert np.allclose(
Expand All @@ -262,73 +268,72 @@ def test_compile_unstructured_jit(


def test_compile_unstructured_wrong_offset_provider(
unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor
unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor
):
if unstructured_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

# compiled for skip_value_mesh
compile_testee_unstructured.compile(
compile_testee_unstructured_no_jit.compile(
offset_provider=skip_value_mesh_descriptor.offset_provider,
enable_jit=False,
)

# but executing the simple_mesh
args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured)
args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit)

# make sure the backend is never called
object.__setattr__(compile_testee_unstructured, "backend", _raise_on_compile)
object.__setattr__(compile_testee_unstructured_no_jit, "backend", _raise_on_compile)

with pytest.raises(RuntimeError, match="No program.*static.*arg.*"):
compile_testee_unstructured(
compile_testee_unstructured_no_jit(
*args, offset_provider=unstructured_case.offset_provider, **kwargs
)


def test_compile_unstructured_modified_offset_provider(
unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor
unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor
):
if unstructured_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

# compiled for skip_value_mesh
compile_testee_unstructured.compile(
compile_testee_unstructured_no_jit.compile(
offset_provider=skip_value_mesh_descriptor.offset_provider,
enable_jit=False,
)

# but executing the simple_mesh
args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured)
args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit)

# make sure the backend is never called
object.__setattr__(compile_testee_unstructured, "backend", _raise_on_compile)
object.__setattr__(compile_testee_unstructured_no_jit, "backend", _raise_on_compile)

with pytest.raises(RuntimeError, match="No program.*static.*arg.*"):
compile_testee_unstructured(
compile_testee_unstructured_no_jit(
*args, offset_provider=unstructured_case.offset_provider, **kwargs
)


def test_compile_unstructured_for_two_offset_providers(
unstructured_case, compile_testee_unstructured, skip_value_mesh_descriptor
unstructured_case, compile_testee_unstructured_no_jit, skip_value_mesh_descriptor
):
if unstructured_case.backend is None:
pytest.skip("Embedded compiled program doesn't make sense.")

# compiled for skip_value_mesh and simple_mesh
compile_testee_unstructured.compile(
compile_testee_unstructured_no_jit.compile(
offset_provider=[
skip_value_mesh_descriptor.offset_provider,
unstructured_case.offset_provider,
],
enable_jit=False,
)

# make sure the backend is never called
object.__setattr__(compile_testee_unstructured, "backend", _raise_on_compile)
object.__setattr__(compile_testee_unstructured_no_jit, "backend", _raise_on_compile)

args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured)
compile_testee_unstructured(*args, offset_provider=unstructured_case.offset_provider, **kwargs)
args, kwargs = cases.get_default_data(unstructured_case, compile_testee_unstructured_no_jit)
compile_testee_unstructured_no_jit(
*args, offset_provider=unstructured_case.offset_provider, **kwargs
)

v2e_numpy = unstructured_case.offset_provider[V2E.value].asnumpy()
assert np.allclose(
Expand Down Expand Up @@ -400,7 +405,9 @@ def test_compile_variants(cartesian_case, compile_variants_testee):
# make sure the backend is never called
object.__setattr__(compile_variants_testee, "backend", _raise_on_compile)

assert compile_variants_testee.compilation_options.static_params == (
assert compile_variants_testee._compiled_programs.argument_descriptor_mapping[
arguments.StaticArg
] == (
"scalar_int",
"scalar_float",
"scalar_bool",
Expand Down Expand Up @@ -579,7 +586,14 @@ def test_compile_variants_not_compiled_then_reset_static_params(
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()

# the compile_variants_testee has static_params set and is compiled (in a previous test)
assert len(compile_variants_testee.compilation_options.static_params) > 0
assert (
len(
compile_variants_testee._compiled_programs.argument_descriptor_mapping[
arguments.StaticArg
]
)
> 0
)
assert compile_variants_testee._compiled_programs is not None

# but now we reset the compiled programs
Expand Down Expand Up @@ -630,7 +644,14 @@ def test_compile_variants_not_compiled_then_set_new_static_params(
field_b = cases.allocate(cartesian_case, compile_variants_testee, "field_b")()

# the compile_variants_testee has static_params set and is compiled (in a previous test)
assert len(compile_variants_testee.compilation_options.static_params) > 0
assert (
len(
compile_variants_testee._compiled_programs.argument_descriptor_mapping[
arguments.StaticArg
]
)
> 0
)
assert compile_variants_testee._compiled_programs is not None

# but now we reset the compiled programs and fix to other static params
Expand Down
Loading