diff --git a/src/gt4py/cartesian/backend/__init__.py b/src/gt4py/cartesian/backend/__init__.py index 7dcb7c9380..b3bc5e64fa 100644 --- a/src/gt4py/cartesian/backend/__init__.py +++ b/src/gt4py/cartesian/backend/__init__.py @@ -8,6 +8,8 @@ from warnings import warn +from gt4py.cartesian import config as gt_config + from .base import REGISTRY, Backend, BaseBackend, BasePyExtBackend, from_name, register from .debug_backend import DebugBackend from .gtcpp_backend import GTCpuIfirstBackend, GTCpuKfirstBackend, GTGpuBackend @@ -31,12 +33,13 @@ ] -try: - from .dace_backend import DaceCPUBackend, DaceCPUKFirstBackend, DaceGPUBackend +if gt_config.GT4PY_CART_ENABLE_DACE: + try: + from .dace_backend import DaceCPUBackend, DaceCPUKFirstBackend, DaceGPUBackend - __all__ += ["DaceCPUBackend", "DaceCPUKFirstBackend", "DaceGPUBackend"] -except ImportError: - warn( - "GT4Py was unable to load DaCe. DaCe backends (`dace:cpu`, `dace:cpu_kfirst`, and `dace:gpu`) will not be available.", - stacklevel=2, - ) + __all__ += ["DaceCPUBackend", "DaceCPUKFirstBackend", "DaceGPUBackend"] + except ImportError: + warn( + "GT4Py was unable to load DaCe. DaCe backends (`dace:cpu`, `dace:cpu_kfirst`, and `dace:gpu`) will not be available.", + stacklevel=2, + ) diff --git a/src/gt4py/cartesian/config.py b/src/gt4py/cartesian/config.py index a9ded21ec7..c649d63f37 100644 --- a/src/gt4py/cartesian/config.py +++ b/src/gt4py/cartesian/config.py @@ -30,6 +30,8 @@ GT4PY_COMPILE_OPT_LEVEL: str = os.environ.get("GT4PY_COMPILE_OPT_LEVEL", "3") GT4PY_EXTRA_COMPILE_OPT_FLAGS: str = os.environ.get("GT4PY_EXTRA_COMPILE_OPT_FLAGS", "") +GT4PY_CART_ENABLE_DACE: bool = bool(int(os.environ.get("GT4PY_CART_ENABLE_DACE", "1"))) + # Settings dict GT4PY_EXTRA_COMPILE_ARGS: str = os.environ.get("GT4PY_EXTRA_COMPILE_ARGS", "") extra_compile_args: List[str] = ( diff --git a/src/gt4py/cartesian/gtscript.py b/src/gt4py/cartesian/gtscript.py index 56bf8874a8..da02e7951c 100644 --- a/src/gt4py/cartesian/gtscript.py +++ b/src/gt4py/cartesian/gtscript.py @@ -20,14 +20,15 @@ import numpy as np -from gt4py.cartesian import definitions as gt_definitions +from gt4py.cartesian import config as gt_config, definitions as gt_definitions from gt4py.cartesian.lazy_stencil import LazyStencil -try: - from gt4py.cartesian.backend.dace_lazy_stencil import DaCeLazyStencil -except ImportError: - DaCeLazyStencil = LazyStencil # type: ignore +if gt_config.GT4PY_CART_ENABLE_DACE: + try: + from gt4py.cartesian.backend.dace_lazy_stencil import DaCeLazyStencil + except ImportError: + DaCeLazyStencil = LazyStencil # type: ignore # GTScript builtins MATH_BUILTINS = {