Skip to content

Added support F.one_hot #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Apr 7, 2024
Merged

Conversation

shaharelys
Copy link
Contributor

Before submitting
  • Was this discussed/approved via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs? No, should I?
  • Did you write any new necessary tests?

What does this PR do?

Implements F.one_hot.

Fixes #64

PR review

This PR is open for review, yet again I'm not certain it's complete. Added comments under one_hot for things I wasn't sure of. Tests have passed. docs has not been updated (I had a look at this README but I wasn't sure I understand). Feedback is welcomed!

Did you have fun?

I sure did! 🙃

shaharelys and others added 20 commits March 28, 2024 21:43
make = partial(make_tensor, device=device, dtype=torch.long, requires_grad=requires_grad)

test_shapes = [
(10,),
Copy link
Collaborator

@mruberry mruberry Apr 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a tensor with no dimensions (its shape is ()) and a tensor with no elements (like (0, 512)), too

Copy link
Contributor

@nikitaved nikitaved Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, empty tensors + 0-dim tensors are very important. The underlying ops dispatch to PyTorch native implementations (scatter_add) but these were written by me before the widespread adoption of 0-dim tensors... So we'd better double check that PyTorch does the right thing here as well... And if PyTorch falls short, we could file an issue there, and short circuit our implementation here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are still missing empty inputs. Or is there an error?

Copy link
Contributor Author

@shaharelys shaharelys Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikitaved I've added an empty input (0, 512) as you guys suggested (:

Copy link
Contributor

@nikitaved nikitaved Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shaharelys , what about scalar inputs, the inputs with no dimension (i.e. shape=())?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nikitaved! Sorry for the delay. I did not add these. Should we also add these now?

@mruberry mruberry requested a review from nikitaved April 3, 2024 18:20
@mruberry
Copy link
Collaborator

mruberry commented Apr 3, 2024

Hey @shaharelys! This looks pretty good. I made some comments for your review. I also added @nikitaved as a reviewer.

@shaharelys
Copy link
Contributor Author

shaharelys commented Apr 3, 2024

@mruberry
Thx a lot! Will look into these 🙏🏼

Comment on lines 3502 to 3504
src = ones_like(index, dtype=dtypes.int64)

return scatter_add(canvas, dim=-1, index=index, src=src)
Copy link
Contributor

@nikitaved nikitaved Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems fine for now, but what happens is that we create a tensor full of ones because tensor creation and scatter_add are not going to be fused together. We could:

  • create a tensor with a single 1 and broadcast it.
  • leave it be in hopes that sometime in some future our backends/executors will improve this bit (when scatter_add will be there for nvFuser, for example). Worth adding a comment, I guess?

PyTorch uses scatter and that, apparently, allows scalars as source inputs...

Copy link
Contributor Author

@shaharelys shaharelys Apr 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @nikitaved! Not sure I understand the implementation suggested in ..and broadcast it. Tried to implement naively by replacing,
src = ones_like(index, device=a.device, dtype=dtypes.int64)
with,
src = tensor([1], device=a.device, dtype=dtypes.int64)
and got,

thunder/core/prims.py:3009: in scatter_add_meta
    utils.check(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

cond = False, s = <function scatter_add_meta.<locals>.<lambda> at 0x7bb2a8c043a0>
exception_type = <class 'RuntimeError'>

    def check(cond: bool, s: Callable[[], str], exception_type: type[Exception] = RuntimeError) -> None:
        """Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
    
        s is a callable producing a string to avoid string construction if the error check is passed.
        """
        if not cond:
>           raise exception_type(s())
E           RuntimeError: Expected index (rank=3) to have the same rank as value (rank=1)

thunder/core/baseutils.py:103: RuntimeError

Copy link
Contributor

@nikitaved nikitaved Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can read about broadcasting here: https://numpy.org/doc/stable/user/basics.broadcasting.html. The idea is to create a tensor of a single 1 and reshape it into some other shape which, in this case, should match either the index or the source as I reckon. This solution should be future-proof and optimal. But could be done as a follow-up.

Copy link
Contributor Author

@shaharelys shaharelys Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikitaved
Cool, and should this be more efficient operation than current implementation?

Copy link
Contributor

@nikitaved nikitaved Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, yes. It will spare us launching a kernel that fills memory with 1s. This might become redundant once the target executor can fuse together ones and scatter_add. In the context of NVFuser, this means putting the code of ones and scatter_add into a single CUDA kernel. Do not worry for now if you do not understand these things, and the change is not that critical for now. You can learn more about executors from the documentation that should be, alas, built locally.

@shaharelys
Copy link
Contributor Author

Hey @mruberry, @nikitaved ! I've reviewed and addressed most of the comments. However, there are a few points I'm uncertain about. I'll comment directly under those for clarity.

Looking forward to your feedback!

@t-vi t-vi enabled auto-merge (squash) April 7, 2024 12:26
Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@t-vi t-vi merged commit cd80d08 into Lightning-AI:main Apr 7, 2024
@nikitaved
Copy link
Contributor

nikitaved commented Apr 7, 2024

@t-vi , I have not checked the changes yet, nor have I answered the posited questions :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Operator support for F.one_hot
5 participants