From 3bd3093894ad448d1e822fd6f8df361b99cd0562 Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Mon, 28 Oct 2024 12:17:28 +0100 Subject: [PATCH 01/10] import Config from thunder.tests.litgpt_model (#1351) --- thunder/tests/test_jit_general.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index aa0fe8728c..bc3517eb8b 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -675,7 +675,7 @@ def test_nanogpt(): ("cpu", "cuda", "meta"), ) def test_litgpt_variants(name, device): - from litgpt.config import Config + from thunder.tests.litgpt_model import Config from litgpt.model import GPT if device == "cuda" and not torch.cuda.is_available(): @@ -733,7 +733,7 @@ def test_litgpt_variants(name, device): ("cpu", "cuda"), ) def test_litgpt_variants_kvcache(name, device): - from litgpt.config import Config + from thunder.tests.litgpt_model import Config from litgpt.model import GPT import torch._dynamo # this monkeypatches torch.manual_seed From 1575797a8d21e0b845e188ddd6af305be6457b1c Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 28 Oct 2024 22:22:21 +0900 Subject: [PATCH 02/10] functional jit, bye (#1355) Signed-off-by: Masaki Kozuki --- thunder/core/interpreter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 84577f3d66..8f60b515e8 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -64,9 +64,8 @@ # of where values originated. This mode is used by the general jit. This is done by # wrapping all values in WrappedValues. # -# Both thunder.jit and thunder.functional.jit use extensions (in jit_ext.py) to -# create Thunder Programs. They use callbacks and additional lookasides to -# add their functionality. +# thunder.jit uses extensions (in jit_ext.py) to create Thunder Programs. +# They use callbacks and additional lookasides to add their functionality. # # The Thunder program constructed has three parts, a "prologue trace", a # "computation trace", and (optionally) an "epilogue trace". The prologue trace has From 287abf79ac7193f5d3dad6f7bdd7d5a42849ae6b Mon Sep 17 00:00:00 2001 From: Thomas Viehmann Date: Tue, 29 Oct 2024 09:05:11 +0100 Subject: [PATCH 03/10] skip test triggering nvfuser bug (#1356) --- thunder/tests/test_jit_general.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/thunder/tests/test_jit_general.py b/thunder/tests/test_jit_general.py index bc3517eb8b..be89967f70 100644 --- a/thunder/tests/test_jit_general.py +++ b/thunder/tests/test_jit_general.py @@ -680,6 +680,8 @@ def test_litgpt_variants(name, device): if device == "cuda" and not torch.cuda.is_available(): pytest.skip("CUDA not available") + if device == "cuda" and name == "falcon-7b-like": + pytest.skip("NVFuser reenable when https://github.com/NVIDIA/Fuser/issues/3292 is fixed") device = torch.device(device) From 6e41c90a6b5b9d9d1716960fdbe33b09ce20ba7d Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Tue, 29 Oct 2024 09:52:13 +0100 Subject: [PATCH 04/10] Add recipes and high level entrypoint (#1353) --- thunder/__init__.py | 8 ++++ thunder/core/recipe.py | 77 +++++++++++++++++++++++++++++++++++ thunder/recipes/__init__.py | 0 thunder/recipes/hf_bert.py | 24 +++++++++++ thunder/tests/test_recipes.py | 39 ++++++++++++++++++ 5 files changed, 148 insertions(+) create mode 100644 thunder/core/recipe.py create mode 100644 thunder/recipes/__init__.py create mode 100644 thunder/recipes/hf_bert.py create mode 100644 thunder/tests/test_recipes.py diff --git a/thunder/__init__.py b/thunder/__init__.py index e78839d68e..d2f596a90b 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -39,6 +39,7 @@ check_inplace_to_views, functionalize_inplace_ops, ) +from thunder.core.recipe import Recipe, Lookaside from thunder.common import ( CompileData, CompileStats, @@ -265,6 +266,13 @@ def _recursive_jit_call_warning() -> None: ) +def compile(fn: Callable, recipe: Recipe | None): + if recipe is None: + return thunder.jit(fn) + + return recipe.apply(fn) + + # This function will replace compile() (below) before RC1 # TODO RC1 Consider adding a debug_log parameter to control debug printing # TODO RC1 Consider renaming compile_options to additional_compile_options diff --git a/thunder/core/recipe.py b/thunder/core/recipe.py new file mode 100644 index 0000000000..f025c10ce0 --- /dev/null +++ b/thunder/core/recipe.py @@ -0,0 +1,77 @@ +from typing import List, Dict + +from thunder.core.transform_common import Transform +from thunder.extend import Executor, get_default_executors + +import torch + + +class Lookaside: + def __init__(self, fn, replace_with): + self._fn = fn + self._replace_with = replace_with + + +class Recipe: + # thunder.jit | torch.compile + compiler = "thunder.jit" + + def __init__(self): + pass + + def validate(self, model): + # this is supposed to raise + pass + + # def setup_operators(self) -> List[Operator]: + # # this is for registering custom kernels on the fly + # return None + + def setup_lookasides(self) -> list[Lookaside] | None: + return None + + def setup_transforms(self) -> list[Transform] | None: + return None + + def setup_executors(self): + return get_default_executors() + + def setup_config(self) -> dict: + return {} + + def apply(self, model): + self.validate(model) + + self.config = self.setup_config() + lookasides = self.setup_lookasides() + + from thunder.core import jit_ext, interpreter + + if lookasides is not None: + for lookaside in lookasides: + wrapped_replacement_fn = interpreter.interpreter_needs_wrap(lookaside._replace_with) + jit_ext._general_jit_lookaside_map[lookaside._fn] = wrapped_replacement_fn + + self.lookasides = lookasides + self.executors = self.setup_executors() + self.transforms = self.setup_transforms() + + if self.compiler == "thunder.jit": + from thunder import jit + + thunder_model = jit(model, transforms=self.transforms, executors=self.executors, **self.config) + + elif self.compiler == "torch.compile": + from thunder.dynamo import ThunderCompiler + + thunder_backend = ThunderCompiler(transforms=self.transforms, executors=self.executors, **self.config) + thunder_model = torch.compile(model, backend=thunder_backend) + + else: + raise AttributeError(f"Compiler must be one of 'thunder.jit', 'torch.compile'. Found: {self.compiler}.") + + return thunder_model + + +class DynamoRecipe(Recipe): + compiler = "torch.compile" diff --git a/thunder/recipes/__init__.py b/thunder/recipes/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/thunder/recipes/hf_bert.py b/thunder/recipes/hf_bert.py new file mode 100644 index 0000000000..abfb518905 --- /dev/null +++ b/thunder/recipes/hf_bert.py @@ -0,0 +1,24 @@ +import thunder +import transformers +import torch + + +class HFBertBasic(thunder.Recipe): + def validate(self, model): + if not isinstance(model, transformers.BertForSequenceClassification): + raise ValueError("The model must be a BertForSequenceClassification") + + def setup_lookasides(self): + warn_lookaside = thunder.Lookaside( + fn=transformers.modeling_utils.PreTrainedModel.warn_if_padding_and_no_attention_mask, + replace_with=lambda *args: None, + ) + + if hasattr(torch, "compiler") and hasattr(torch.compiler, "is_compiling"): + is_compiling = torch.compiler.is_compiling + else: + is_compiling = torch._dynamo.is_compiling + + is_compiling_lookaside = thunder.Lookaside(fn=is_compiling, replace_with=lambda *_: True) + + return [warn_lookaside, is_compiling_lookaside] diff --git a/thunder/tests/test_recipes.py b/thunder/tests/test_recipes.py new file mode 100644 index 0000000000..def8a590a5 --- /dev/null +++ b/thunder/tests/test_recipes.py @@ -0,0 +1,39 @@ +import thunder +import transformers +import torch + +from torch.testing import assert_close, make_tensor + + +def test_recipe_basic_bert(): + bert = transformers.BertForSequenceClassification(transformers.BertConfig()) + del bert.bert.encoder.layer[1:] + bert.eval() + + inp = torch.randint(1, 20, (1, 32)) + + from thunder.recipes.hf_bert import HFBertBasic + + thunder_bert = thunder.compile(bert, recipe=HFBertBasic()) + + actual = thunder_bert(inp) + expected = bert(inp) + + assert_close(actual, expected) + + +def test_recipe_basic_bert_dynamo(): + bert = transformers.BertForSequenceClassification(transformers.BertConfig()) + del bert.bert.encoder.layer[1:] + bert.eval() + + inp = torch.randint(1, 20, (1, 32)) + + from thunder.core.recipe import DynamoRecipe + + thunder_bert = thunder.compile(bert, recipe=DynamoRecipe()) + + actual = thunder_bert(inp) + expected = bert(inp) + + assert_close(actual, expected) From 2f6e48d851b5fe97979f0d4482515dc1e211be93 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 29 Oct 2024 10:02:02 +0100 Subject: [PATCH 05/10] fix leak when view of output is saved for backward in autograd.Function (#1352) --- thunder/executors/torch_autograd.py | 19 ++++++++++++++ thunder/tests/test_core.py | 39 +++++++++++++++++++++++++++++ thunder/tests/test_grad.py | 3 ++- 3 files changed, 60 insertions(+), 1 deletion(-) diff --git a/thunder/executors/torch_autograd.py b/thunder/executors/torch_autograd.py index ff6348b0d8..5374b23afe 100644 --- a/thunder/executors/torch_autograd.py +++ b/thunder/executors/torch_autograd.py @@ -70,10 +70,29 @@ def forward( ctx.saved_other = saved_other ctx.compiled_backward = compiled_backward + # NOTE [Saved view of output of torch.autograd.Function leaks] + # We detach here to avoid a bug in PyTorch where + # it leaks memory if view of the output of torch.autograd.Function + # is saved for backward. + # See - https://github.com/pytorch/pytorch/issues/94990#issuecomment-1435181804 + # NOTE - Detaching here would lead to problem with higher order differentiation but + # this is ok for now because ThunderFunction is only `once_differentiable`. + def detach_if_tensor(t): + # Some operations may claim to return Tensor (as per their meta function) + # but may return None at Runtime (eg. noticed this for sdpa) + if isinstance(t, torch.Tensor): + return t.detach() + return t + + saved_tensors = tuple(map(detach_if_tensor, saved_tensors)) + # We must save tensors using ctx.save_for_backward ctx.save_for_backward(*saved_tensors) return flat_output + # NOTE: If `torch.autograd.function.once_differentiable` is to be removed, + # one must take care of correctly removing the `detach_if_tensor` above. + # For more context, see NOTE [Saved view of output of torch.autograd.Function leaks] above. @staticmethod @torch.autograd.function.once_differentiable def backward(ctx, *args): diff --git a/thunder/tests/test_core.py b/thunder/tests/test_core.py index 345eabd921..3f1bba6437 100644 --- a/thunder/tests/test_core.py +++ b/thunder/tests/test_core.py @@ -2966,3 +2966,42 @@ def fn(): assert jfn() == 2 # Verify that jfn now returns 2 assert thunder.cache_hits(jfn) == 1 assert thunder.cache_misses(jfn) == 2 + + +def test_saved_view_of_output_of_autograd_function_does_not_leak(): + # Verify that we have side-stepped the bug in torch.autograd.Function + # where saving a view of the output for backward leads to leak. + # See NOTE [Saved view of output of torch.autograd.Function leaks] + def fn(idx, weight): + tok_emb = torch.nn.functional.embedding(idx, weight) + emb = torch.reshape(tok_emb, (2, 32)) + matmul = emb @ emb.T + return tok_emb, matmul + + weight = make_tensor((16, 32), dtype=torch.float, device="cpu", requires_grad=True) + x = make_tensor((1, 2), dtype=torch.int64, low=0, high=10, device="cpu") + + jfn = thunder.jit(fn) + + # Computation Trace for jfn + # We save view of the output `tok_emb` for backward. + # @torch.no_grad() + # @no_autocast + # def computation(idx, t_wte_weight): + # # idx: "cuda:0 i64[1, 2]" + # # t_wte_weight: "cuda:0 f32[16, 32]" + # tok_emb = torch.nn.functional.embedding(idx, t_wte_weight, None, None, 2.0, False, False) # tok_emb: "cuda:0 f32[1, 2, 32]" + # [emb, t4] = nvFusion0(tok_emb) + # # emb = prims.reshape(tok_emb, (2, 32)) # emb: "cuda:0 f32[2, 32]" + # # t4 = prims.transpose(emb, (1, 0)) # t4: "cuda:0 f32[32, 2]" + # matmul = torch.matmul(emb, t4) # matmul: "cuda:0 f32[2, 2]" + # return {'output': (tok_emb, matmul), 'flat_args': [idx, t_wte_weight], 'flat_output': (tok_emb, matmul)}, ((emb, idx, t4), ()) + + prev_iter_refs = [] + for iter_n in range(4): + tok_emb, _ = jfn(x, weight) + if iter_n < 3: + prev_iter_refs.append(weakref.ref(tok_emb)) + + for ref in prev_iter_refs: + assert ref() is None diff --git a/thunder/tests/test_grad.py b/thunder/tests/test_grad.py index 33b4e4df2c..a7b0898c7f 100644 --- a/thunder/tests/test_grad.py +++ b/thunder/tests/test_grad.py @@ -1725,7 +1725,8 @@ def f(x): # With activation checkpointing, we are saving only the original input. # The intermediate values are recomputed during backward pass. assert len(out.grad_fn.saved_tensors) == 1 - assert out.grad_fn.saved_tensors[0] is x + # We detach the saved tensors (which returns a new Python tensor backed by same storage) + assert out.grad_fn.saved_tensors[0].data_ptr() == x.data_ptr() g = torch.ones_like(out) out.backward(g) From b1dae82456e5b5690d7b026cdc6730be2510176e Mon Sep 17 00:00:00 2001 From: Martyna Patelka <149149379+mpatel31415@users.noreply.github.com> Date: Tue, 29 Oct 2024 10:07:43 +0100 Subject: [PATCH 06/10] Mpatel31415/fix for missing tokens per sec (#1347) --- thunder/benchmarks/benchmark_litgpt.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/thunder/benchmarks/benchmark_litgpt.py b/thunder/benchmarks/benchmark_litgpt.py index afb9d79f3e..4da107438a 100644 --- a/thunder/benchmarks/benchmark_litgpt.py +++ b/thunder/benchmarks/benchmark_litgpt.py @@ -819,8 +819,9 @@ def benchmark_main(return_metrics_as_json=False, json_path="", **kwargs) -> None print(f"Average iter time: {benchmark.perf_metrics['average_iter_time']:.2f} ms") print(f"Memory used: {benchmark.perf_metrics['memory_used_GB']:.02f} GB") - print(f"Tokens/s: {benchmark.perf_metrics['tokens_per_sec']:.02f}") - print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") + if "tokens_per_sec" in benchmark.perf_metrics: + print(f"Tokens/s: {benchmark.perf_metrics.get['tokens_per_sec']:.02f}") + print(f"Tokens/s/GPU: {(benchmark.perf_metrics['tokens_per_sec']/world_size):.02f}") print(f"TFLOP/s: {benchmark.perf_metrics['model_flop_per_sec'] / 1e12:.02f}") if benchmark.dump_memory_snapshot: From a7618f68c758684a77fecd9cdcf7f94a9c4f5447 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 29 Oct 2024 09:28:25 +0000 Subject: [PATCH 07/10] [pre-commit.ci] pre-commit suggestions (#1358) Co-authored-by: Thomas Viehmann --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 73bdeaade1..b235526420 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: - id: detect-private-key - repo: https://github.com/asottile/pyupgrade - rev: v3.18.0 + rev: v3.19.0 hooks: - id: pyupgrade args: ["--py310-plus"] From 49e4b57f70d247ac1aea248c245c85f767f57183 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 29 Oct 2024 18:38:45 +0900 Subject: [PATCH 08/10] [benchmark] fix litgpt config import (#1346) --- thunder/benchmarks/distributed.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/thunder/benchmarks/distributed.py b/thunder/benchmarks/distributed.py index bc02528fe7..338ea91663 100644 --- a/thunder/benchmarks/distributed.py +++ b/thunder/benchmarks/distributed.py @@ -12,9 +12,8 @@ NanoGPTConfig, NanoGPTBenchmark, LitGPTBenchmark, - LitGPTConfig, ) -from thunder.tests.litgpt_model import name_to_config +from thunder.tests.litgpt_model import name_to_config, Config as LitGPTConfig from thunder.distributed import FSDPBucketingStrategy from thunder.distributed import FSDPType From b28d5b3536e60fb0b30896bdd4df6e288cf6a5c8 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Tue, 29 Oct 2024 10:45:38 +0100 Subject: [PATCH 09/10] grad rule for copy_with_setitem (#1322) --- thunder/core/transforms.py | 20 ++++++++++++ thunder/tests/test_ops.py | 62 ++++++++++++++++++++++++++++++++------ 2 files changed, 72 insertions(+), 10 deletions(-) diff --git a/thunder/core/transforms.py b/thunder/core/transforms.py index 13437d488e..116e4094ed 100644 --- a/thunder/core/transforms.py +++ b/thunder/core/transforms.py @@ -1414,6 +1414,26 @@ def _maximum_grad(a: TensorProxy, b: TensorProxy, /): # This operation creates no grad associations register_grad(pids.SHAPE, prims.shape) + +def _copy_with_setitem_grad(a: TensorProxy, index, value: Number | TensorProxy): + fwd = prims.copy_with_setitem(a, index, value) + g = get_grad(fwd) + + a_grad = prims.copy_with_setitem(g, index, 0) + put_grad(a, a_grad) + + if isinstance(value, TensorProxy): + value_grad = g[index] + expanded_dims = value_grad.ndim - value.ndim + if expanded_dims > 0: + value_grad = prims.sum(value_grad, tuple(range(expanded_dims))) + put_grad(value, value_grad) + + return fwd + + +register_grad(pids.COPY_WITH_SETITEM, _copy_with_setitem_grad) + # # Phantom grad transform helpers # diff --git a/thunder/tests/test_ops.py b/thunder/tests/test_ops.py index a92f9650a5..a588e94f3b 100644 --- a/thunder/tests/test_ops.py +++ b/thunder/tests/test_ops.py @@ -239,18 +239,60 @@ def foo(): tfoo() -def test_setitem(): - def fn(a): - a[:3] = 2 +@pytest.mark.parametrize("requires_grad", (True, False)) +def test_setitem(requires_grad): + + def _test_forward_and_backward(fn, a, value): + a_ref = a.detach().clone() + a_ref.requires_grad_(a.requires_grad) + + if isinstance(value, torch.Tensor): + value_ref = value.detach().clone() + value_ref.requires_grad_(value.requires_grad) + else: + value_ref = value + + out_ref = fn(a_ref, value_ref) + jf = thunder.jit(fn) + out = jf(a, value) + assert_close(a, a_ref) + assert_close(out, out_ref) + + if requires_grad: + g = torch.randn_like(out) + inputs = (a, value) if isinstance(value, torch.Tensor) else (a,) + actual_grad = torch.autograd.grad(out, inputs, g) + + inputs_ref = (a_ref, value_ref) if isinstance(value, torch.Tensor) else (a_ref,) + expected_grad = torch.autograd.grad(out_ref, inputs_ref, g) + assert_close(actual_grad, expected_grad) + + def clone_if_requires_grad(a): + if requires_grad: + # Withou the clone + # PyTorch eager errors with + # `RuntimeError: a view of a leaf Variable that requires grad is being used in an in-place operation.` + # and thunder has silent correctness issue - https://github.com/Lightning-AI/lightning-thunder/issues/1284 + return a.clone() + return a + + def fn(a, value): + a = clone_if_requires_grad(a) + a[:3] = value return a * 2 - a_ref = torch.ones(5) - out_ref = fn(a_ref) - a = torch.ones(5) - jf = thunder.jit(fn) - out = jf(a) - assert_close(a, a_ref) - assert_close(out, out_ref) + # set value: scalar + _test_forward_and_backward(fn, torch.randn(5, requires_grad=requires_grad), 2.0) + + # set value: tensor which needs to be broadcasted + _test_forward_and_backward( + fn, torch.randn(5, requires_grad=requires_grad), torch.tensor(2.0, requires_grad=requires_grad) + ) + + # set value: tensor of same rank + _test_forward_and_backward( + fn, torch.randn(5, requires_grad=requires_grad), torch.tensor([1.0, 2.0, 3.0], requires_grad=requires_grad) + ) # TODO: Add random operator support to OpInfo From 9c916d9df73f3920b51e5951303a76b25ab2d4d4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 30 Oct 2024 01:01:08 +0200 Subject: [PATCH 10/10] Skip importing Apex if torch distributed is not available (#1359) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- thunder/executors/apex_fused_rms_norm_impl.py | 6 ++++++ thunder/tests/test_apex_fused_norms.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/thunder/executors/apex_fused_rms_norm_impl.py b/thunder/executors/apex_fused_rms_norm_impl.py index e97a657b88..6fb6ac55fe 100644 --- a/thunder/executors/apex_fused_rms_norm_impl.py +++ b/thunder/executors/apex_fused_rms_norm_impl.py @@ -14,6 +14,12 @@ APEX_FUSED_NORMS_AVAILABLE = True try: + # Fused layer norm is only importable if torch.distributed is available + # https://github.com/NVIDIA/apex/issues/1853 + from torch.distributed import is_available + + if not is_available(): + raise ImportError import fused_layer_norm_cuda from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction except ImportError: diff --git a/thunder/tests/test_apex_fused_norms.py b/thunder/tests/test_apex_fused_norms.py index 57ae6637b1..d2a4734863 100644 --- a/thunder/tests/test_apex_fused_norms.py +++ b/thunder/tests/test_apex_fused_norms.py @@ -3,14 +3,19 @@ from torch.testing import assert_close fused_layer_norm_cuda = pytest.importorskip("fused_layer_norm_cuda") -from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + +from torch.distributed import is_available from thunder.executors.apexex import apex_ex import thunder +# See https://github.com/NVIDIA/apex/issues/1853 +@pytest.mark.skipif(not is_available(), reason="torch.distributed is not available") @pytest.mark.parametrize("requires_grad", [True, False]) @pytest.mark.parametrize("memory_efficient", [True, False]) def test_apex_fused_rms_norm(requires_grad, memory_efficient): + from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + def fn(x, weight, normalized_shape, eps): return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient) @@ -34,9 +39,13 @@ def fn(x, weight, normalized_shape, eps): assert_close(actual_grad, expected_grad) +# See https://github.com/NVIDIA/apex/issues/1853 +@pytest.mark.skipif(not is_available(), reason="torch.distributed is not available") @pytest.mark.parametrize("requires_grad", [True, False]) @pytest.mark.parametrize("memory_efficient", [True, False]) def test_apex_fused_rms_norm_autoregister(requires_grad, memory_efficient): + from apex.normalization.fused_layer_norm import FusedRMSNormAffineMixedDtypesFunction + def fn(x, weight, normalized_shape, eps): return FusedRMSNormAffineMixedDtypesFunction.apply(x, weight, normalized_shape, eps, memory_efficient)