Skip to content

Commit 2d4ad51

Browse files
Disallow bool scalar conversation and update tests
1 parent da2598c commit 2d4ad51

File tree

6 files changed

+29
-43
lines changed

6 files changed

+29
-43
lines changed

dpctl/tensor/_usmarray.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,6 +1140,7 @@ cdef class usm_ndarray:
11401140

11411141
def __bool__(self):
11421142
if self.size == 1:
1143+
_check_0d_scalar_conversion(self)
11431144
view = _as_zero_dim_ndarray(self)
11441145
return view.__bool__()
11451146

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 6 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -286,18 +286,8 @@ def test_properties(dt):
286286
@pytest.mark.parametrize("shape", [tuple(), (1,), (1, 1), (1, 1, 1)])
287287
@pytest.mark.parametrize("dtype", ["|b1", "|u2", "|f4", "|i8"])
288288
class TestCopyScalar:
289-
def test_copy_bool_scalar_with_func(self, shape, dtype):
290-
try:
291-
X = dpt.usm_ndarray(shape, dtype=dtype)
292-
except dpctl.SyclDeviceCreationError:
293-
pytest.skip("No SYCL devices available")
294-
Y = np.arange(1, X.size + 1, dtype=dtype)
295-
X.usm_data.copy_from_host(Y.view("|u1"))
296-
Y.shape = tuple()
297-
assert bool(X) == bool(Y)
298-
299-
@pytest.mark.parametrize("func", [float, int, complex])
300-
def test_copy_numeric_scalar_with_func(self, func, shape, dtype):
289+
@pytest.mark.parametrize("func", [bool, float, int, complex])
290+
def test_copy_scalar_with_func(self, func, shape, dtype):
301291
try:
302292
X = dpt.usm_ndarray(shape, dtype=dtype)
303293
except dpctl.SyclDeviceCreationError:
@@ -312,18 +302,10 @@ def test_copy_numeric_scalar_with_func(self, func, shape, dtype):
312302
# 0D arrays are allowed to convert
313303
assert func(X) == func(Y)
314304

315-
def test_copy_bool_scalar_with_method(self, shape, dtype):
316-
try:
317-
X = dpt.usm_ndarray(shape, dtype=dtype)
318-
except dpctl.SyclDeviceCreationError:
319-
pytest.skip("No SYCL devices available")
320-
Y = np.arange(1, X.size + 1, dtype=dtype)
321-
X.usm_data.copy_from_host(Y.view("|u1"))
322-
Y = Y.reshape(())
323-
assert getattr(X, "__bool__")() == getattr(Y, "__bool__")()
324-
325-
@pytest.mark.parametrize("method", ["__float__", "__int__", "__complex__"])
326-
def test_copy_numeric_scalar_with_method(self, method, shape, dtype):
305+
@pytest.mark.parametrize(
306+
"method", ["__bool__", "__float__", "__int__", "__complex__"]
307+
)
308+
def test_copy_scalar_with_method(self, method, shape, dtype):
327309
try:
328310
X = dpt.usm_ndarray(shape, dtype=dtype)
329311
except dpctl.SyclDeviceCreationError:

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1430,11 +1430,12 @@ def test_nonzero_f_contig():
14301430
mask = dpt.zeros((5, 5), dtype="?", order="F")
14311431
mask[2, 3] = True
14321432

1433-
expected_res = (2, 3)
1434-
res = dpt.nonzero(mask)
1433+
expected_res = np.nonzero(dpt.asnumpy(mask))
1434+
result = dpt.nonzero(mask)
14351435

1436-
assert expected_res == res
1437-
assert mask[res]
1436+
for exp, res in zip(expected_res, result):
1437+
assert_array_equal(dpt.asnumpy(res), exp)
1438+
assert dpt.all(mask[result])
14381439

14391440

14401441
def test_nonzero_compacting():
@@ -1448,11 +1449,12 @@ def test_nonzero_compacting():
14481449
mask[3, 2, 1] = True
14491450
mask_view = mask[..., :3]
14501451

1451-
expected_res = (3, 2, 1)
1452-
res = dpt.nonzero(mask_view)
1452+
expected_res = np.nonzero(dpt.asnumpy(mask_view))
1453+
result = dpt.nonzero(mask_view)
14531454

1454-
assert expected_res == res
1455-
assert mask_view[res]
1455+
for exp, res in zip(expected_res, result):
1456+
assert_array_equal(dpt.asnumpy(res), exp)
1457+
assert dpt.all(mask_view[result])
14561458

14571459

14581460
def test_assign_scalar():

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,11 +1438,11 @@ def test_tile_size_1():
14381438
# test for gh-1627 behavior
14391439
res = dpt.tile(x1, reps)
14401440
assert x1.shape == res.shape
1441-
assert x1 == res
1441+
assert_array_equal(dpt.asnumpy(x1), dpt.asnumpy(res))
14421442

14431443
res = dpt.tile(x2, reps)
14441444
assert x2.shape == res.shape
1445-
assert x2 == res
1445+
assert_array_equal(dpt.asnumpy(x2), dpt.asnumpy(res))
14461446

14471447

14481448
def test_tile_prepends_axes():

dpctl/tests/test_usm_ndarray_operators.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ def test_comp_ops(namespace):
134134
pytest.skip("No SYCL devices available")
135135
X._set_namespace(namespace)
136136
assert X.__array_namespace__() is namespace
137-
assert X.__gt__(-1)
138-
assert X.__ge__(-1)
139-
assert not X.__lt__(-1)
140-
assert not X.__le__(-1)
141-
assert not X.__eq__(-1)
142-
assert X.__ne__(-1)
137+
assert dpt.all(X.__gt__(-1))
138+
assert dpt.all(X.__ge__(-1))
139+
assert not dpt.all(X.__lt__(-1))
140+
assert not dpt.all(X.__le__(-1))
141+
assert not dpt.all(X.__eq__(-1))
142+
assert dpt.all(X.__ne__(-1))

dpctl/tests/test_usm_ndarray_sorting.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020
import pytest
21+
from numpy.testing import assert_array_equal
2122

2223
import dpctl.tensor as dpt
2324
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
@@ -345,11 +346,11 @@ def test_radix_sort_size_1_axis():
345346

346347
x1 = dpt.ones((), dtype="i1")
347348
r1 = dpt.sort(x1, kind="radixsort")
348-
assert r1 == x1
349+
assert_array_equal(dpt.asnumpy(r1), dpt.asnumpy(x1))
349350

350351
x2 = dpt.ones([1], dtype="i1")
351352
r2 = dpt.sort(x2, kind="radixsort")
352-
assert r2 == x2
353+
assert_array_equal(dpt.asnumpy(r2), dpt.asnumpy(x2))
353354

354355
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
355356
r3 = dpt.sort(x3, kind="radixsort")
@@ -369,7 +370,7 @@ def test_radix_argsort_size_1_axis():
369370

370371
x2 = dpt.ones([1], dtype="i1")
371372
r2 = dpt.argsort(x2, kind="radixsort")
372-
assert r2 == 0
373+
assert dpt.all(r2 == 0)
373374

374375
x3 = dpt.reshape(dpt.arange(10, dtype="i1"), (10, 1))
375376
r3 = dpt.argsort(x3, kind="radixsort")

0 commit comments

Comments
 (0)