Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prims.div: equalize NVFuser prims with PyTorch for integer Tensors #821

Merged
merged 1 commit into from
Jul 22, 2024

Conversation

nikitaved
Copy link
Contributor

NVFuser implementation of prims.div used reciprocal. That caused upcasting for integer inputs and, ultimately, rendered meta checks incorrect.

Fixes #808
Fixes #818

@nikitaved nikitaved force-pushed the nikitaved/prims_div_nvfuser_fix branch from c68844e to 8e10ba4 Compare July 22, 2024 12:16
@nikitaved nikitaved force-pushed the nikitaved/prims_div_nvfuser_fix branch from 8e10ba4 to a4b5bf1 Compare July 22, 2024 12:30
@nikitaved
Copy link
Contributor Author

It is part one - the fix for integer types. We also have issues with floating types that need to get fixed to re-enable all the consistency tests.

Copy link
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @nikitaved, awesome stuff!

@t-vi t-vi enabled auto-merge (squash) July 22, 2024 12:49
@t-vi t-vi merged commit 44d2f3c into main Jul 22, 2024
36 checks passed
@t-vi t-vi deleted the nikitaved/prims_div_nvfuser_fix branch July 22, 2024 12:50
return thunder.prims.div(a, b)

def truediv(a, b):
return a // b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is floor division, not true division

Copy link
Contributor Author

@nikitaved nikitaved Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Made a typo - will fix that. Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C division has always been the true division for me :)

for f in (thunder.jit(div), thunder.jit(truediv)):
rout = f(x.cpu(), y.cpu()).to(device)
jout = f(x, y)
assert rout.equal(jout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these results should be equal (I think we're exposing a bug in the nvFuser implementation of prims.div). prims.div should implement truncation division, which is not the same as Python's floor division. See, for example, the PyTorch executor's implementation of prims.div:

def _div_prim_impl(a: Number | torch.Tensor, b: Number | torch.Tensor) -> torch.Tensor:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@nikitaved nikitaved Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

div calls to prims.div, not ltorch.div, so they should be equivalent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but doesn't truediv call floor division?

Copy link
Contributor Author

@nikitaved nikitaved Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floor division has a decomp that is more than just prims.div, and this is what is being tested here.

Copy link
Collaborator

@mruberry mruberry Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but it's a common misconception that they are. If the result is negative then they round differently. Floor division rounds down, truncation division rounds towards zero.

In Python

-9 // 5
-2

In C++

#include <iostream>

int main() {
    int a = (-9 / 5);
    std::cout << a; 
}

prints -1. This is because -9 / 5 is -1.8, and floor division takes the floor of -1.8, which is -2, while truncation division truncates the fractional part of -1.8, giving -1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused -- I am not comparing these two operations head to head, I just make sure that the cpu version is consistent with the nvfuser version for both jitted prims.div and * // *.

You're correct, of course, and I just misread the code. You are comparing that the CPU and CUDA versions are consistent with each other. The fact that truncation division and floor division are distinct for integers is a separate issue, and doesn't impact this code. My mistake.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi @kevinstephano @jjsjann123

Aha! @IvanYashchuk told me that nvFuser's div is performing truncation division for integers now. I mistakenly thought it wasn't. That's why this test passes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we are looking at torchex, I think the div prim impl has a typo (twice a instead of a and b) in the if condition:

if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we are looking at torchex, I think the div prim impl has a typo (twice a instead of a and b) in the if condition:

if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):

I think you're right!

@nikitaved nikitaved changed the title prims.div: equalize NVFuser prims with PyTorch prims.div: equalize NVFuser prims with PyTorch for integer Tensors Jul 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants