Skip to content

Commit

Permalink
Merge pull request #509 from neutrinoceros/compat/np_unstack
Browse files Browse the repository at this point in the history
TST: declare `np.unstack` as subclass-safe (fix incompatibility with Numpy 2.1)
  • Loading branch information
jzuhone authored Jul 2, 2024
2 parents 4dc3844 + c276c79 commit 55c5286
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions unyt/tests/test_array_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@


if NUMPY_VERSION >= Version("2.0.0dev0"):
# the followin all work out of the box (tested)
# the following all work out of the box (tested)
NOOP_FUNCTIONS |= {
np.linalg.cross,
np.linalg.diagonal,
Expand All @@ -184,6 +184,11 @@
np.vecdot,
}

if NUMPY_VERSION >= Version("2.1.0dev0"):
NOOP_FUNCTIONS |= {
np.unstack,
}

# Functions for which behaviour is intentionally left to default
IGNORED_FUNCTIONS = {
np.i0,
Expand Down Expand Up @@ -1223,16 +1228,25 @@ def test_broadcast_arrays():


@pytest.mark.parametrize(
"func, args",
"func_name, args",
[
(np.split, (3, 2)),
(np.dsplit, (3,)),
(np.hsplit, (2,)),
(np.vsplit, (1,)),
(np.array_split, (3,)),
("split", (3, 2)),
("dsplit", (3,)),
("hsplit", (2,)),
("vsplit", (1,)),
("array_split", (3,)),
pytest.param(
"unstack",
(),
marks=pytest.mark.skipif(
NUMPY_VERSION < Version("2.1.0dev0"),
reason="np.unstack is new in NumPy 2.1",
),
),
],
)
def test_xsplit(func, args):
def test_xsplit(func_name, args):
func = getattr(np, func_name)
x = [
[
[
Expand Down

0 comments on commit 55c5286

Please sign in to comment.