Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update meta-function for where to error on a single input aka where(mark) #146

Closed
nikitaved opened this issue Apr 9, 2024 · 4 comments · Fixed by #192
Closed

update meta-function for where to error on a single input aka where(mark) #146

nikitaved opened this issue Apr 9, 2024 · 4 comments · Fixed by #192
Labels
bug Something isn't working good first issue Good for newcomers help wanted Extra attention is needed operators

Comments

@nikitaved
Copy link
Contributor

nikitaved commented Apr 9, 2024

🐛 Bug

As per #124 (comment),
the error message in our tests is not particularly friendly since Thunder is not aware of this single argument overload.

Also, since this op is data-dependent, we cannot implement it just yet.

Therefore, it would be great to detect such usages and throw an error with comprehensive explanations.
One should probably start with checking _where_meta.

cc @apaz-cli

@nikitaved nikitaved added bug Something isn't working good first issue Good for newcomers help wanted Extra attention is needed operators labels Apr 9, 2024
@k223kim
Copy link
Contributor

k223kim commented Apr 14, 2024

Hi @nikitaved!

I'm Kaeun, one of the new contributors on Thunder (and I am loving it!). Would it be ok if I try to tackle this issue?
Along with that I would like to share what I have in mind: I was thinking about modifying core/prims.py something like the following:

def _where_meta(pred: Number | TensorProxy, a: Number | TensorProxy = None, b: Number | TensorProxy = None, /) -> TensorProxy:
    if a is None and b is None:
       raise NotImplementedError(f"Thunder does not support data-dependent operations yet!")

This also means I would have to accept None type in clang/__init__.py and torch/__init__.py. Let me know what you think about this approach and I will throw a PR right away!

Best,
Kaeun

@mruberry
Copy link
Collaborator

Hi @nikitaved!

I'm Kaeun, one of the new contributors on Thunder (and I am loving it!). Would it be ok if I try to tackle this issue? Along with that I would like to share what I have in mind: I was thinking about modifying core/prims.py something like the following:

def _where_meta(pred: Number | TensorProxy, a: Number | TensorProxy = None, b: Number | TensorProxy = None, /) -> TensorProxy:
    if a is None and b is None:
       raise NotImplementedError(f"Thunder does not support data-dependent operations yet!")

This also means I would have to accept None type in clang/__init__.py and torch/__init__.py. Let me know what you think about this approach and I will throw a PR right away!

Best, Kaeun

Hey @k223kim!

Your proposal makes sense, but I suggest tweaking it a little. Let's keep the primitive for where as-is, and instead just modify thunder.torch.where:

def where(pred: TensorLike, a: Number | TensorLike, b: Number | TensorLike, /) -> TensorLike:

It can check for a and b being None (and extend their type annotations, as you suggested).

The reason I'd prefer to just edit the torch definition of this operation is that I think the core language and primitive where operations are well-defined, and we probably wouldn't overload them with this version of where. Instead, in the future, we might sometimes map thunder.torch.where to clang.nonzero or clang.where depending on how it was called.

How does that sound?

@k223kim
Copy link
Contributor

k223kim commented Apr 15, 2024

Hi @mruberry!

I see what you mean! Your suggestion seems to align better with how PyTorch operates so let me do that instead. Some additional questions regarding this:

  • Regarding the error message, I am thinking about something along the lines of "Thunder does not support torch.where(condition)". Do you have any particular suggestions?
  • Should the test case be updated in tests/opinfo.py?

Appreciate your comments and feedback :)

@mruberry
Copy link
Collaborator

Hi @mruberry!

I see what you mean! Your suggestion seems to align better with how PyTorch operates so let me do that instead. Some additional questions regarding this:

  • Regarding the error message, I am thinking about something along the lines of "Thunder does not support torch.where(condition)". Do you have any particular suggestions?

That's a pretty good error message. I think it's slightly more consistent with other not supported messages to say "torch.where() does not support only specifying a condition".

  • Should the test case be updated in tests/opinfo.py?

That'd be great. The PR could update this OpInfo

where_opinfo = OpInfo(

to use thunder.torch.where and add an error input for the where(condition) case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers help wanted Extra attention is needed operators
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants