Skip to content

Commit

Permalink
Test for resnet18 accuracy (#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 authored Jun 26, 2024
1 parent f93e2bf commit 0ab8a45
Showing 1 changed file with 34 additions and 11 deletions.
45 changes: 34 additions & 11 deletions thunder/tests/test_inplace_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import pytest
import torch.testing

import thunder
import thunder.core.devices as devices
from thunder.core import dtypes
from thunder.core.prims import PrimIDs
from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING
from thunder.tests.framework import instantiate, ops, requiresCUDA, NOTHING, TorchExecutor
from thunder.tests.opinfos import opinfos, OpInfo, make_number, SampleInput
from thunder.tests.make_tensor import make_tensor
from thunder.torch import _torch_to_thunder_function_map, _inplace_to_out_of_place
Expand Down Expand Up @@ -137,31 +139,52 @@ def turn_off_tf32_and_set_seed(monkeypatch):
torch.manual_seed(42)


@instantiate(
dtypes=(thunder.float32, thunder.float64),
devicetypes=(devices.DeviceType.CUDA,),
decorators=(pytest.mark.parametrize("train", (False, True)),),
)
@requiresCUDA
@pytest.mark.parametrize("train", (False, True))
def test_parse_resnet18(turn_off_tf32_and_set_seed, train: bool):
def test_parse_resnet18(executor, device, dtype, turn_off_tf32_and_set_seed, train: bool):
from contextlib import nullcontext

import thunder

torchvision = pytest.importorskip("torchvision")

device = torch.device("cuda")
dtype = torch.float32
with device:
model = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype)
ref_model = torchvision.models.resnet18(weights=None).to(device=device, dtype=dtype)
tdtype = thunder.torch.to_torch_dtype(dtype)
model = torchvision.models.resnet18(weights=None).to(device=device, dtype=tdtype)
ref_model = torchvision.models.resnet18(weights=None).to(device=device, dtype=tdtype)
if not train:
model = model.eval()
ref_model = ref_model.eval()
ctx = torch.no_grad
else:
model = model.train()
ref_model = ref_model.train()
ctx = nullcontext
ref_model.load_state_dict(model.state_dict())

jitted = thunder.jit(model)
x = make_tensor((1, 3, 224, 224), dtype=dtype, device=device)
jitted = executor.make_callable(model)
x = make_tensor((1, 3, 224, 224), dtype=tdtype, device=device)

with ctx():
torch.testing.assert_close(jitted(x), ref_model(x))
out1 = ref_model(x)
out2 = jitted(x)
torch.testing.assert_close(out1, out2)
# Backward fails with nvfuserExecutor, RuntimeError: Unsupported iterable object type for define_vector! Index:0
# Numerical accuracy error when TorchExecutor, `train=True` and dtype is fp32.
# with RTX6000 Ada and CUDA 12.3, I see somewhat huge error:
# E AssertionError: Tensor-likes are not close!
# E
# E Mismatched elements: 9401 / 9408 (99.9%)
# E Greatest absolute difference: 0.07035164535045624 at index (4, 1, 0, 3) (up to 1e-05 allowed)
# E Greatest relative difference: 343.7076110839844 at index (5, 0, 5, 4) (up to 1.3e-06 allowed)
# E The failure occurred for item [0]
if train and executor == TorchExecutor and dtype == thunder.float64:
torch_grads = torch.autograd.grad(out1, ref_model.parameters(), torch.ones_like(out1))
thunder_grads = torch.autograd.grad(out2, jitted.parameters(), torch.ones_like(out2))
torch.testing.assert_close(torch_grads, thunder_grads)


@instantiate(
Expand Down

0 comments on commit 0ab8a45

Please sign in to comment.