From 0ab8a45d7f9ec0c484562fdc4fb407120c093dfb Mon Sep 17 00:00:00 2001 From: Yan Wang Date: Wed, 26 Jun 2024 13:03:30 +0200 Subject: [PATCH] Test for resnet18 accuracy (#645) --- .../tests/test_inplace_functionalization.py | 45 ++++++++++++++----- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/thunder/tests/test_inplace_functionalization.py b/thunder/tests/test_inplace_functionalization.py index 01010cd029..7879a305d9 100644 --- a/thunder/tests/test_inplace_functionalization.py +++ b/thunder/tests/test_inplace_functionalization.py @@ -7,9 +7,11 @@ import pytest import torch.testing +import thunder +import thunder.core.devices as devices from thunder.core import dtypes from thunder.core.prims import PrimIDs -from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING +from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING, TorchExecutor from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput from thunder.tests.make_tensor import make_tensor from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place @@ -137,31 +139,52 @@ def turn_off_tf32_and_set_seed(monkeypatch): torch.manual_seed(42) +@instantiate( + dtypes=(thunder.float32, thunder.float64), + devicetypes=(devices.DeviceType.CUDA,), + decorators=(pytest.mark.parametrize("train", (False, True)),), +) @requiresCUDA -@pytest.mark.parametrize("train", (False, True)) -def test_parse_resnet18(turn_off_tf32_and_set_seed, train: bool): +def test_parse_resnet18(executor, device, dtype, turn_off_tf32_and_set_seed, train: bool): from contextlib import nullcontext + import thunder torchvision = pytest.importorskip("torchvision") - device = torch.device("cuda") - dtype = torch.float32 - with device: - model = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) - ref_model = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype) + tdtype = thunder.torch.to_torch_dtype(dtype) + model = torchvision.models.resnet18(weights=None).to(device=device, dtype=tdtype) + ref_model = torchvision.models.resnet18(weights=None).to(device=device, dtype=tdtype) if not train: model = model.eval() ref_model = ref_model.eval() ctx = torch.no_grad else: + model = model.train() + ref_model = ref_model.train() ctx = nullcontext ref_model.load_state_dict(model.state_dict()) - jitted = thunder.jit(model) - x = make_tensor((1, 3, 224, 224), dtype=dtype, device=device) + jitted = executor.make_callable(model) + x = make_tensor((1, 3, 224, 224), dtype=tdtype, device=device) + with ctx(): - torch.testing.assert_close(jitted(x), ref_model(x)) + out1 = ref_model(x) + out2 = jitted(x) + torch.testing.assert_close(out1, out2) + # Backward fails with nvfuserExecutor, RuntimeError: Unsupported iterable object type for define_vector! Index:0 + # Numerical accuracy error when TorchExecutor, `train=True` and dtype is fp32. + # with RTX6000 Ada and CUDA 12.3, I see somewhat huge error: + # E AssertionError: Tensor-likes are not close! + # E + # E Mismatched elements: 9401 / 9408 (99.9%) + # E Greatest absolute difference: 0.07035164535045624 at index (4, 1, 0, 3) (up to 1e-05 allowed) + # E Greatest relative difference: 343.7076110839844 at index (5, 0, 5, 4) (up to 1.3e-06 allowed) + # E The failure occurred for item [0] + if train and executor == TorchExecutor and dtype == thunder.float64: + torch_grads = torch.autograd.grad(out1, ref_model.parameters(), torch.ones_like(out1)) + thunder_grads = torch.autograd.grad(out2, jitted.parameters(), torch.ones_like(out2)) + torch.testing.assert_close(torch_grads, thunder_grads) @instantiate(