Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automatic registration of torch operators using FakeTensor #554

Merged
merged 58 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
06df3cd
Experiments of auto fallback to torch operator using FakeTensor
kiya00 Jun 7, 2024
84f786b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 7, 2024
b884970
add vjp
kiya00 Jun 11, 2024
c27b942
fix
kiya00 Jun 11, 2024
f1b7147
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 11, 2024
6fe0259
fixi rebase: add Function types in supported tree_flatten types
kiya00 Jul 3, 2024
8ba119a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 3, 2024
56da491
clean up
kiya00 Jul 5, 2024
8bd9898
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2024
81a3dc8
fix: to_printable dead loop when torch.Size; add type for tree_flatten
kiya00 Jul 9, 2024
03e4476
fix
kiya00 Jul 9, 2024
3ae5980
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2024
23bc30b
Fix: TypeError: tree_flatten of type <class 'method_descriptor'> is n…
kiya00 Jul 12, 2024
5a10851
Fix: add torch.Tensor.ops; catch vjp exception; ret None when output …
kiya00 Jul 12, 2024
f325be7
add test
kiya00 Jul 12, 2024
f0a75b7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2024
a1e6120
add test
kiya00 Jul 15, 2024
c1ed22b
exclude inplace ops
kiya00 Jul 16, 2024
fcf5e5c
fix test using mock
kiya00 Jul 16, 2024
bb61fc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 16, 2024
b8ca656
tmp rm test
kiya00 Jul 16, 2024
377b8cb
add more supported types
kiya00 Jul 17, 2024
30a15bb
Use torch.overrides.get_overridable_functions to decide which op need…
kiya00 Jul 18, 2024
b666702
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 18, 2024
821fb8d
add enable_fallback_to_torch compile option, move registration into g…
kiya00 Jul 18, 2024
2b30a7c
recover test
kiya00 Jul 18, 2024
421929b
Put the operator list in a file and use static registration
kiya00 Jul 24, 2024
79b6b83
modify tests, too slow
kiya00 Jul 24, 2024
40fec06
fix the backward output: add grad=none for non-differential input, us…
kiya00 Jul 25, 2024
019d63e
tmp tests
kiya00 Jul 25, 2024
98be150
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 25, 2024
277c937
fix comments: use capture torchfn instead of passing it in residule
kiya00 Jul 25, 2024
1f8f357
Fix test
kiya00 Jul 25, 2024
b1017ce
Fix test: use first 5 sample inputs to run faster
kiya00 Jul 29, 2024
cd6ff05
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 30, 2024
bc1d0fc
Apply suggestions from code review
kiya00 Jul 31, 2024
6d9b3d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 31, 2024
453343a
Print out auto registered ops in examine, check there's no overlap be…
kiya00 Jul 31, 2024
3c315b4
fix for comments
kiya00 Jul 31, 2024
39d74f5
Fix comments and enable test cases using LAPACK
kiya00 Aug 1, 2024
f7f2b83
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 1, 2024
98740f4
rm some ops, ipu, mkldnn etc.
kiya00 Aug 1, 2024
623b1f5
fix comments
kiya00 Aug 1, 2024
d2c9129
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 1, 2024
9330aff
Merge branch 'main' into faketensor
kiya00 Aug 1, 2024
e2939a3
rebuild
kiya00 Aug 2, 2024
df4a244
fix test: skipCPUIfNoLapack
kiya00 Aug 2, 2024
af1c7e5
Apply suggestions from code review
kiya00 Aug 2, 2024
7934257
fix comments
kiya00 Aug 2, 2024
12404dc
Apply suggestions from code review
kiya00 Aug 5, 2024
5770e6f
fix comments
kiya00 Aug 5, 2024
5a8f463
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 5, 2024
45bed89
Merge branch 'main' into faketensor
kiya00 Aug 5, 2024
445147d
fix
kiya00 Aug 5, 2024
fc751ea
fix tests
kiya00 Aug 5, 2024
a57bde3
reBuild
kiya00 Aug 5, 2024
305015f
Merge branch 'main' into faketensor
kiya00 Aug 6, 2024
62f4d10
Merge branch 'main' into faketensor
kiya00 Aug 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions thunder/core/pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,13 @@ def tree_flatten(args, namespace=""):
torch.Size,
torch.finfo,
dtypes.signedinteger,
# FakeTensor type is used for automatic registration of torch ops
torch._subclasses.fake_tensor.FakeTensor,
torch.device,
kiya00 marked this conversation as resolved.
Show resolved Hide resolved
}
and not isinstance(args, (ProxyInterface))
and not dataclasses.is_dataclass(args)
and not type(args).__module__.startswith("torch.return_types")
):
raise TypeError(f"tree_flatten of type {type(args)} is not supported.")
return optree.tree_flatten(args, none_is_leaf=True, namespace=namespace)
Expand Down
10 changes: 10 additions & 0 deletions thunder/examine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from thunder.core.proxies import TensorProxy
from thunder.core.symbol import BoundSymbol
from thunder.torch import _torch_to_thunder_function_map
from thunder.torch.default_torch_ops import torch_auto_registered_ops
from thunder.core.langctxs import resolve_language, LanguageContext, Languages
import torch
from warnings import warn
from itertools import chain


# TODO Maybe make collect_into a set?
Expand Down Expand Up @@ -73,9 +75,13 @@ def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs):

# Step 1 Identifies supported (and unsupported) operations
supported_ops = set()
all_auto_registered_ops = list(chain(*torch_auto_registered_ops.values()))
auto_registered_ops = set()
for name, op in collected_ops.keys():
if op in _torch_to_thunder_function_map:
supported_ops.add((name, op))
if op in all_auto_registered_ops:
auto_registered_ops.add(name)
elif name.startswith("_TensorBase.") or name.startswith("TensorBase.") or name.startswith("Tensor."):
# Identifies properties and methods
# NOTE The approach of testing if the name starts with "_TensorBase." or "Tensor." seems a little hacky
Expand Down Expand Up @@ -111,6 +117,10 @@ def examine(fn: Callable, *args, show_call_stack: bool | int = False, **kwargs):
print(
f"Found {len(collected_ops)} distinct operations, of which {len(supported_ops)} ({len(supported_ops) / len(collected_ops) * 100:.1f}%) are supported"
)
if len(auto_registered_ops) != 0:
print(f"Note {len(auto_registered_ops)} operators are automatically registered: ")
for n in auto_registered_ops:
print(n)

# Terminates early if there are unsupported operations or there was a preprocessing exception
if len(unsupported_ops) > 0:
Expand Down
182 changes: 182 additions & 0 deletions thunder/tests/test_auto_register_torchops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from functools import partial
from unittest.mock import patch

import pytest
import thunder
import thunder.torch.default_torch_ops as ops
import torch

from thunder.tests.framework import requiresCUDA, TorchExecutor
from thunder.tests.make_tensor import make_tensor
from thunder.tests.opinfos import get_opinfo, OpInfo
from thunder.tests.test_einops import skipIfNoCUDA
from torch.testing._internal.common_device_type import skipCPUIfNoLapack
from torch.testing._internal.common_methods_invocations import op_db

_name2func = {}
[_name2func.setdefault(v.__name__, v) for v in ops.torch_auto_registered_ops[torch]]
[_name2func.setdefault(f"nn.functional.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.nn.functional]]
# Use the sample input from torch.xx to test torch.tensor.xx
[_name2func.setdefault(f"Tensor.{v.__name__}", v) for v in ops.torch_auto_registered_ops[torch.Tensor]]
_opinfos = [opinfo for opinfo in op_db if opinfo.name in _name2func]


# Note that successfully catching an exception in this test is also noted as passed
@skipIfNoCUDA
@pytest.mark.parametrize("op_info,", _opinfos, ids=list(map(lambda opinfo: opinfo.name, _opinfos)))
@pytest.mark.parametrize("requires_grad", [True, False], ids=("train", "inference"))
@pytest.mark.parametrize("device,", ["cuda", "cpu"])
def test_torch_ops_trace(device, requires_grad, op_info):
from itertools import islice

# for op_info in op_infos:
if device == "cuda" and torch.float32 not in op_info.dtypesIfCUDA:
return
if device == "cpu" and not torch.float32 in op_info.dtypes:
return
# No cuda backend support
if op_info.name in ("nonzero_static",) and device == "cuda":
return
if device == "cpu" and not torch._C.has_lapack and skipCPUIfNoLapack in op_info.decorators:
return
funcs = [_name2func[op_info.name], _name2func.get(f"Tensor.{op_info.name}", None)]
for func in funcs:
if func is None:
continue
# It takes too long, test only the first 5 sample inputs
gen = islice(
op_info.sample_inputs_func(
op_info, device=torch.device(device), dtype=torch.float32, requires_grad=requires_grad
),
5,
)
for sample in gen:
try:
jfun = thunder.jit(func)
out = jfun(sample.input, *sample.args, **sample.kwargs)
except Exception as e:
assert isinstance(e, NotImplementedError)
assert str(e).startswith(f"Exception encountered when doing automatic registration") or str(
e
).startswith(f"Unsupported type:")
break
else:
if requires_grad:
trc = thunder.last_backward_traces(jfun)[-1]
fwd_trc = thunder.last_traces(jfun)[-1]
# skip if it is not differentiable
outs = fwd_trc.output[0]["output"]
outs = outs if isinstance(outs, tuple) else (outs,)
if all(not thunder.core.dtypes.is_inexact_dtype(o.dtype) for o in outs):
continue
vjp_op_name = f"{op_info.name.split('.')[-1]}_vjp"
if op_info.name == "mm":
assert any(bsym.sym.name.endswith(vjp_op_name) for bsym in trc.bound_symbols)
else:
assert any(bsym.sym.name == vjp_op_name for bsym in trc.bound_symbols)
else:
fwd_trc = thunder.last_traces(jfun)[-1]
assert any(
bsym.sym.name.endswith(op_info.name.split(".")[-1]) and not bsym.subsymbols
for bsym in fwd_trc.bound_symbols
)


# Replace manual registration of some operations with automatic registration for network test cases
_skip_ops_nanogpt = [
get_opinfo("layer_norm"),
get_opinfo("linear"),
get_opinfo("gelu"),
get_opinfo("scaled_dot_product_attention"),
]
_skip_ops_alexnet = [
get_opinfo("conv2d"),
get_opinfo("linear"),
get_opinfo("adaptive_avg_pool2d"),
get_opinfo("max_pool2d"),
]
_disable_opinfos = _skip_ops_nanogpt + _skip_ops_alexnet
_tmp_general_jit_lookaside_map = dict(thunder.core.jit_ext._general_jit_lookaside_map)
list(_tmp_general_jit_lookaside_map.pop(k.torch_reference, None) for k in _disable_opinfos)
_tmp_torch_to_thunder_function_map = dict(thunder.torch._torch_to_thunder_function_map)
list(_tmp_torch_to_thunder_function_map.pop(k.torch_reference, None) for k in _disable_opinfos)
_tmp_minimal_lookaside_map = dict(thunder.core.jit_ext._minimal_lookaside_map)
list(_tmp_minimal_lookaside_map.pop(k.torch_reference, None) for k in _disable_opinfos)
from thunder.torch import register_default_torch_op


# mock all the global variables that are modified during registration
@patch.dict(thunder.core.jit_ext._general_jit_lookaside_map, _tmp_general_jit_lookaside_map, clear=True)
@patch.dict(thunder.torch._torch_to_thunder_function_map, _tmp_torch_to_thunder_function_map, clear=True)
@patch.dict(thunder.core.jit_ext._minimal_lookaside_map, _tmp_minimal_lookaside_map, clear=True)
@patch.dict(thunder.executors.torchex.ex._implmap, {})
@patch.dict(thunder.executors.torchex.ex._opmap, {})
@patch.dict(thunder.core.transforms.augmented_forward_impls, {})
@patch.dict(thunder.core.transforms.backward_impls, {})
class TestFallbackToTorch:
def _tmp_update_jit_lookup(self, torchfn):
from thunder.core.interpreter import interpreter_needs_wrap
from thunder.core.jit_ext import (
_general_jit_lookaside_map,
ensure_recursive_proxies,
record_source_loc_in_symbol_header,
)

_general_jit_lookaside_map.update(
{
torchfn: ensure_recursive_proxies(
interpreter_needs_wrap(
record_source_loc_in_symbol_header(thunder.torch._torch_to_thunder_function_map[torchfn])
)
)
}
)

@requiresCUDA
def test_nanogpt_block(self):
import thunder.tests.nanogpt_model as nanogpt_model

for op in _skip_ops_nanogpt:
if op.name == "gelu":
register_default_torch_op(op.torch_reference, torch)
else:
register_default_torch_op(op.torch_reference, torch.nn.functional)
self._tmp_update_jit_lookup(op.torch_reference)
tdtype = torch.float32
device = torch.device("cuda")
executor = TorchExecutor
make = partial(make_tensor, dtype=tdtype, device=device)

config = nanogpt_model.GPTConfig(dropout=0)
model = nanogpt_model.Block(config).to(device=device, dtype=tdtype)
jitted = executor.make_callable(model)

x = make((2, config.block_size, config.n_embd))

cache_entry, _, _ = thunder.compile_data(jitted).get_computation_and_inputs(x)
bwd_trcs = cache_entry.backward_traces
for op in _skip_ops_nanogpt:
vjp_op_name = f"{op.name}_vjp"
assert any(bsym.sym.name == vjp_op_name for bsym in bwd_trcs[-1].bound_symbols)

@requiresCUDA
def test_alexnet(self):
torchvision = pytest.importorskip("torchvision")

for op in _skip_ops_alexnet:
register_default_torch_op(op.torch_reference, torch.nn.functional)
self._tmp_update_jit_lookup(op.torch_reference)
tdtype = torch.float32
device = torch.device("cuda")
model = torchvision.models.alexnet(weights=None).to(device=device, dtype=tdtype)
model = model.train()

executor = TorchExecutor
jitted = executor.make_callable(model)
x = make_tensor((1, 3, 224, 224), dtype=tdtype, device=device)

cache_entry, _, _ = thunder.compile_data(jitted).get_computation_and_inputs(x)
bwd_trcs = cache_entry.backward_traces
for op in _skip_ops_alexnet:
vjp_op_name = f"{op.name}_vjp"
assert any(bsym.sym.name == vjp_op_name for bsym in bwd_trcs[-1].bound_symbols)
6 changes: 3 additions & 3 deletions thunder/tests/test_jit_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def no_error(x):
# randn_like is in ltorch
return torch.randn_like(x)

def should_error(x):
def should_error(x, y):
# rand_like is not yet in ltroch
return torch.rand_like(x)
return torch.allclose(x, y)

x = torch.rand(1)

Expand All @@ -68,7 +68,7 @@ def should_error(x):

jshould_error = thunder.jit(should_error)
with pytest.raises(NotImplementedError):
jshould_error(x)
jshould_error(x, x)


def test_binary_add_tensors():
Expand Down
Loading
Loading