Skip to content

Commit

Permalink
Modified tests for cbrt, copysign, and rsqrt
Browse files Browse the repository at this point in the history
Now test more type combinations/output types
  • Loading branch information
ndgrigorian committed Oct 16, 2023
1 parent 0695fde commit 3e08194
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions dpctl/tests/elementwise/test_cbrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import dpctl.tensor as dpt
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _map_to_device_dtype, _real_fp_dtypes
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_fp_dtypes


@pytest.mark.parametrize("dtype", _real_fp_dtypes)
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
def test_cbrt_out_type(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
Expand Down
6 changes: 3 additions & 3 deletions dpctl/tests/elementwise/test_copysign.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import dpctl.tensor as dpt
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _compare_dtypes, _real_fp_dtypes
from .utils import _compare_dtypes, _no_complex_dtypes, _real_fp_dtypes


@pytest.mark.parametrize("op1_dtype", _real_fp_dtypes)
@pytest.mark.parametrize("op2_dtype", _real_fp_dtypes)
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes)
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes)
def test_copysign_dtype_matrix(op1_dtype, op2_dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(op1_dtype, q)
Expand Down
6 changes: 3 additions & 3 deletions dpctl/tests/elementwise/test_rsqrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,16 @@
import dpctl.tensor as dpt
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported

from .utils import _map_to_device_dtype, _real_fp_dtypes
from .utils import _map_to_device_dtype, _no_complex_dtypes, _real_fp_dtypes


@pytest.mark.parametrize("dtype", _real_fp_dtypes)
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
def test_rsqrt_out_type(dtype):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

x = dpt.asarray(1, dtype=dtype, sycl_queue=q)
expected_dtype = np.reciprocal(np.sqrt(1, dtype=dtype)).dtype
expected_dtype = np.reciprocal(np.sqrt(np.array(1, dtype=dtype))).dtype
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
assert dpt.rsqrt(x).dtype == expected_dtype

Expand Down

0 comments on commit 3e08194

Please sign in to comment.