Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchao float8tensor] #1415

Draft
wants to merge 101 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
ae4a28b
workaround for `.__init__` call on the output of `_make_wrapper_subcl…
crcrpar Nov 2, 2024
b0a4710
the test case works
crcrpar Nov 2, 2024
bd01157
attribute access to subclass proxy seems functioning
crcrpar Nov 3, 2024
ac0d9fe
simplify if-else in `SubclassTensorProxy.__init__`
crcrpar Nov 3, 2024
d79afcc
stricter type check of tensors
crcrpar Nov 3, 2024
2c06edd
support `MySubclass(...)` called inside of `torch.autograd.Function`
crcrpar Nov 5, 2024
8c0f39e
explanation
crcrpar Nov 5, 2024
5ae4d2e
failing test case as starter
crcrpar Nov 3, 2024
2e6e008
add path of SubclassTensorProxy in `tensorproxy`
crcrpar Nov 5, 2024
2a99349
add no-op tensor subclass transform
crcrpar Nov 5, 2024
8d677ee
transfer #1345
crcrpar Nov 5, 2024
22b007b
make the subclass check more meticulous
crcrpar Nov 6, 2024
3e1d4b0
fake_tensor.foo -> foo
crcrpar Nov 6, 2024
5f1d65e
simplify `subclass_proxy_to_flatten`
crcrpar Nov 6, 2024
e2b3b43
handle `PrimIDs.RETURN` earlier
crcrpar Nov 6, 2024
aa4351c
give created subclass required attributes
crcrpar Nov 7, 2024
1d9e5a3
remove `subclass_type_to_attr_names`
crcrpar Nov 7, 2024
335df5c
remove `requires_desugarring`
crcrpar Nov 7, 2024
409798d
import cleanup
crcrpar Nov 7, 2024
e9027ba
avoid flattening non-tensor args of subclass ctor
crcrpar Nov 19, 2024
ca67d6f
add path of SubclassTensorProxy in `tensorproxy`
crcrpar Nov 5, 2024
6e3c8b2
phase 1 for backward test
crcrpar Nov 7, 2024
4cded75
check backward is runnable with subclass arguments
crcrpar Nov 7, 2024
5daa159
bwd run with tensor creation inside of trace
crcrpar Nov 7, 2024
d8ce1b1
flatten Function.apply of converter
crcrpar Nov 7, 2024
f7b4976
torchao small test
crcrpar Nov 8, 2024
76177fa
placeholder-ish attributes/methods for `_make_wrapper_subclass`
crcrpar Nov 8, 2024
4cb28cd
[autograd.Function lookaside] `dce` to wipe out redundant bsyms
crcrpar Nov 8, 2024
ca51fdb
Give unpack bsyms to traces generated inside `Function` lookaside
crcrpar Nov 8, 2024
b504417
some tweaks
crcrpar Nov 8, 2024
141cba0
revert pytree changes
crcrpar Nov 21, 2024
a1c6471
imports for tensor subclass ctor
crcrpar Nov 21, 2024
16fde2f
define bind-postprocess
crcrpar Nov 22, 2024
2ece401
xfail, for now
crcrpar Nov 22, 2024
778d8a0
fix type set creation and add bsym postprocess for torchex
crcrpar Nov 24, 2024
52a9c1b
printer translating thunder dtype/device to torchs
crcrpar Nov 24, 2024
c474c2d
meticulously set import_ctx of cls
crcrpar Nov 24, 2024
3cf828d
dry
crcrpar Nov 24, 2024
8082fc3
test failure info update
crcrpar Nov 24, 2024
3eef261
cosmetic
crcrpar Nov 24, 2024
6679d9a
better repr & type string
crcrpar Nov 25, 2024
c026711
use `transpose` instead of `permute`
crcrpar Nov 25, 2024
75a03ef
better typestring
crcrpar Nov 25, 2024
b46f24f
num bsyms check
crcrpar Nov 25, 2024
9db718e
allow tensor subclasses with non-empty metadata
crcrpar Nov 25, 2024
48958ab
bsyms is a list inside `trace_from_bsym_or_bsyms`
crcrpar Nov 25, 2024
373b39f
the typestring has syntactic mistake; remove for now
crcrpar Nov 25, 2024
75a0423
better error message for missing support of ops
crcrpar Nov 25, 2024
8122c77
add `torch._scaled_mm` to auto register
crcrpar Nov 25, 2024
38530ff
update error message of missing op
crcrpar Nov 26, 2024
eaaa012
tree_flatten tensor subclass metadata values
crcrpar Nov 26, 2024
3790e42
make error msg verbose
crcrpar Nov 26, 2024
4e83361
better effor message for failing map from fx node to ltorch op
crcrpar Nov 26, 2024
1078712
better error message
crcrpar Nov 26, 2024
f817a0b
register scaled_mm
crcrpar Nov 26, 2024
fe57ddb
note where new bsyms come from, especially torch dispatch
crcrpar Nov 26, 2024
730c287
cast `fx.immutable_{dict, list}` to `dict`/`list`
crcrpar Nov 26, 2024
e943081
printer and bind_postprocess for `__tensor_flatten__` & `__tensor_unf…
crcrpar Nov 26, 2024
3a62aaf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 26, 2024
06c7eeb
xfail reason
crcrpar Nov 26, 2024
1de88f3
cosmetic
crcrpar Nov 27, 2024
5c197cf
simplify subclass output handling
crcrpar Nov 27, 2024
4b9e67f
Unrolling tensor subclasses in fwd/bwd split (#1489)
crcrpar Nov 28, 2024
759e01d
reduce return values by one
crcrpar Nov 28, 2024
c4c89b0
clarify the error is numeric
crcrpar Nov 28, 2024
eae9834
add bfloat16 to test parametrization
crcrpar Nov 29, 2024
0f41b5e
torch_compile_ex style transform for execution
crcrpar Nov 29, 2024
70b576b
update test
crcrpar Nov 29, 2024
66d67f9
clarify nothing is put into thunder.jit when thunderfx
crcrpar Nov 29, 2024
bfdbe5a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 29, 2024
7f14349
shorter header for torch dispatch result
crcrpar Nov 29, 2024
481b3b2
try to tell if the trace is backward or not by checking certain bsyms
crcrpar Nov 29, 2024
96e5562
update test
crcrpar Nov 29, 2024
595d32c
check bsym.args itself before its first name
crcrpar Nov 29, 2024
6b4330d
warn tensor subclass support
crcrpar Nov 29, 2024
4d8d375
test update
crcrpar Nov 30, 2024
87e7354
more meticulous bsym check to tell if the trace is bwd
crcrpar Nov 30, 2024
46c485d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2024
d8309dc
remove `flat_trace_args_spec`
crcrpar Nov 30, 2024
c1246e6
fix wrong rebase output
crcrpar Dec 7, 2024
e226903
fix typo
crcrpar Dec 12, 2024
0c54374
add tensor subclass transform output to traces
crcrpar Dec 12, 2024
d28c6ae
bring back unexpectedly deleted line
crcrpar Dec 13, 2024
ee91436
add note about the behavioral difference
crcrpar Dec 13, 2024
9fe5deb
DCE for ``tensor.__tensor_flatten__``
crcrpar Dec 13, 2024
cc86afb
update regex of assert raises
crcrpar Dec 13, 2024
b793ffb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 13, 2024
cdefd06
`torch._scaled_matmul` decomposition
crcrpar Dec 16, 2024
47f5cc2
fix check and add missing cast
crcrpar Dec 16, 2024
c838ea0
no consumer map
crcrpar Dec 16, 2024
8d15622
add `flatten_tensor_subclass` to docs
crcrpar Dec 16, 2024
54645b3
fix column major check & add dtype check of data mat
crcrpar Dec 16, 2024
b7801b0
use existing ones
crcrpar Dec 16, 2024
4e6d9fd
fix nvfuser impl
crcrpar Dec 17, 2024
97e7ee2
add device check
crcrpar Dec 17, 2024
8693dc7
rename to dce from cse
crcrpar Dec 19, 2024
caa2242
rename to `enable_scaled_mm`
crcrpar Dec 19, 2024
6a6fe59
update cond
crcrpar Dec 19, 2024
fce281b
getnv(a,...) -> getnv(b,...)
crcrpar Dec 20, 2024
e5d26fc
remove nvfuser scaled_mm
crcrpar Dec 21, 2024
56c69df
remove flattening and unflattening
crcrpar Dec 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/transforms/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ thunder.transforms
.. autosummary::
:toctree: generated/

flatten_tensor_subclasses
MaterializationTransform
ConstantFolding
8 changes: 7 additions & 1 deletion thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
from thunder.core.interpreter import print_interpreter_log, print_to_log
from thunder.core.jit_ext import thunder_general_jit
from thunder.executors.torch_autograd import split_forward_backward, ThunderFunction
from thunder.transforms.tensor_subclasses import flatten_tensor_subclasses

# NOTE This import is intentionally pytorch so that it thunder.torch doesn't import this
import torch as pytorch
Expand Down Expand Up @@ -370,7 +371,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in data_ptr_to_tensor_group_index:
data_ptr_to_tensor_group_index[data_ptr] = len(data_ptr_to_tensor_group_index)
Expand Down Expand Up @@ -617,6 +618,7 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

_tensor_subclass_transform_applied = False
backward_trc = None
if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
Expand All @@ -630,6 +632,10 @@ def get_computation_and_inputs(*args, **kwargs):
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *inps)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward
_tensor_subclass_transform_applied = True
if not _tensor_subclass_transform_applied:
computation_trc = flatten_tensor_subclasses(computation_trc)
computation_traces.append(computation_trc)

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
Expand Down
59 changes: 59 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
NumberProxy,
StringProxy,
TensorProxy,
SubclassTensorProxy,
FutureTensorProxy,
make_proxy_name,
Variable,
Expand Down Expand Up @@ -666,6 +667,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
So far, non-tensor ``ctx`` attributes seem to be folded into a trace.
"""
from thunder.core.baseutils import check, sequencify
from thunder.core.transform_common import dce

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
Expand All @@ -677,6 +679,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
)
if trace_of_fwd is INTERPRETER_SIGNALS.EXCEPTION_RAISED:
return trace_of_fwd
trace_of_fwd = dce(trace_of_fwd)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

double check if this is really necessary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without dce's, torchao.float8 tests fail:

______________________________________ test_torchao_float8_linear_torch_cuda_thunder.dtypes.float32[False] ______________________________________

args = <thunder.core.interpreter.WrappedValue object at 0x7dd65e182380>
kwargs = <thunder.core.interpreter.WrappedValue object at 0x7dd65e181060>
runtimectx = <thunder.core.interpreter.InterpreterRuntimeCtx object at 0x7dd65e1d37f0>
fn_wrapped = <thunder.core.interpreter.WrappedValue object at 0x7dd6638f3190>
getfn = <function interpret.<locals>.fn_.<locals>.getfn at 0x7dd658f21750>
wrapped_fn_2 = <thunder.core.interpreter.WrappedValue object at 0x7dd6638f3250>
wrapped_closure = <thunder.core.interpreter.WrappedValue object at 0x7dd6638f1cc0>
wrapped_cell = <thunder.core.interpreter.WrappedValue object at 0x7dd662d6c670>, traceback_str = ''
msg = 'Encountered exception ValueError: Variable t_0 is being overwritten this is not allowed while tracing Sequential(\n  ...Linear(in_features=64, out_features=64, bias=False, cast_configs=i:dyn_ten_e4m3,w:dyn_ten_e4m3,go:dyn_ten_e5m2")\n):\n'

    @functools.wraps(fn)
    def fn_(*args, **kwargs) -> Any:
        runtimectx: InterpreterRuntimeCtx = InterpreterRuntimeCtx(debug_log=debug_log, record_history=record_history)

        with interpreter_ctx(compilectx, runtimectx):
            try:
                # we normalize the outmost function to be interpreted to take
                # args and kwargs as arguments (not *args and **kwargs).
                # We thus have three special INPUTs for the entry function: INPUT_ARGS, INPUT_KWARGS, INPUT_FN
                args = wrap(
                    args,
                    provenance=ProvenanceRecord(inst=PseudoInst.INPUT_ARGS, inputs=[]),
                )

                kwargs = wrap(
                    kwargs,
                    provenance=ProvenanceRecord(inst=PseudoInst.INPUT_KWARGS, inputs=[]),
                )

                fn_wrapped = wrap(
                    fn,
                    provenance=ProvenanceRecord(inst=PseudoInst.INPUT_FN, inputs=[]),
                )

                def getfn():
                    def fn_2(args, kwargs):
                        return fn(*args, **kwargs)

                    return fn_2

                wrapped_fn_2 = wrap_const(getfn())
                if compilectx._with_provenance_tracking:
                    wrapped_closure = wrap_attribute(
                        wrapped_fn_2.value.__closure__, wrapped_fn_2, wrap_const("__closure__")
                    )
                    wrapped_cell = wrap_binary_subscr(wrapped_closure.value[0], wrapped_closure, 0)
                    assert isinstance(wrapped_closure.item_wrappers, list)
                    wrapped_closure.item_wrappers[0] = wrapped_cell
                    populate_attribute_wrapper(wrapped_cell, "cell_contents", fn_wrapped)

>               interpretation_result: Any = _interpret_call(wrapped_fn_2, args, kwargs)

thunder/core/interpreter.py:7207:


# Forward.
unwrapped_custom_forward_args = tree_map(lambda a: unwrap(a), args)
Expand All @@ -690,6 +693,7 @@ def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwar
for a in filter(lambda a: isinstance(a, Proxy), trace_of_fwd.args)
]
trace_of_fwd.bound_symbols = unpack_bsyms + trace_of_fwd.bound_symbols
trace_of_fwd = dce(trace_of_fwd)
crcrpar marked this conversation as resolved.
Show resolved Hide resolved

@wraps(trace_of_fwd.python_callable())
def core_of_forward(*args, **kwargs):
Expand Down Expand Up @@ -736,6 +740,7 @@ def core_of_forward(*args, **kwargs):
for a in filter(lambda a: isinstance(a, Proxy), trace_of_backward.args)
]
trace_of_backward.bound_symbols = bwd_unpack_bsyms + trace_of_backward.bound_symbols
trace_of_backward = dce(trace_of_backward)
crcrpar marked this conversation as resolved.
Show resolved Hide resolved

bwd_trace_impl = TraceCtx()
bwd_trace_impl.bound_symbols.extend(trace_of_backward.bound_symbols)
Expand Down Expand Up @@ -769,6 +774,24 @@ def grad_transform(*args, **kwargs):
execution_transform=core_of_forward,
grad_transform=grad_transform,
)

added_bsym: BoundSymbol = get_jit_ctx().computation_trace.scopes[-1][-1]
import_ctx, call_ctx, object_ctx = {}, {}, {}
for bsym in trace_of_fwd.bound_symbols:
cur_import_ctx, cur_call_ctx, cur_object_ctx = bsym.gather_ctxs()
import_ctx.update(cur_import_ctx)
call_ctx.update(cur_call_ctx)
object_ctx.update(cur_object_ctx)

if import_ctx:
added_bsym._import_ctx.update(import_ctx)
if call_ctx:
if added_bsym._call_ctx is not None:
added_bsym._call_ctx.update(call_ctx)
else:
added_bsym._call_ctx = call_ctx
if object_ctx:
added_bsym._object_ctx.update(object_ctx)
crcrpar marked this conversation as resolved.
Show resolved Hide resolved
return forward_result


Expand Down Expand Up @@ -863,6 +886,42 @@ def grad_transform(*args, **kwargs):
return output


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
def _make_wrapper_subclass(
cls: torch._C._TensorMeta,
size: Sequence[int],
strides: Sequence[int] | None = None,
storage_offset: int | None = None,
memory_format: torch.memory_format | None = None,
dtype: torch.dtype | None = None,
layout: torch.layout | None = torch.strided,
device: torch.device | None = None,
pin_memory: bool = False,
requires_grad: bool = False,
dispatch_sizes_strides_policy: str | None = None,
dispatch_device: bool = False,
dispatch_layout: bool = False,
_extra_dispatch_keys: torch.DispatchKeySet | None = None,
storage_size: int | None = None,
):
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)

subclass = SubclassTensorProxy(
None,
shape=usize,
device=udevice,
dtype=udtype,
requires_grad=urequires_grad,
history=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]),
subclass_type=ucls,
)
return wrap(subclass, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, [cls.provenance]))


@register_general_jit_lookaside(torch.autocast.__enter__)
def autocast_enter(autocast_obj):
unwrap_autocast_obj = unwrap(autocast_obj)
Expand Down
Loading
Loading