diff --git a/requirements/requirements-test.txt b/requirements/requirements-test.txt index 098e5ae..54a46f7 100644 --- a/requirements/requirements-test.txt +++ b/requirements/requirements-test.txt @@ -1,4 +1,4 @@ coverage==7.4.3 -pytest==7.4.4 +pytest==8.0.2 pytest-cov==4.1.0 pytest-xdist==3.5.0 diff --git a/synced_collections/numpy_utils.py b/synced_collections/numpy_utils.py index 1ba37c4..0ef141c 100644 --- a/synced_collections/numpy_utils.py +++ b/synced_collections/numpy_utils.py @@ -84,11 +84,11 @@ def _is_numpy_scalar(data): def _is_complex(data): """Check if an object is complex. - This function works for both numpy raw Python data types. + This function works for numpy arrays, scalars, and raw Python data types. Returns ------- bool Whether or not the input is a complex number. """ - return (NUMPY and numpy.iscomplex(data).any()) or (isinstance(data, complex)) + return (NUMPY and numpy.iscomplexobj(data)) or isinstance(data, complex) diff --git a/tests/synced_collection_test.py b/tests/synced_collection_test.py index 1fab325..c64efd5 100644 --- a/tests/synced_collection_test.py +++ b/tests/synced_collection_test.py @@ -3,12 +3,14 @@ # This software is licensed under the BSD 3-Clause License. import platform from collections.abc import MutableMapping, MutableSequence +from contextlib import nullcontext from copy import deepcopy from typing import Any, Tuple, Type import pytest from synced_collections import SyncedCollection +from synced_collections.backends.collection_zarr import ZarrCollection from synced_collections.errors import KeyTypeError from synced_collections.numpy_utils import NumpyConversionWarning @@ -526,15 +528,16 @@ def test_set_get_numpy_float_data(self, synced_collection, dtype, shape): # should fail to set correctly. raw_value = value.item() if shape is None else value.tolist() test_value = value.item(0) if isinstance(raw_value, list) else raw_value - has_corresponding_python_type = isinstance( - test_value, (numpy.number, numpy.bool_) - ) + should_fail = isinstance(test_value, (numpy.number, numpy.bool_)) - if has_corresponding_python_type: - with pytest.raises((ValueError, TypeError)), pytest.warns( - NumpyConversionWarning - ): - synced_collection["numpy_dtype_val"] = value + maybe_warn = nullcontext() + if isinstance(synced_collection, ZarrCollection): + maybe_warn = pytest.warns(NumpyConversionWarning) + + if should_fail: + with pytest.raises((ValueError, TypeError)): + with maybe_warn: + synced_collection["numpy_dtype_val"] = value else: with pytest.warns(NumpyConversionWarning): synced_collection["numpy_dtype_val"] = value @@ -553,10 +556,13 @@ def test_set_get_numpy_complex_data(self, synced_collection, dtype, shape): # not a priority to test here). value = dtype(random_sample(shape)) - with pytest.raises((ValueError, TypeError)), pytest.warns( - NumpyConversionWarning - ): - synced_collection["numpy_dtype_val"] = value + maybe_warn = nullcontext() + if isinstance(synced_collection, ZarrCollection): + maybe_warn = pytest.warns(NumpyConversionWarning) + + with pytest.raises((ValueError, TypeError)): + with maybe_warn: + synced_collection["numpy_dtype_val"] = value class SyncedListTest(SyncedCollectionTest): @@ -813,11 +819,14 @@ def test_set_get_numpy_float_data(self, synced_collection, dtype, shape): test_value = value.item(0) if isinstance(raw_value, list) else raw_value should_fail = isinstance(test_value, (numpy.number, numpy.bool_)) + maybe_warn = nullcontext() + if isinstance(synced_collection, ZarrCollection): + maybe_warn = pytest.warns(NumpyConversionWarning) + if should_fail: - with pytest.raises((ValueError, TypeError)), pytest.warns( - NumpyConversionWarning - ): - synced_collection.append(value) + with pytest.raises((ValueError, TypeError)): + with maybe_warn: + synced_collection.append(value) else: with pytest.warns(NumpyConversionWarning): synced_collection.append(value) @@ -840,14 +849,15 @@ def test_set_get_numpy_complex_data(self, synced_collection, dtype, shape): # not a priority to test here). value = dtype(random_sample(shape)) - with pytest.raises((ValueError, TypeError)), pytest.warns( - NumpyConversionWarning - ): - synced_collection.append(value) + maybe_warn = nullcontext() + if isinstance(synced_collection, ZarrCollection): + maybe_warn = pytest.warns(NumpyConversionWarning) + + with pytest.raises((ValueError, TypeError)): + with maybe_warn: + synced_collection.append(value) - with pytest.raises((ValueError, TypeError)), pytest.warns( - NumpyConversionWarning - ): + with pytest.raises((ValueError, TypeError)): synced_collection[-1] = value @pytest.mark.parametrize("dtype", NUMPY_INT_TYPES)