Skip to content

Commit

Permalink
prims.div: equalize NVFuser prims with PyTorch (#821)
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved authored Jul 22, 2024
1 parent 5cdf37b commit 44d2f3c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
9 changes: 6 additions & 3 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,14 +1740,17 @@ def bitwise_xor(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionD
register_supported(PrimIDs.BITWISE_XOR, bitwise_xor, _elementwise_binary_check)


# TODO nvFuser's div operation is not equivalent to the div primitive
# (mruberry) I need to investigate if nvFuser exposes a truncation division operation
def div(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any:
nva = getnv(a, fd, lc_to_nv_map)
nvb = getnv(b, fd, lc_to_nv_map)

a_dtype = dtypes.to_dtype(a)
b_dtype = dtypes.to_dtype(b)

if dtypes.is_integer_dtype(a_dtype) and dtypes.is_integer_dtype(b_dtype):
return fd.ops.div(nva, nvb)

# NOTE It's currently significantly faster for nvFuser to multiply the reciprocal than divide
# return fd.ops.div(nva, nvb)
return fd.ops.mul(nva, fd.ops.reciprocal(nvb))


Expand Down
30 changes: 30 additions & 0 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,3 +961,33 @@ def foo(t, ab):
out_ref = foo(t, ab)

assert out.equal(out_ref)


# TODO: we should improve our consistency testing
# to also include checks for the result of meta functions.
@instantiate(
dtypes=(thunder.int64, thunder.int32),
executors=(nvFuserExecutor,),
)
def test_div_truediv_integer_tensors_consistency_nvfuser(executor, device, thunder_dtype):
dtype = ltorch.to_torch_dtype(thunder_dtype)

def div(a, b):
return thunder.prims.div(a, b)

def truediv(a, b):
return a // b

def make_integer_tensor():
half_len = 5
t = torch.tensor([*range(-half_len, 0), *range(1, half_len + 1)], device=device, dtype=dtype)
perm = torch.randperm(2 * half_len)
return t[perm]

x = make_integer_tensor()
y = make_integer_tensor()

for f in (thunder.jit(div), thunder.jit(truediv)):
rout = f(x.cpu(), y.cpu()).to(device)
jout = f(x, y)
assert rout.equal(jout)

0 comments on commit 44d2f3c

Please sign in to comment.