Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prims.div: equalize NVFuser prims with PyTorch for integer Tensors #821

Merged
merged 1 commit into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,14 +1740,17 @@ def bitwise_xor(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionD
register_supported(PrimIDs.BITWISE_XOR, bitwise_xor, _elementwise_binary_check)


# TODO nvFuser's div operation is not equivalent to the div primitive
# (mruberry) I need to investigate if nvFuser exposes a truncation division operation
def div(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
nva = getnv(a, fd, lc_to_nv_map)
nvb = getnv(b, fd, lc_to_nv_map)

a_dtype = dtypes.to_dtype(a)
b_dtype = dtypes.to_dtype(b)

if dtypes.is_integer_dtype(a_dtype) and dtypes.is_integer_dtype(b_dtype):
return fd.ops.div(nva, nvb)

# NOTE It's currently significantly faster for nvFuser to multiply the reciprocal than divide
# return fd.ops.div(nva, nvb)
return fd.ops.mul(nva, fd.ops.reciprocal(nvb))


Expand Down
30 changes: 30 additions & 0 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,3 +961,33 @@ def foo(t, ab):
out_ref = foo(t, ab)

assert out.equal(out_ref)


# TODO: we should improve our consistency testing
# to also include checks for the result of meta functions.
@instantiate(
dtypes=(thunder.int64, thunder.int32),
executors=(nvFuserExecutor,),
)
def test_div_truediv_integer_tensors_consistency_nvfuser(executor, device, thunder_dtype):
dtype = ltorch.to_torch_dtype(thunder_dtype)

def div(a, b):
return thunder.prims.div(a, b)

def truediv(a, b):
return a // b
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is floor division, not true division

Copy link
Contributor Author

@nikitaved nikitaved Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Made a typo - will fix that. Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C division has always been the true division for me :)


def make_integer_tensor():
half_len = 5
t = torch.tensor([*range(-half_len, 0), *range(1, half_len + 1)], device=device, dtype=dtype)
perm = torch.randperm(2 * half_len)
return t[perm]
t-vi marked this conversation as resolved.
Show resolved Hide resolved

x = make_integer_tensor()
y = make_integer_tensor()

for f in (thunder.jit(div), thunder.jit(truediv)):
rout = f(x.cpu(), y.cpu()).to(device)
jout = f(x, y)
assert rout.equal(jout)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these results should be equal (I think we're exposing a bug in the nvFuser implementation of prims.div). prims.div should implement truncation division, which is not the same as Python's floor division. See, for example, the PyTorch executor's implementation of prims.div:

def _div_prim_impl(a: Number | torch.Tensor, b: Number | torch.Tensor) -> torch.Tensor:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@nikitaved nikitaved Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

div calls to prims.div, not ltorch.div, so they should be equivalent.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but doesn't truediv call floor division?

Copy link
Contributor Author

@nikitaved nikitaved Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Floor division has a decomp that is more than just prims.div, and this is what is being tested here.

Copy link
Collaborator

@mruberry mruberry Jul 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but it's a common misconception that they are. If the result is negative then they round differently. Floor division rounds down, truncation division rounds towards zero.

In Python

-9 // 5
-2

In C++

#include <iostream>

int main() {
    int a = (-9 / 5);
    std::cout << a; 
}

prints -1. This is because -9 / 5 is -1.8, and floor division takes the floor of -1.8, which is -2, while truncation division truncates the fractional part of -1.8, giving -1.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am confused -- I am not comparing these two operations head to head, I just make sure that the cpu version is consistent with the nvfuser version for both jitted prims.div and * // *.

You're correct, of course, and I just misread the code. You are comparing that the CPU and CUDA versions are consistent with each other. The fact that truncation division and floor division are distinct for integers is a separate issue, and doesn't impact this code. My mistake.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi @kevinstephano @jjsjann123

Aha! @IvanYashchuk told me that nvFuser's div is performing truncation division for integers now. I mistakenly thought it wasn't. That's why this test passes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we are looking at torchex, I think the div prim impl has a typo (twice a instead of a and b) in the if condition:

if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while we are looking at torchex, I think the div prim impl has a typo (twice a instead of a and b) in the if condition:

if dtypes.is_exact_dtype(to_dtype(a.dtype)) and dtypes.is_exact_dtype(to_dtype(a.dtype)):

I think you're right!

Loading