Skip to content

Commit

Permalink
[fix] mul chalf: cpu scalar and cuda tensor case
Browse files Browse the repository at this point in the history
pytorch#76158 added `chalf` support for `mul` on CUDA incorrectly.

Following sample
```python
torch.mul(torch.zeros(3, device='cuda'), 2.5) # CUDA Tensor and CPU Scalar
```

fails with
```
RuntimeError: iter.device(arg).is_cuda() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/cuda/JitLoops.cuh":83, please report a bug to PyTorch. argument 2: expected a CUDA device but found cpu
```
Pull Request resolved: pytorch#76364
Approved by: https://github.com/mruberry
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Apr 30, 2022
1 parent 100e72f commit 1f1d0b3
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/BinaryMulDivKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ void mul_kernel_cuda(TensorIteratorBase& iter) {
return a * b;
}
);
jitted_gpu_kernel<mul_name, scalar_t, scalar_t, 2>(iter, mul_string);
opmath_jitted_gpu_kernel_with_scalars<mul_name, scalar_t, scalar_t>(iter, mul_string);
#else
using opmath_t = at::opmath_type<scalar_t>;
opmath_gpu_kernel_with_scalars<scalar_t>(iter, MulFunctor<opmath_t>());
Expand Down
14 changes: 14 additions & 0 deletions test/test_binary_ufuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,6 +3560,20 @@ def test_helper(x, q):
x = make_tensor((2, 3, 4), dtype=x_dtype, device=device)
test_helper(x, q)

@onlyCUDA
@dtypes(torch.chalf,)
def test_mul_chalf_tensor_and_cpu_scalar(self, device, dtype):
# Tests that Tensor and CPU Scalar work for `mul` for chalf.
# Ideally, this should be covered by `test_complex_half_reference_testing`
# from test_ops.py by checking reference_samples from the OpInfo.
# But currently that doesn't work as sample generation requires support of
# `index_select` which is not implemented for `complex32` at the
# time of writing this test.
# TODO: Remove this test once above issue is fixed.
# Ref: https://github.com/pytorch/pytorch/pull/76364
x = make_tensor((2, 2), device=device, dtype=dtype)
self.assertEqual(x * 2.5, x * torch.tensor(2.5, device=device, dtype=dtype))


tensor_binary_ops = [
'__lt__', '__le__',
Expand Down

0 comments on commit 1f1d0b3

Please sign in to comment.