-
Notifications
You must be signed in to change notification settings - Fork 100
Added support F.one_hot
#128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…pes on CPU devices
…/shaharelys/lightning-thunder into added_support_hardswish_function
for more information, see https://pre-commit.ci
…rom main but same conflict)
for more information, see https://pre-commit.ci
make = partial(make_tensor, device=device, dtype=torch.long, requires_grad=requires_grad) | ||
|
||
test_shapes = [ | ||
(10,), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add a tensor with no dimensions (its shape is ()
) and a tensor with no elements (like (0, 512)
), too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, empty tensors + 0-dim tensors are very important. The underlying ops dispatch to PyTorch native implementations (scatter_add
) but these were written by me before the widespread adoption of 0-dim tensors... So we'd better double check that PyTorch does the right thing here as well... And if PyTorch falls short, we could file an issue there, and short circuit our implementation here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are still missing empty inputs. Or is there an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nikitaved I've added an empty input (0, 512)
as you guys suggested (:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shaharelys , what about scalar inputs, the inputs with no dimension (i.e. shape=()
)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @nikitaved! Sorry for the delay. I did not add these. Should we also add these now?
Hey @shaharelys! This looks pretty good. I made some comments for your review. I also added @nikitaved as a reviewer. |
@mruberry |
thunder/torch/__init__.py
Outdated
src = ones_like(index, dtype=dtypes.int64) | ||
|
||
return scatter_add(canvas, dim=-1, index=index, src=src) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems fine for now, but what happens is that we create a tensor full of ones because tensor creation and scatter_add
are not going to be fused together. We could:
- create a tensor with a single 1 and broadcast it.
- leave it be in hopes that sometime in some future our backends/executors will improve this bit (when scatter_add will be there for nvFuser, for example). Worth adding a comment, I guess?
PyTorch uses scatter
and that, apparently, allows scalars as source inputs...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @nikitaved! Not sure I understand the implementation suggested in ..and broadcast it. Tried to implement naively by replacing,
src = ones_like(index, device=a.device, dtype=dtypes.int64)
with,
src = tensor([1], device=a.device, dtype=dtypes.int64)
and got,
thunder/core/prims.py:3009: in scatter_add_meta
utils.check(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
cond = False, s = <function scatter_add_meta.<locals>.<lambda> at 0x7bb2a8c043a0>
exception_type = <class 'RuntimeError'>
def check(cond: bool, s: Callable[[], str], exception_type: type[Exception] = RuntimeError) -> None:
"""Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
s is a callable producing a string to avoid string construction if the error check is passed.
"""
if not cond:
> raise exception_type(s())
E RuntimeError: Expected index (rank=3) to have the same rank as value (rank=1)
thunder/core/baseutils.py:103: RuntimeError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can read about broadcasting here: https://numpy.org/doc/stable/user/basics.broadcasting.html. The idea is to create a tensor of a single 1
and reshape it into some other shape which, in this case, should match either the index or the source as I reckon. This solution should be future-proof and optimal. But could be done as a follow-up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nikitaved
Cool, and should this be more efficient operation than current implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, yes. It will spare us launching a kernel that fills memory with 1
s. This might become redundant once the target executor can fuse together ones
and scatter_add
. In the context of NVFuser, this means putting the code of ones
and scatter_add
into a single CUDA kernel. Do not worry for now if you do not understand these things, and the change is not that critical for now. You can learn more about executors from the documentation that should be, alas, built locally.
for more information, see https://pre-commit.ci
Hey @mruberry, @nikitaved ! I've reviewed and addressed most of the comments. However, there are a few points I'm uncertain about. I'll comment directly under those for clarity. Looking forward to your feedback! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @shaharelys @nikitaved @mruberry
@t-vi , I have not checked the changes yet, nor have I answered the posited questions :) |
Before submitting
What does this PR do?
Implements
F.one_hot
.Fixes #64
PR review
This PR is open for review, yet again I'm not certain it's complete. Added comments under
one_hot
for things I wasn't sure of. Tests have passed.docs
has not been updated (I had a look at this README but I wasn't sure I understand). Feedback is welcomed!Did you have fun?
I sure did! 🙃