diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 0eb591c2f8..0e005fe4b5 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -2,6 +2,7 @@ from collections.abc import Callable from dataclasses import dataclass from functools import partial +from typing import TYPE_CHECKING import pytest import torch.testing @@ -13,6 +14,9 @@ from thunder.tests.make_tensor import make_tensor from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place +if TYPE_CHECKING: + from thunder.core.symbol import Symbol + # `SampleInput`s of ops with `inplace` argument do not seem to come with `inplace` arg, so give it to them. def sample_generator_wrapper(sample_generator): @@ -36,7 +40,7 @@ def inplace_masked_fill_sample_generator(op, device, dtype, requires_grad, **kwa yield SampleInput(a, pred, value) -_torchsymbol_to_torch: dict[Sybmol, Callable] = {v: k for k, v in _torch_to_thunder_function_map.items()} +_torchsymbol_to_torch: dict[Symbol, Callable] = {v: k for k, v in _torch_to_thunder_function_map.items()} _functional_to_inplace: dict[Callable, Callable] = { functional: inplace for inplace, (functional, index) in _inplace_to_out_of_place.items() if index == -1 } @@ -125,33 +129,39 @@ def test_functionalization(op: OpInfo, device: str, dtype: dtypes.dtype, executo ) -# TODO(crcrpar): Investigate the numerical accuracy when `train=True` and dtype is fp32. -# with RTX6000 Ada and CUDA 12.3, I see somewhat huge error: -# E AssertionError: Tensor-likes are not close! -# E -# E Mismatched elements: 913 / 1000 (91.3%) -# E Greatest absolute difference: 0.000273287296295166 at index (0, 50) (up to 1e-05 allowed) -# E Greatest relative difference: 0.4177769422531128 at index (0, 727) (up to 1.3e-06 allowed) +@pytest.fixture +def turn_off_tf32_and_set_seed(monkeypatch): + import torch + + monkeypatch.setenv("NVIDIA_TF32_OVERRIDE", "0") + torch.manual_seed(42) + + @requiresCUDA @pytest.mark.parametrize("train", (False, True)) -def test_parse_resnet18(train: bool): +def test_parse_resnet18(turn_off_tf32_and_set_seed, train: bool): + from contextlib import nullcontext import thunder torchvision = pytest.importorskip("torchvision") device = torch.device("cuda") - dtype = torch.float64 if train else torch.float32 + dtype = torch.float32 with device: - model: nn.Module = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) - ref_model: nn.Module = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) + model = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) + ref_model = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) if not train: model = model.eval() ref_model = ref_model.eval() + ctx = torch.no_grad + else: + ctx = nullcontext ref_model.load_state_dict(model.state_dict()) jitted = thunder.jit(model) x = make_tensor((1, 3, 224, 224), dtype=dtype, device=device) - torch.testing.assert_close(jitted(x), ref_model(x)) + with ctx(): + torch.testing.assert_close(jitted(x), ref_model(x)) @instantiate(