Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ #1394

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 3, 2024

  • add a generic proxy class to represent Tensor Wrapper Subclasses that call torch.Tensor._make_wrapper_subclass in their __new__ and define their own __torch_dispatch__
  • add a trace transform that evaluate BoundSymbols of a trace one by one so that we could make a trace free from actual tensor subclass objects as possible and write out the actual behavior that is defined by __torch_dispatch__ in a trace

@crcrpar crcrpar changed the title [do not review] ops with subclass support [do not review] ops with subclass support, on top of 1393 Nov 4, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from fa05c82 to 519e813 Compare November 5, 2024 08:00
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
thunder/transforms/tensor_subclasses.py Outdated Show resolved Hide resolved
@crcrpar crcrpar force-pushed the crpa/subclss-tensor-init branch from 21c2af8 to 11fea26 Compare November 6, 2024 08:50
@crcrpar crcrpar changed the base branch from crpa/subclss-tensor-init to main November 6, 2024 13:51
@crcrpar crcrpar changed the title [do not review] ops with subclass support, on top of 1393 [Tensor Subclasses] [do not review] Trace transfom to interpret __torch_dispatch__ and get the correct output type. Depends on 1393 Nov 6, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 5e4b00d to 4cd7a3d Compare November 6, 2024 15:20
@crcrpar crcrpar marked this pull request as ready for review November 6, 2024 15:20
@crcrpar crcrpar changed the title [Tensor Subclasses] [do not review] Trace transfom to interpret __torch_dispatch__ and get the correct output type. Depends on 1393 [Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ and get the correct output type. based on 1393 Nov 6, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 4cd7a3d to bed021e Compare November 7, 2024 06:14
@t-vi
Copy link
Collaborator

t-vi commented Nov 7, 2024

To my mind, we would get better error handling (eg stack traces) if we resolved the __torch_dispatch__ earlier by checking for subclasses in the torchsymbol handling logic. Also, I don't think we have much information about the outputs - e.g. are they subclasses again or the original class, shapes etc. - without doing so, so we cannot evaluate even metadata in control flow.

@crcrpar

This comment was marked as outdated.

@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 3fa8e2d to d5fb9fe Compare November 19, 2024 06:41
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from d5fb9fe to 15c8d12 Compare November 26, 2024 07:22
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 15c8d12 to 70dc6ba Compare November 28, 2024 12:31
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from 70dc6ba to fc6d8a9 Compare December 7, 2024 07:22
@crcrpar crcrpar changed the title [Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ and get the correct output type. based on 1393 [Tensor Subclasses] Trace transfom to interpret __torch_dispatch__ Dec 9, 2024
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from fc6d8a9 to ce3edbc Compare December 12, 2024 23:23
@mruberry
Copy link
Collaborator

There's a lot going on with this PR, and it's pretty complicated. Maybe we should schedule an online sync, @crcrpar and @IvanYashchuk, to see if we can make it more incremental?

…ass` lookaside

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
no `__torch_dispatch__` support at all.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
somehow, apparently

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/subclass-tensor-ops branch from ce3edbc to e9027ba Compare December 21, 2024 07:28
@t-vi
Copy link
Collaborator

t-vi commented Dec 21, 2024

I would still prefer to do the flattening during interpretation for the benefit of getting error messages with backtraces.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 22, 2024

There's a lot going on with this PR, and it's pretty complicated. Maybe we should schedule an online sync, @crcrpar and @IvanYashchuk, to see if we can make it more incremental?

We don't have a nice time slots that work all of us.
What would have you find PRs incremental?
Would it look better to you to split this into two, one to add a proxy and the other to add trace transforms?

EDIT: #1583, #1584, and #1585 are a sequence of PRs that cover this one and #1415.

I would still prefer to do the flattening during interpretation for the benefit of getting error messages with backtraces.

I embarrassingly am not familiar with interpretation implementation at all. Could you give me some pointers to add flattening so that it happens during interpretation?

EDIT: __torch_dispatch__ overrides the behavior of ops used in backward and the customized behavior could use torch ops that do not have their backward definition. Thus I do think we'd need to put the flattening AFTER forward-backward split as in #1415

@mruberry
Copy link
Collaborator

There's a lot going on with this PR, and it's pretty complicated. Maybe we should schedule an online sync, @crcrpar and @IvanYashchuk, to see if we can make it more incremental?

We don't have a nice time slots that work all of us. What would have you find PRs incremental? Would it look better to you to split this into two, one to add a proxy and the other to add trace transforms?

OK; we can try to work asynchronously. For a first incremental PR, would you create a PR adding support for aten operators? In particular, if someone were to call something like torch.ops.aten.add then how would that work? Can that operator be added to the torch operations (and be in the torch language context), or should it be added to its own aten language context that has a separate file (or set of files)?

@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 23, 2024

if someone were to call something like torch.ops.aten.add then how would that work?

So far in my implementation there's an optimistic mapping from core aten ops to ltorch ops:

for node in list_of_function_call_node:
if not hasattr(ltorch, node.target._opname):
msg = (
f"`thunder.torch` does not have corresponding op for {node.target._opname}. "
"Think about adding it to thunder/torch/default_torch_ops.py"
f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}"
f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}"
)
raise RuntimeError(msg)
ltorch_ops_for_node_of_ops.append(getattr(ltorch, node.target._opname))
.

By the way, if we're to cover core aten ops, then I'd say it'd be worth thinking of using thunder as a custom backend after AOTAutograd.

Can that operator be added to the torch operations (and be in the torch language context), or should it be added to its own aten language context that has a separate file (or set of files)?

Currently torchsymbol has some for core aten ops, apparently

elif hasattr(torch.ops.aten, name):
id = f"torch.ops.aten.{name}"

@torchsymbol(torch.ops.aten.embedding_backward)
and
torch.ops.aten._adaptive_avg_pool2d_backward,
are a core aten op. So I think extending thunder/torch/__init__.py would be fair.

@mruberry
Copy link
Collaborator

mruberry commented Dec 23, 2024

if someone were to call something like torch.ops.aten.add then how would that work?

So far in my implementation there's an optimistic mapping from core aten ops to ltorch ops:

for node in list_of_function_call_node:
if not hasattr(ltorch, node.target._opname):
msg = (
f"`thunder.torch` does not have corresponding op for {node.target._opname}. "
"Think about adding it to thunder/torch/default_torch_ops.py"
f"\nThe op is found while flattening the following BoundSymbol:\n{bsym}"
f"\ntorch.fx graph:\n{fx_graph.print_readable(print_output=False)}"
)
raise RuntimeError(msg)
ltorch_ops_for_node_of_ops.append(getattr(ltorch, node.target._opname))

.
By the way, if we're to cover core aten ops, then I'd say it'd be worth thinking of using thunder as a custom backend after AOTAutograd.

Can that operator be added to the torch operations (and be in the torch language context), or should it be added to its own aten language context that has a separate file (or set of files)?

Currently torchsymbol has some for core aten ops, apparently

elif hasattr(torch.ops.aten, name):
id = f"torch.ops.aten.{name}"

@torchsymbol(torch.ops.aten.embedding_backward)

and

torch.ops.aten._adaptive_avg_pool2d_backward,

are a core aten op. So I think extending thunder/torch/__init__.py would be fair.

OK; expanding thunder/torch/init.py sounds good for now. Let's not "optimistically" try to map ATen operations to torch operations for the moment, but just treat them like different operations.

Would you submit a PR adding torch.ops.aten.add to the torch operations?

EDITED BELOW.

As a follow-up PR to that, what about working with a program like

# Original program
def foo(x):
  return x

# Trace
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  return x

Where the initial trace shows the tensor subclass and its flattened information, and the prologue validates the subclass and its flattening. Then I'd be curious to see addition with that tensor, like this:

# Original program
def foo(x):
  return x + 1

# Trace
def computation(x):
  # x: "MyTensorSubclass[cuda:0 f32[12, 12]]" 
  t0 = MyTensorSubclass.torch.add(x, 1)  # t0: "MyTensorSubclass[cuda:0 f32[12, 12]]"
    # t1 = flatten_tensor_subclass(MyTensorSubclass, x)
    # t2 = torch.ops.aten.add(t1, 1)
      # <decomposition of aten.add into prims would go here> 
    # t0 = unflatten_tensor_subclass(MyTensorSubclass, t2)
  return t0

This can be translated for execution by PyTorch, but I think working through this will be interesting. Then the follow-up question is what the grad transform for it looks like, and how this operation should be translated for execution by nvFuser.

@crcrpar
Copy link
Collaborator Author

crcrpar commented Dec 24, 2024

IMHO, it'd sound more natural to me to register core aten ops to thunder.torch namespace after merging #1583.

Then registration comes after the aforementioned PR, before #1584 and #1585, followed by some refinement of prologue and how traces with tensor subclasses look accompanied by #1584.

the follow-up question is what the grad transform for it looks like, and how this operation should be translated for execution by nvFuser.

With the experience of #1585, I do think we'd have to let the trace get split into forward and backward before interpreting __torch_dispatch__ partly because the extended behavior of certain ops could be dependent on ops without any backward definitions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants