From 726dadfc1ab3bda13595e1018785dbd34ebd3fc6 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Sat, 10 Jan 2026 10:27:49 +0100 Subject: [PATCH 1/6] Fix nanobind segfault --- src/gt4py/next/otf/compilation/compiler.py | 10 +++++----- src/gt4py/next/program_processors/runners/gtfn.py | 8 +++++++- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index e03fa84e50..e4ffbff7dd 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -42,6 +42,8 @@ def __call__( ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... +_MODULE_CACHE = [] + @dataclasses.dataclass(frozen=True) class Compiler( workflow.ChainableWorkflowMixin[ @@ -83,12 +85,10 @@ def __call__( f"On-the-fly compilation unsuccessful for '{inp.program_source.entry_point.name}'." ) - compiled_prog = getattr( - importer.import_from_path(src_dir / new_data.module), new_data.entry_point_name - ) - - return compiled_prog + m = importer.import_from_path(src_dir / new_data.module) + _MODULE_CACHE.append(_MODULE_CACHE) + return getattr(m, new_data.entry_point_name) class CompilerFactory(factory.Factory): class Meta: diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 7bd796785d..6084914fd5 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -42,6 +42,12 @@ def convert_arg(arg: Any) -> Any: return arg +import faulthandler +import signal +faulthandler.enable() +faulthandler.register(signal.SIGUSR1) + + def convert_args( inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: @@ -53,7 +59,7 @@ def decorated_program( # Note: this function is on the hot path and needs to have minimal overhead. if out is not None: args = (*args, out) - converted_args = (convert_arg(arg) for arg in args) + converted_args = [convert_arg(arg) for arg in args] conn_args = extract_connectivity_args(offset_provider, device) opt_kwargs: dict[str, Any] = {} From 39e9c37e069754e6655c6d639259f0d4a3929b46 Mon Sep 17 00:00:00 2001 From: tehrengruber Date: Sat, 10 Jan 2026 10:50:57 +0100 Subject: [PATCH 2/6] Cleanup --- src/gt4py/next/otf/compilation/compiler.py | 9 +++++++-- src/gt4py/next/program_processors/runners/gtfn.py | 8 +------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index e4ffbff7dd..7a0400ea91 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,6 +10,7 @@ import dataclasses import pathlib +import types from typing import Protocol, TypeVar import factory @@ -42,7 +43,8 @@ def __call__( ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... -_MODULE_CACHE = [] +_MODULES: list[types.ModuleType] = [] + @dataclasses.dataclass(frozen=True) class Compiler( @@ -86,10 +88,13 @@ def __call__( ) m = importer.import_from_path(src_dir / new_data.module) - _MODULE_CACHE.append(_MODULE_CACHE) + # Keep a reference to the module so they are not garbage collected. This avoids a SEGFAULT + # in nanobind when calling the compiled program. + _MODULES.append(m) return getattr(m, new_data.entry_point_name) + class CompilerFactory(factory.Factory): class Meta: model = Compiler diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 6084914fd5..7bd796785d 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -42,12 +42,6 @@ def convert_arg(arg: Any) -> Any: return arg -import faulthandler -import signal -faulthandler.enable() -faulthandler.register(signal.SIGUSR1) - - def convert_args( inp: stages.CompiledProgram, device: core_defs.DeviceType = core_defs.DeviceType.CPU ) -> stages.CompiledProgram: @@ -59,7 +53,7 @@ def decorated_program( # Note: this function is on the hot path and needs to have minimal overhead. if out is not None: args = (*args, out) - converted_args = [convert_arg(arg) for arg in args] + converted_args = (convert_arg(arg) for arg in args) conn_args = extract_connectivity_args(offset_provider, device) opt_kwargs: dict[str, Any] = {} From 1f4add078d471d0a1d017b4d6176f566b65c76de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 4 Feb 2026 13:47:54 +0100 Subject: [PATCH 3/6] Try a different workaround --- src/gt4py/next/otf/compilation/compiler.py | 24 +++++++++++++--------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 7a0400ea91..a3237a2cd1 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -10,8 +10,7 @@ import dataclasses import pathlib -import types -from typing import Protocol, TypeVar +from typing import Protocol, TypeVar, cast import factory @@ -43,9 +42,6 @@ def __call__( ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... -_MODULES: list[types.ModuleType] = [] - - @dataclasses.dataclass(frozen=True) class Compiler( workflow.ChainableWorkflowMixin[ @@ -88,11 +84,19 @@ def __call__( ) m = importer.import_from_path(src_dir / new_data.module) - # Keep a reference to the module so they are not garbage collected. This avoids a SEGFAULT - # in nanobind when calling the compiled program. - _MODULES.append(m) - - return getattr(m, new_data.entry_point_name) + func = getattr(m, new_data.entry_point_name) + + # Since nanobind 2.10, calling functions with ndarray args crashes (SEGFAULT) + # when there are not live references to their extension module (see: https://github.com/wjakob/nanobind/issues/1283) + # Here we dynamically create a new callable class holding a reference to the + # module and the function, whose `__call__` is exactly the `__call__` method + # of the returned (nanobind) nbfunction object. As long as this object is alive, + # the module reference is kept alive too, preventing the SEGFAULT. + managed_entry_point = type( + f"{m.__name__}__{id(m)}", (), dict(__call__=func.__call__, module_ref=m, func_ref=func) + )() + + return cast(stages.CompiledProgram, managed_entry_point) class CompilerFactory(factory.Factory): From 8bc2d7b1565a0e7ba5ee91e8a0cad4be4115d203 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 4 Feb 2026 14:48:54 +0100 Subject: [PATCH 4/6] Add parts of copilot suggestions --- src/gt4py/next/otf/compilation/compiler.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index a3237a2cd1..d1d09041ea 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -93,8 +93,15 @@ def __call__( # of the returned (nanobind) nbfunction object. As long as this object is alive, # the module reference is kept alive too, preventing the SEGFAULT. managed_entry_point = type( - f"{m.__name__}__{id(m)}", (), dict(__call__=func.__call__, module_ref=m, func_ref=func) - )() + f"{m.__name__}_managed_wrapper", + (), + dict( + __call__=func.__call__, + __doc__=getattr(func, "__doc__", None), + __hash__=func.__hash__, + __eq__=func.__eq__, + ), + )(module_ref=m, func_ref=func) return cast(stages.CompiledProgram, managed_entry_point) From 194ea5f7cafde92b40d4cca1ce34eefcd02cf8f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 4 Feb 2026 14:51:07 +0100 Subject: [PATCH 5/6] Fix format --- src/gt4py/next/otf/compilation/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index d1d09041ea..aeceb46c7b 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -100,7 +100,7 @@ def __call__( __doc__=getattr(func, "__doc__", None), __hash__=func.__hash__, __eq__=func.__eq__, - ), + ), )(module_ref=m, func_ref=func) return cast(stages.CompiledProgram, managed_entry_point) From 9f886df574560f0d7b17b0151b02775dd21b20d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Enrique=20Gonz=C3=A1lez=20Paredes?= Date: Wed, 4 Feb 2026 15:46:41 +0100 Subject: [PATCH 6/6] Fix leftovers from previous experiments --- src/gt4py/next/otf/compilation/compiler.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index aeceb46c7b..1f9c85b07f 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -100,8 +100,10 @@ def __call__( __doc__=getattr(func, "__doc__", None), __hash__=func.__hash__, __eq__=func.__eq__, + module_ref=m, + func_ref=func, ), - )(module_ref=m, func_ref=func) + )() return cast(stages.CompiledProgram, managed_entry_point)