Skip to content

Commit

Permalink
prims.div: equalize NVFuser prims with PyTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitaved committed Jul 22, 2024
1 parent 5cdf37b commit c68844e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
7 changes: 6 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,8 +1746,13 @@ def div(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefinitio
nva = getnv(a, fd, lc_to_nv_map)
nvb = getnv(b, fd, lc_to_nv_map)

a_dtype = dtypes.to_dtype(a)
b_dtype = dtypes.to_dtype(b)

if dtypes.is_integer_dtype(a_dtype) and dtypes.is_integer_dtype(b_dtype):
return fd.ops.div(nva, nvb)

# NOTE It's currently significantly faster for nvFuser to multiply the reciprocal than divide
# return fd.ops.div(nva, nvb)
return fd.ops.mul(nva, fd.ops.reciprocal(nvb))


Expand Down
28 changes: 28 additions & 0 deletions thunder/tests/test_nvfuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,3 +961,31 @@ def foo(t, ab):
out_ref = foo(t, ab)

assert out.equal(out_ref)


# TODO: we should improve our consistency testing
# to also include checks for the result of meta functions.
@instantiate(
dtypes=NOTHING,
executors=(nvFuserExecutor,),
)
def test_div_truediv_integer_tensors_consistency_nvfuser(executor, device, _):
def div(a, b):
return thunder.prims.div(a, b)

def truediv(a, b):
return a // b

def make_integer_tensor():
half_len = 5
t = torch.tensor([*range(-half_len, 0), *range(1, half_len + 1)], device=device)
perm = torch.randperm(2 * half_len)
return t[perm]

x = make_integer_tensor()
y = make_integer_tensor()

for f in (thunder.jit(div), thunder.jit(truediv)):
rout = f(x.cpu(), y.cpu()).to(device)
jout = f(x, y)
assert rout.equal(jout)

0 comments on commit c68844e

Please sign in to comment.