-
Notifications
You must be signed in to change notification settings - Fork 84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
prims.div
: equalize NVFuser prims with PyTorch for integer Tensors
#821
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -961,3 +961,33 @@ def foo(t, ab): | |||||||
out_ref = foo(t, ab) | ||||||||
|
||||||||
assert out.equal(out_ref) | ||||||||
|
||||||||
|
||||||||
# TODO: we should improve our consistency testing | ||||||||
# to also include checks for the result of meta functions. | ||||||||
@instantiate( | ||||||||
dtypes=(thunder.int64, thunder.int32), | ||||||||
executors=(nvFuserExecutor,), | ||||||||
) | ||||||||
def test_div_truediv_integer_tensors_consistency_nvfuser(executor, device, thunder_dtype): | ||||||||
dtype = ltorch.to_torch_dtype(thunder_dtype) | ||||||||
|
||||||||
def div(a, b): | ||||||||
return thunder.prims.div(a, b) | ||||||||
|
||||||||
def truediv(a, b): | ||||||||
return a // b | ||||||||
|
||||||||
def make_integer_tensor(): | ||||||||
half_len = 5 | ||||||||
t = torch.tensor([*range(-half_len, 0), *range(1, half_len + 1)], device=device, dtype=dtype) | ||||||||
perm = torch.randperm(2 * half_len) | ||||||||
return t[perm] | ||||||||
t-vi marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||
|
||||||||
x = make_integer_tensor() | ||||||||
y = make_integer_tensor() | ||||||||
|
||||||||
for f in (thunder.jit(div), thunder.jit(truediv)): | ||||||||
rout = f(x.cpu(), y.cpu()).to(device) | ||||||||
jout = f(x, y) | ||||||||
assert rout.equal(jout) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think these results should be equal (I think we're exposing a bug in the nvFuser implementation of
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, but doesn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Floor division has a decomp that is more than just There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, but it's a common misconception that they are. If the result is negative then they round differently. Floor division rounds down, truncation division rounds towards zero. In Python
In C++
prints -1. This is because -9 / 5 is -1.8, and floor division takes the floor of -1.8, which is -2, while truncation division truncates the fractional part of -1.8, giving -1. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You're correct, of course, and I just misread the code. You are comparing that the CPU and CUDA versions are consistent with each other. The fact that truncation division and floor division are distinct for integers is a separate issue, and doesn't impact this code. My mistake. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aha! @IvanYashchuk told me that nvFuser's div is performing truncation division for integers now. I mistakenly thought it wasn't. That's why this test passes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. while we are looking at torchex, I think the div prim impl has a typo (twice a instead of a and b) in the if condition:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think you're right! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is floor division, not true division
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! Made a typo - will fix that. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C division has always been the true division for me :)