diff --git a/thunder/tests/test_auto_register_torchops.py b/thunder/tests/test_auto_register_torchops.py index f20fa1daf1..389b1e2c85 100644 --- a/thunder/tests/test_auto_register_torchops.py +++ b/thunder/tests/test_auto_register_torchops.py @@ -107,7 +107,7 @@ def test_torch_ops_trace(device, requires_grad): list(tmp2.pop(k.torch_reference, None) for k in disable_opinfos) tmp3 = dict(thunder.core.jit_ext._minimal_lookaside_map) list(tmp3.pop(k.torch_reference, None) for k in disable_opinfos) -from thunder.torch import meta_adaptor, register_default_torch_op +from thunder.torch import register_default_torch_op # mock all the global variables that are modified during registration @@ -143,9 +143,9 @@ def test_nanogpt_block(self): for op in skip_ops_nanogpt: if op.name == "gelu": - register_default_torch_op(op.torch_reference, meta_adaptor(op.torch_reference), torch) + register_default_torch_op(op.torch_reference, torch) else: - register_default_torch_op(op.torch_reference, meta_adaptor(op.torch_reference), torch.nn.functional) + register_default_torch_op(op.torch_reference, torch.nn.functional) self._tmp_update_jit_lookup(op.torch_reference) tdtype = torch.float32 device = torch.device("cuda") @@ -169,7 +169,7 @@ def test_alexnet(self): torchvision = pytest.importorskip("torchvision") for op in skip_ops_alexnet: - register_default_torch_op(op.torch_reference, meta_adaptor(op.torch_reference), torch.nn.functional) + register_default_torch_op(op.torch_reference, torch.nn.functional) self._tmp_update_jit_lookup(op.torch_reference) tdtype = torch.float32 device = torch.device("cuda")