diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 1788bd5430..5c410aca46 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -1740,14 +1740,17 @@ def bitwise_xor(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionD register_supported(PrimIDs.BITWISE_XOR, bitwise_xor, _elementwise_binary_check) -# TODO nvFuser's div operation is not equivalent to the div primitive -# (mruberry) I need to investigate if nvFuser exposes a truncation division operation def div(a: TensorProxy | Number, b: TensorProxy | Number, *, fd: FusionDefinition, lc_to_nv_map: dict) -> Any: nva = getnv(a, fd, lc_to_nv_map) nvb = getnv(b, fd, lc_to_nv_map) + a_dtype = dtypes.to_dtype(a) + b_dtype = dtypes.to_dtype(b) + + if dtypes.is_integer_dtype(a_dtype) and dtypes.is_integer_dtype(b_dtype): + return fd.ops.div(nva, nvb) + # NOTE It's currently significantly faster for nvFuser to multiply the reciprocal than divide - # return fd.ops.div(nva, nvb) return fd.ops.mul(nva, fd.ops.reciprocal(nvb)) diff --git a/thunder/tests/test_nvfuser.py b/thunder/tests/test_nvfuser.py index a4dad1f1aa..1f2d042242 100644 --- a/thunder/tests/test_nvfuser.py +++ b/thunder/tests/test_nvfuser.py @@ -961,3 +961,33 @@ def foo(t, ab): out_ref = foo(t, ab) assert out.equal(out_ref) + + +# TODO: we should improve our consistency testing +# to also include checks for the result of meta functions. +@instantiate( + dtypes=(thunder.int64, thunder.int32), + executors=(nvFuserExecutor,), +) +def test_div_truediv_integer_tensors_consistency_nvfuser(executor, device, thunder_dtype): + dtype = ltorch.to_torch_dtype(thunder_dtype) + + def div(a, b): + return thunder.prims.div(a, b) + + def truediv(a, b): + return a // b + + def make_integer_tensor(): + half_len = 5 + t = torch.tensor([*range(-half_len, 0), *range(1, half_len + 1)], device=device, dtype=dtype) + perm = torch.randperm(2 * half_len) + return t[perm] + + x = make_integer_tensor() + y = make_integer_tensor() + + for f in (thunder.jit(div), thunder.jit(truediv)): + rout = f(x.cpu(), y.cpu()).to(device) + jout = f(x, y) + assert rout.equal(jout)