Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
kiya00 committed Aug 5, 2024
1 parent 45bed89 commit 445147d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def test_torch_ops_trace(device, requires_grad):
list(tmp2.pop(k.torch_reference, None) for k in disable_opinfos)
tmp3 = dict(thunder.core.jit_ext._minimal_lookaside_map)
list(tmp3.pop(k.torch_reference, None) for k in disable_opinfos)
from thunder.torch import meta_adaptor, register_default_torch_op
from thunder.torch import register_default_torch_op


# mock all the global variables that are modified during registration
Expand Down Expand Up @@ -143,9 +143,9 @@ def test_nanogpt_block(self):

for op in skip_ops_nanogpt:
if op.name == "gelu":
register_default_torch_op(op.torch_reference, meta_adaptor(op.torch_reference), torch)
register_default_torch_op(op.torch_reference, torch)
else:
register_default_torch_op(op.torch_reference, meta_adaptor(op.torch_reference), torch.nn.functional)
register_default_torch_op(op.torch_reference, torch.nn.functional)
self._tmp_update_jit_lookup(op.torch_reference)
tdtype = torch.float32
device = torch.device("cuda")
Expand All @@ -169,7 +169,7 @@ def test_alexnet(self):
torchvision = pytest.importorskip("torchvision")

for op in skip_ops_alexnet:
register_default_torch_op(op.torch_reference, meta_adaptor(op.torch_reference), torch.nn.functional)
register_default_torch_op(op.torch_reference, torch.nn.functional)
self._tmp_update_jit_lookup(op.torch_reference)
tdtype = torch.float32
device = torch.device("cuda")
Expand Down

0 comments on commit 445147d

Please sign in to comment.