-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
prims.div
: equalize NVFuser prims with PyTorch for integer Tensors
#821
Conversation
c68844e
to
8e10ba4
Compare
8e10ba4
to
a4b5bf1
Compare
It is part one - the fix for integer types. We also have issues with floating types that need to get fixed to re-enable all the consistency tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you @nikitaved, awesome stuff!
return thunder.prims.div(a, b) | ||
|
||
def truediv(a, b): | ||
return a // b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is floor division, not true division
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Made a typo - will fix that. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C division has always been the true division for me :)
for f in (thunder.jit(div), thunder.jit(truediv)): | ||
rout = f(x.cpu(), y.cpu()).to(device) | ||
jout = f(x, y) | ||
assert rout.equal(jout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think these results should be equal (I think we're exposing a bug in the nvFuser implementation of prims.div
). prims.div
should implement truncation division, which is not the same as Python's floor division. See, for example, the PyTorch executor's implementation of prims.div:
def _div_prim_impl(a: Number | torch.Tensor, b: Number | torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
div
calls to prims.div
, not ltorch.div
, so they should be equivalent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but doesn't truediv
call floor division?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Floor division has a decomp that is more than just prims.div
, and this is what is being tested here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but it's a common misconception that they are. If the result is negative then they round differently. Floor division rounds down, truncation division rounds towards zero.
In Python
-9 // 5
-2
In C++
#include <iostream>
int main() {
int a = (-9 / 5);
std::cout << a;
}
prints -1. This is because -9 / 5 is -1.8, and floor division takes the floor of -1.8, which is -2, while truncation division truncates the fractional part of -1.8, giving -1.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am confused -- I am not comparing these two operations head to head, I just make sure that the cpu version is consistent with the nvfuser version for both jitted
prims.div
and* // *
.
You're correct, of course, and I just misread the code. You are comparing that the CPU and CUDA versions are consistent with each other. The fact that truncation division and floor division are distinct for integers is a separate issue, and doesn't impact this code. My mistake.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aha! @IvanYashchuk told me that nvFuser's div is performing truncation division for integers now. I mistakenly thought it wasn't. That's why this test passes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while we are looking at torchex, I think the div prim impl has a typo (twice a instead of a and b) in the if condition:
if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
while we are looking at torchex, I think the div prim impl has a typo (twice a instead of a and b) in the if condition:
if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):
I think you're right!
prims.div
: equalize NVFuser prims with PyTorchprims.div
: equalize NVFuser prims with PyTorch for integer Tensors
NVFuser implementation of
prims.div
usedreciprocal
. That caused upcasting for integer inputs and, ultimately, rendered meta checks incorrect.Fixes #808
Fixes #818