Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tensor Subclasses] Support func calling only Subclass(...) #1393

Closed
wants to merge 7 commits into from

Conversation

crcrpar
Copy link
Collaborator

@crcrpar crcrpar commented Nov 2, 2024

What does this PR do?

The scope of this PR is Tensor Subclasses that call torch.Tensor._make_wrapper_subclass in their dunder new (and define __torch_dispatch__, __tensor_flatten__, and __tensor_unflatten__).

This PR adds a new proxy for such tensor subclasses and implements a lookaside for _make_wrapper_subclass that returns an instance of the new proxy.

MySubclass(...) calls MySubclass.__new__(cls, ...) before calling __init__(...) on the return value of the dunder new.
Since this PR has the lookaside and it returns an instance of a proxy, not MySubclass, the __init__ of the new proxy is called.
This is the reason the new proxy has if-else branches inside of its dunder new.

Caveat: This assumes that Subclass.__new__ does not have kwargs, only positional args.

@crcrpar crcrpar changed the title Support func calling only Subclass(...) [Tensor Subclasses] Support func calling only Subclass(...) Nov 2, 2024
Comment on lines +765 to +783
ucls = unwrap(cls)
usize = unwrap(size)
udtype = unwrap(dtype)
udevice = unwrap(device)
urequires_grad = unwrap(requires_grad)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be clear about ignored args

requires_grad=requires_grad,
tensors=tensors,
non_tensors=non_tensors,
history=[t.history for t in tensors],
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

history should be able to get better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of history?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dunno, something disallows history being empty/none

@@ -1880,6 +1880,111 @@ def real(self):
return method(self)


class SubclassTensorProxy(TensorProxy):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be able to express __tensor_flatten__(self) -> tuple[list[str], dict[str, Any]] (and __tensor_unflatten__(inner_tensors: dict[str, Tensor], metadata: dict[str, Any], outer_size, outer_stride) -> MySubclass).
For it to happen, somewhere I have to give instances of this class the attribute names of tensors and non-tensor values.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you add a comment describing what this class is intended for?

@crcrpar crcrpar force-pushed the crpa/subclss-tensor-init branch from 0b69d52 to 9d77226 Compare November 3, 2024 06:45
assert scale.numel() == 1, f"Invalid `scale`: {scale}"
dtype = x.dtype
device = x.device
self = torch.Tensor._make_wrapper_subclass(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to use make_wrapper_subclass when requires_grad=False?

Here the behavior is different depending on the requires_grad value:
https://github.com/albanD/subclass_zoo/blob/ec47458346c2a1cfcd5e676926a4bbc6709ff62e/base_tensor.py#L12-L15

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The link uses _make_subclass not _make_wrapper_subclass and the last update is 2 years ago, so it doesn't sound convincing to me

@@ -743,6 +744,42 @@ def grad_transform(*args, **kwargs):
return forward_result


@register_general_jit_lookaside(torch.Tensor._make_wrapper_subclass)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it need to be a jit lookaside? Can the implementation be moved to thunder/torch/__init__.py if @torchsymbol is used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather want to hide this method as possible. so making it a lookaside feels more right than a torchsymbol

@@ -369,7 +369,7 @@ def _alias_tensor_of_args_kwargs_dict(*args, **kwargs) -> dict[int, list[int]]:
data_ptr_to_tensor_group_index = {}
tensor_group_index_to_tensor_indices = defaultdict(list)
for idx, t in enumerate(flat_args):
if pytorch.is_tensor(t) and t.layout == pytorch.strided:
if type(t) in {pytorch.Tensor, pytorch.nn.Parameter} and t.layout == pytorch.strided:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please submit this fix with a test in a separate pull request?

Copy link
Collaborator Author

@crcrpar crcrpar Nov 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That pull request would need a subclass in the test then I'm not quite convinced by the option

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do subclasses appear here? Do all subclasses have the actual torch.Tensor type?

@crcrpar crcrpar force-pushed the crpa/subclss-tensor-init branch from 21c2af8 to 11fea26 Compare November 6, 2024 08:50
…ass` lookaside

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
no `__torch_dispatch__` support at all.

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
somehow, apparently

Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
Signed-off-by: Masaki Kozuki <mkozuki@nvidia.com>
@crcrpar crcrpar force-pushed the crpa/subclss-tensor-init branch from 11fea26 to 2e4ecad Compare November 6, 2024 13:52
@crcrpar crcrpar marked this pull request as ready for review November 6, 2024 13:52
kwarg_non_tensors = kwargs.pop("non_tensors", [])
subclass_type = kwargs.pop("subclass_type", None)

# If tensors (and non_tensors) are not empty, then it should be the path of `_make_wrapper_subclass`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you elaborate on this comment?

non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args))
has_name_before_init = hasattr(self, "_name")

is_dunder_init_following_make_wrapper_subclass: bool = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's going on here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def __init__(self, *args, **kwargs):
from thunder.core.pytree import tree_flatten

kwarg_tensors = kwargs.pop("tensors", [])
Copy link
Collaborator

@mruberry mruberry Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__(self, *args, tensors=[], non_tensors=[], subclass_type=None, **kwargs)

?

flat_args, spec = tree_flatten((args, kwargs))
tensors = list(filter(lambda t: isinstance(t, TensorProxy), flat_args))
non_tensors = list(filter(lambda t: not isinstance(t, TensorProxy), flat_args))
has_name_before_init = hasattr(self, "_name")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can this happen?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If thunder sees a tensor wrapper subclass of MySubclass(...) that has its own dunder new calling _make_wrapper_subclass in a function thunder's tracing, the lookaside of https://github.com/Lightning-AI/lightning-thunder/pull/1393/files#diff-3d1ea50ad3b0e3ad6fc369f91a7e42011d1d33d770ce25f800637c99de85f4b5R762 creates an instance of SubclassTensorProxy, then the dunder init of that instance is called, not the dunder init of MySubclass instance.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess I'm still confused about how and why the proxy class is entangled with the actual subclass


self._tensors = tensors
self._non_tensors = non_tensors
bsym = prims.tensor_subclass_ctor.bind(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting. Why does this class either create a new proxy or change the actual trace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this path is happening while tracing and the call of Proxy.__init__ is not recorded by default in a trace so it's necessary to deliberately register a boundsymbol to the currently active scope

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK; but why does proxy creation get recorded into the trace?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a function creates a subclass inside it then shouldn't a trace of it have a BoundSymbol to represent it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's OK to put tensor subclass creation into a trace (although I'm curious if we can dce the creation if the subclass is just flattened and used once later), but I'm not sure why the operator that creates the actual tensor subclass is also a constructor for the proxy. The existing tensor proxies don't entangle their tensor factory methods with the creation of the proxy, for example

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The goal is to divorce the proxy from the actual runtime object to simplify the code and its concepts.

How else would we infer the decomposition of operations called on a tensor subclass, and if the result is also a member of the subclass?

I understand that init gets called on the object at runtime, of course. The goal is not to someone circumvent the proper construction of the object at runtime (unless it can be elided for performance).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm not understanding this well, but I'm a bit worried that we don't have well-defined semantics here and that the traces are not representing what's up.

We should absolutely know when we want to construct and use the subclass object (if we return it or something wants the subclass object) and when we don't (all other cases) and should also represent what will be the compute in the trace.

We also want to minimize admin overhead during the compute, which we are not doing a great job today (having looked at wall clock vs. GPU self time for Llama 1b today), so adding the overhead of dealing with subclasses at compute time should likely be minimized.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excuse me for my idling on this. I guess I understand your points a bit better. The separation does sound nice. For it, I think one way is to do the interpretation of __torch_dispatch__ while acquiring the initial trace, not what #1394 is doing, i.e., get the initial trace with opaque tensor subclass proxies before unrolling their __torch_dispatch__. With this, some traces like ones we'd get from torchao float8 programs would be completely free from tensor subclass proxies. If any of arguments or return values of a function is a tensor subclass, we might end up having one or more proxy classes for tensor subclasses observed in it and we might need to implement a dynamic registration of them while interpretation, which I almost have no idea at the moment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to talk through this more over VC if it would be helpful, too!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tensor subclass proxy objects would have type(tensor_subclass_object) as their attribute but not actual instances.

I'd be rather hesitant to complicate the current source code of interpreter by introducing mechanism which unrolls tensor subclasses, defines a proxy class for the observed tensor subclass, and registers that proxy class to the namespace.
Currently, the implementation itself isn't the best for sure, but the unrolling of tensor subclasses happens after the initial computation is acquired (This initial trace would have some wrong expressions, e.g., when a tensor subclass' __torch_dispatch__ returns one or multiple objects of that tensor subclass), thus I think the change of this PR (more precisely, #1394 and #1415) would have less interactions.

else:
cur_tail_scope.append(bsym)

def replace(self, **changes):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is a replace function needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because I don't want .replace(...) call to replace subclass tensor proxy with tensor proxy

in preference to the old values but overridable by keyword arguments.
Note that the copy will use the current (environment) tracectx."""

like = changes.get("like")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If changes doesn't have like as a key won't this throw a key error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If key is not found then the default value, in this case, None is returned. https://docs.python.org/3/library/stdtypes.html#dict.get

thunder_fsdp_padding_size,
) = _infer_tensor_properties(
like,
changes.get("shape", self._shape if like is None else None),
Copy link
Collaborator

@mruberry mruberry Nov 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not

*[changes.get(key, None) for key in ('shape', 'device', ...)]

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied these lines from TensorProxy.replace and I'm lazy enough not to do that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But doesn't the current code mean that if like is specified then the shape of the like tensor is overriden with the shape of the current tensor? Is that what's intended?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think changes > like > self is the priority order.

@instantiate(
dtypes=(thunder.core.dtypes.float32,),
)
def test_func_of_subclass_ctor_wrapper(executor, device, _):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these tests just check that the tensor subclass can be constructed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because I was struggling to support previously, in #1345. The others are in #1394

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my part this is good to merge (thanks @crcrpar and @IvanYashchuk for the sync with @lantiga and me last week). We can revise the bits from this PR when the need arises while we build on it.

@mruberry do you think we should go ahead?

@crcrpar crcrpar marked this pull request as draft November 29, 2024 00:25
@crcrpar
Copy link
Collaborator Author

crcrpar commented Nov 29, 2024

I'd like to move to #1394. #1394 is based on this one so I think changing the target of #1394 to main would be better than merge 1393 then 1394

@t-vi
Copy link
Collaborator

t-vi commented Nov 29, 2024

Works for me, too.

@crcrpar crcrpar closed this Dec 7, 2024
@crcrpar crcrpar deleted the crpa/subclss-tensor-init branch December 7, 2024 07:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants