diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index e03fa84e50..1f9c85b07f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,7 +10,7 @@ import dataclasses import pathlib -from typing import Protocol, TypeVar +from typing import Protocol, TypeVar, cast import factory @@ -83,11 +83,29 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - compiled_prog = getattr( - importer.import_from_path(src_dir / new_data.module), new_data.entry_point_name - ) - - return compiled_prog + m = importer.import_from_path(src_dir / new_data.module) + func = getattr(m, new_data.entry_point_name) + + # Since nanobind 2.10, calling functions with ndarray args crashes (SEGFAULT) + # when there are not live references to their extension module (see: https://github.com/wjakob/nanobind/issues/1283) + # Here we dynamically create a new callable class holding a reference to the + # module and the function, whose `__call__` is exactly the `__call__` method + # of the returned (nanobind) nbfunction object. As long as this object is alive, + # the module reference is kept alive too, preventing the SEGFAULT. + managed_entry_point = type( + f"{m.__name__}_managed_wrapper", + (), + dict( + __call__=func.__call__, + __doc__=getattr(func, "__doc__", None), + __hash__=func.__hash__, + __eq__=func.__eq__, + module_ref=m, + func_ref=func, + ), + )() + + return cast(stages.CompiledProgram, managed_entry_point) class CompilerFactory(factory.Factory):