Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Operator support for F.one_hot #64

Closed
kyo-takano opened this issue Mar 24, 2024 · 16 comments · Fixed by #128
Closed

Operator support for F.one_hot #64

kyo-takano opened this issue Mar 24, 2024 · 16 comments · Fixed by #128
Assignees
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed operators

Comments

@kyo-takano
Copy link

kyo-takano commented Mar 24, 2024

🐛 Bug

thunder fails When attempting to compile a graph containing torch.nn.functional.one_hot within the forward pass.
The error message indicates that the input to the method must be a Tensor, but a TensorProxy is received instead.

To Reproduce

Steps to reproduce the behavior:

  • Define a PyTorch model class with a forward pass involving F.one_hot to convert the input tensor to a one-hot encoded representation.
  • Create an instance of the model and evaluate it on a random input tensor.
  • Compile the model using thunder.jit.
  • Call the compiled model with the same input tensor.

Example

import thunder


class MLP(nn.Module):
    def __init__(self, hidden_size=1024):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(6 * 256, hidden_size, bias=False)
        self.head = nn.Linear(hidden_size, 32000, bias=False)

    def forward(self, inputs):
        x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
        x = self.hidden(x)
        logits = self.head(x)
        return logits


x = torch.randint(0, 6, (1, 256))

model = MLP(1024).eval()
print(model(x))

model = thunder.jit(model)
print(model(x))
Output
tensor([[-0.1134, -0.0827, -0.0205,  ...,  0.0757,  0.0066,  0.0974]],
       grad_fn=<MmBackward0>)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
[<ipython-input-6-6425e5faad6e>](https://localhost:8080/#) in <cell line: 23>()
     21 
     22 model = thunder.jit(model)
---> 23 print(model(x))

16 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1509             # type ignore was added because at this point one knows that
   1510             # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511             name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
   1512             if name:
   1513                 tracing_state.push_scope(name)

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1518         finally:
   1519             if recording_scopes:
-> 1520                 tracing_state.pop_scope()
   1521         return result
   1522 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in forward(self, *args, **kwargs)
    192 
    193     def forward(self, *args, **kwargs):
--> 194         res = self._forward_fn(*args, **kwargs)
    195         return res
    196 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in fn_(*args, **kwargs)
    609         cs.calls += 1
    610 
--> 611         cache_entry, inps, pro_to_epi = get_computation_and_inputs(*args, **kwargs)
    612         cs.last_trace_host_execution_start = time.time_ns()
    613 

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in cache_info_wrapper(*args, **kwargs)
    260         tok = _cache_info_ctx.set({})
    261         try:
--> 262             res = fn(*args, **kwargs)
    263         finally:
    264             _cache_info_ctx.reset(tok)

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in get_computation_and_inputs(*args, **kwargs)
    496                 prologue_trc: TraceCtx
    497                 computation_trc: TraceCtx
--> 498                 prologue_trc, computation_trc, *maybe_epilogue = interpreter(
    499                     fn, args, kwargs, sharp_edges=cd.sharp_edges
    500                 )

[/usr/local/lib/python3.10/dist-packages/thunder/__init__.py](https://localhost:8080/#) in _general_frontend(fn, args, kwargs, sharp_edges)
    173 # Translates the Python function to a thunder program using the thunder interpreter
    174 def _general_frontend(fn: Callable, args, kwargs, /, *, sharp_edges: SHARP_EDGES_OPTIONS) -> tuple[TraceCtx, TraceCtx]:
--> 175     return thunder_general_jit(fn, args, kwargs, sharp_edges=sharp_edges)
    176 
    177 

[/usr/local/lib/python3.10/dist-packages/thunder/core/jit_ext.py](https://localhost:8080/#) in thunder_general_jit(fn, args, kwargs, sharp_edges)
   1384     with general_jit_ctx(ctx):
   1385         with tracectx(computation_trace):
-> 1386             result = jfn(*args, **kwargs)
   1387             prims.python_return(result)
   1388             process_recorded_modifications(ctx, epilogue_trace)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_(*args, **kwargs)
   6578                 assert isinstance(e, BaseException), e
   6579                 runtimectx.curexc = None
-> 6580                 raise e
   6581 
   6582             return interpretation_result

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in fn_2()
   6541                 def getfn():
   6542                     def fn_2(args, kwargs):
-> 6543                         return fn(*args, **kwargs)
   6544 
   6545                     return fn_2

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _wrapped_call_impl()
   1509             # type ignore was added because at this point one knows that
   1510             # torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
-> 1511             name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None  # type: ignore[index, operator] # noqa: B950
   1512             if name:
   1513                 tracing_state.push_scope(name)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_impl()
   1518         finally:
   1519             if recording_scopes:
-> 1520                 tracing_state.pop_scope()
   1521         return result
   1522 

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _impl()
   5940 
   5941         def _impl(fn, *args, **kwargs):
-> 5942             return fn.__func__(fn.__self__, *args, **kwargs)
   5943 
   5944         return _interpret_call(_impl, wrapped_fn, *args, **kwargs)  # type: ignore

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in forward()
      9 
     10     def forward(self, inputs):
---> 11         x = F.one_hot(inputs, 6).reshape(-1, 6 * 256).float()
     12         x = self.hidden(x)
     13         logits = self.head(x)

[/usr/local/lib/python3.10/dist-packages/thunder/core/interpreter.py](https://localhost:8080/#) in _call_dispatch(compilectx, runtimectx, fn, *args, **kwargs)
   6067         kwargs_ = {unwrap(k): unwrap(v) for k, v in kwargs.items()}
   6068         try:
-> 6069             opaque_result: Any = fn(*args_, **kwargs_)
   6070         except Exception as e:
   6071             runtimectx.curexc = e

TypeError: one_hot(): argument 'input' (position 1) must be Tensor, not TensorProxy

Environment

  • OS: Ubuntu/Google Colab
  • Python Version: 3.10
  • PyTorch Version: 2.3.0.dev20240314+cu121
  • Thunder Version: 0.1.0
  • Installation:
pip install --pre 'nvfuser-cu121[torch]' --extra-index-url https://pypi.nvidia.com
pip install lightning-thunder

Additional context

  • Other functional methods like F.relu doesn't seem to raise the issue.
@kyo-takano kyo-takano added bug Something isn't working help wanted Extra attention is needed labels Mar 24, 2024
@t-vi
Copy link
Collaborator

t-vi commented Mar 24, 2024

Hello @kyo-takano , thank you for trying out thunder! This is the thunder way (we should have a FAQ) of saying "this operator is not supported yet in thunder". We are actually working on giving a better error message.

If you want to have a comprehensive check on whether a model you want to run has ops we don't support yet, you can use examine to check.

Implementing an operator like one_hot is not entirely trivial, but I would guess that it could be a great entry point for an aspiring thunder developer. 😉 (Reminds me a lot of the early PyTorch days, my first PR almost 7 years ago was to get double derivatives for clamp going.) one_hot might be decomposed using scatter add.)

@t-vi t-vi changed the title Compilation Issue with F.one_hot Operator support for F.one_hot Mar 24, 2024
@t-vi t-vi added enhancement New feature or request good first issue Good for newcomers operators and removed bug Something isn't working labels Mar 24, 2024
@kyo-takano
Copy link
Author

Thanks @t-vi

This is just a heads up, as I'd expect to see many issues of the same kind, given that functional operations like this are used everywhere in PyTorch implementations.

This bug is something that does not occur with the native compiler regardless of the chosen backend. So, if thunder.jit is intended to be a drop-in replacement for torch.compile, it might be necessary to address this issue to ensure a seamless transition for users, depending on the number & significance of such incompatible operations.

@lantiga
Copy link
Collaborator

lantiga commented Mar 25, 2024

Thank you @kyo-takano, the limitation is not related to functional operations per se, but the number of ops we currently support. They will grow quickly, but it’ll take a bit.

For this reason (and we’ll add a FAQ soon) Thunder is not intended to be a drop in replacement for torch.compile at this stage, with the same coverage of torch.compile. However we do focus on:

What model are you targeting specifically?

@kyo-takano
Copy link
Author

Thank you @lantiga

the limitation is [...] the number of ops we currently support

Thunder is not intended to be a drop in replacement for torch.compile at this stage

Understood. Thanks for clarifying these points.

What model are you targeting specifically?

I don't have a particular model I need to compile with thunder right away.
I just thought it would be beneficial to flag an issue about a widely used operation that is currently incompatible.

@shaharelys
Copy link
Contributor

Hi everyone, I'm new to contributing to open-source and very interested in the Thunder project. I noticed this issue on and I'd love to take on this opportunity to contribute. I have some experience with Python and Pytorch and am eager to learn more about Thunder and PyTorch internals. Could I possibly take on this issue? Any guidance or suggestions on how to get started would be greatly appreciated!

@shaharelys
Copy link
Contributor

Thank you for the support. I am starting to work on implementing support for the one_hot operator as discussed. I'll keep the thread updated with my progress and any questions I might have.

Could one of the collaborators please assign this issue to me? Thanks!

@lantiga
Copy link
Collaborator

lantiga commented Mar 25, 2024

Awesome, welcome aboard! Looking forward to your contribution. One super-useful thing you could do is record your journey as a first contributor so we can create more onboarding material and address the papercuts.

@nikitaved
Copy link
Contributor

nikitaved commented Mar 25, 2024

Hey, @kyo-takano , thank you for the issue!

Hey, @shaharelys , super excited about your interest in helping us out!

A couple of, hopefully, helpful notes:

  • As @t-vi pointed out, this operation might not necessarily be trivial to implement, so I would advice to familiarize oneself with how this operation is implemented in PyTorch. Special attention has to be paid to handling wrong and edge-case inputs.
  • Once the behavior of one_hot is understood, it is about time to think about implementing it! It is very likely that one_hot could be implemented as a composition of other torch-like operations, so it does, as we say it, "decompose" into "primitive" operations. So, if your algorithm does depend on scatter_add, we can check whether this operation is available to us. I like using git grep, and it gives the following:
>>> git grep "def scatter_add"
thunder/clang/__init__.py:def scatter_add(a: TensorProxy, /, indices: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
thunder/core/prims.py:def scatter_add_meta(a: TensorProxy, /, index: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
thunder/tests/opinfos.py:def scatter_add_sample_generator(op, device, dtype, requires_grad, **kwargs):
thunder/torch/__init__.py:def scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike:
  • Once all the required "primitive" operations are available (defined in clang/__init__.py and/or core/prims.py and/or torch.__init__.py), all high-level decompositions are implemented in thunder/torch/__init__.py and these try to match the behavior of PyTorch. A good and rather complex example of a non-trivial decomposition could be the interpolate operation.
    @torchsymbol(torch.nn.functional.interpolate, is_method=False)
  • Never forget to always test things! The grep from above contains the thunder/tests/opinfos.py:def scatter_add_sample_generator(op, device, dtype, requires_grad, **kwargs): line. It defines a generator that produces inputs the scatter_add function is being tested on. You could have a look at other entries of scatter_add in the file to figure out all the remaining part which are necessary to get the testing going. Once the code is filled in, you can run the one_hot-related tests as pytest -sv thunder/tests/test_ops.py -k one_hot.

If you have any questions, do not hesitate to ping me! Have fun and thank you!

@shaharelys
Copy link
Contributor

shaharelys commented Mar 25, 2024

@nikitaved, thank you for the detailed guidance!

  • I will dive into PyTorch's implementation to understand one_hot more deeply.

  • As for the mention of def scatter_add, as I'm still familiarizing myself with the project's structure, its role and how it aids in achieving our goal with one_hot isn't fully clear to me yet.

  • Could you recommend specific areas of the codebase or modules as good starting points? From your comment, I gather that clang/init.py, core/prims.py, and torch/init.py are crucial. Are there other parts I should focus on?

  • Am I correct in understanding that the approach you are suggestion is to decompose one_hot into a set of primitive operations already supported by Thunder?

Also, I'd like to know the typical timeline for resolving issues like this. Recognizing that I'm new and still catching up, I want to set realistic expectations for my contribution timeline. What is usually anticipated?

Thank you for your support!

@nikitaved
Copy link
Contributor

nikitaved commented Mar 25, 2024

@nikitaved, thank you for the detailed guidance!

Always happy to help!

Could you recommend specific areas of the codebase or modules as good starting points? From your comment, I gather that clang/init.py, core/prims.py, and torch/init.py are crucial. Are there other parts I should focus on?

These should be more than enough to start with operations.
I would recommend specifically focusing on thunder/torch/__init__.py for now. prims and clang could be optional.
Also have a look at sample input generators in thunder/tests/opinfos.py.

Am I correct in understanding that the approach you are suggestion is to decompose one_hot into a set of primitive operations already supported by Thunder?

You are correct! This simplifies things a lot!

Also, I'd like to know the typical timeline for resolving issues like this. Recognizing that I'm new and still catching up, I want to set realistic expectations for my contribution timeline. What is usually anticipated?

Take as much time as you need. The outcome may turn out to be just as simple as adding a gelu.

@torchsymbol(torch.nn.functional.gelu, is_method=False)

In fact, if you think that we are missing some simple activation functions, you could open an issue, work on them, and then return to one_hot back once you are comfortable with the codebase.

@shaharelys
Copy link
Contributor

shaharelys commented Mar 26, 2024

Awesome, welcome aboard! Looking forward to your contribution. One super-useful thing you could do is record your journey as a first contributor so we can create more onboarding material and address the papercuts.

Hey @lantiga , per your request, I will try to list things that it took some time for me to understand and I think a better onboarding material could help. For me the easiest way is to just write these here under this thread from time to time. Will that work for you guys?

Just note that I am very new to open source and I am still a junior so some of my comments might be subject to my lack of experience and not to lack of onboarding material. Anyways, for me it is hard to tell the difference so this is just a little heads up (:

Here are my initial comments (some are also questions):
Context - Had my first review today over core/prims, core/symbol, clang/init, thunder/torch/init

  1. At the high level, I'd love to get some insight into the idea of Thunder. Like, how could you possibly make torch faster?
  2. Should all pytorch primitives be listed also under thunder primitives? (After a short glimse into the pytorch prim file I guess the answer is no, but why?)
  3. Could you make some order for me in the roles of clang, prims, symbol?
  4. What is the role of TensorProxy as oppose to Tensor and other Proxy types?
  5. Generally a bit more documentation could help. A good example is from clang/init where it states '# This file defines the operations in thunder.jit's "core" language.'. As for other modules it took some time for me to understand that the 'header' comment is actually refering to the whole file.

@nikitaved
Copy link
Contributor

Hey, @shaharelys ! Congrats on your hardswish PR! Do you need more help with this one? I can expand more on your questions posted above in your next PR. What do you think?

@shaharelys
Copy link
Contributor

shaharelys commented Apr 2, 2024

Hey @nikitaved! Actually, I think I'm handling this for now. We will likely have the PR for this ready soon, I believe.

But since you asked, just a quick question; I've already got the basic structure set up:

def one_hot(tensor: Tensor, num_classes: int = -1) -> Tensor:
    if -1 == num_classes:
        num_classes = int(torch.max(tensor)) + 1

    canvas = torch.zeros(*tensor.shape, num_classes, dtype=torch.long)
    index = tensor.unsqueeze(-1)
    src = torch.ones_like(index, dtype=torch.long)

    return canvas.scatter_add_(dim=-1, index=index, src=src)

I could find all needed methods as building blocks, but not torch.max.
How would you go about this?

@nikitaved
Copy link
Contributor

@shaharelys I guess we should have argmax.

@nikitaved
Copy link
Contributor

nikitaved commented Apr 2, 2024

@shaharelys , I also suspect that we might have issues with num_classes == -1 because we need to get max.item to be able to implement that. Or do we properly proxify scalars now, @mruberry ?
We can still implement this function at least partially and file an issue for potential extensions.

@mruberry
Copy link
Collaborator

mruberry commented Apr 2, 2024

@shaharelys , I also suspect that we might have issues with num_classes == -1 because we need to get max.item to be able to implement that. Or do we properly proxify scalars now, @mruberry ? We can still implement this function at least partially and file an issue for potential extensions.

I believe the latest is that we treat numeric inputs as constants, and NumberProxies should only be generated for numbers whose value is only determined at runtime (like those coming from calls to .item()). If that's not the case there's probably a bug.

num_classes == -1 is an issue because the shape of the output tensor would only be determined at runtime, which we do not support. We could look at extending the one_hot function in the future to support a shape parameter, like JAX's nonzero (it calls the parameter size), but I'm not immediately sure how that would work, and I agree with you that we should just limit the functionality for now and raise a NotImplementedError if num_classes == -1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers help wanted Extra attention is needed operators
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants