diff --git a/CHANGELOG.md b/CHANGELOG.md index 0121e67e4441..690c50baf993 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list * The {func}`jax.numpy.hypot` function now issues a deprecation warning when passing complex-valued inputs to it. This will raise an error when the deprecation is completed. + * Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and + related functions now raise an error, following a similar change in NumPy. ## jaxlib 0.4.27 diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 080d67591d3f..3767633deefc 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1454,10 +1454,8 @@ def nonzero(a: ArrayLike, *, size: int | None = None, arr = asarray(a) del a if ndim(arr) == 0: - # Added 2023 Dec 6 - warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()", - DeprecationWarning, stacklevel=2) - arr = atleast_1d(arr) + raise ValueError("Calling nonzero on 0d arrays is not allowed. " + "Use jnp.atleast_1d(scalar).nonzero() instead.") mask = arr if arr.dtype == bool else (arr != 0) calculated_size = mask.sum() if size is None else size calculated_size = core.concrete_dim_or_error(calculated_size, diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 39cb1c8532b1..ddc599792e63 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -323,17 +323,15 @@ def testCountNonzero(self, shape, dtype, axis): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testNonzero(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes + for shape in nonempty_nonscalar_array_shapes for fill_value in [None, -1, shape or (1,)] ], dtype=all_dtypes, @@ -351,17 +349,13 @@ def np_fun(x): return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)]) for fval, arg in safe_zip(fillvals, result)) jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testFlatNonzero(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) - np_fun = jtu.ignore_warning( - category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*")(np.flatnonzero) + np_fun = np.flatnonzero jnp_fun = jnp.flatnonzero args_maker = lambda: [rng(shape, dtype)] self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) @@ -371,7 +365,7 @@ def testFlatNonzero(self, shape, dtype): self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - shape=nonempty_array_shapes, + shape=nonempty_nonscalar_array_shapes, dtype=all_dtypes, fill_value=[None, -1, 10, (-1,), (10,)], size=[1, 5, 10], @@ -379,7 +373,6 @@ def testFlatNonzero(self, shape, dtype): def testFlatNonzeroSize(self, shape, dtype, size, fill_value): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - @jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*") def np_fun(x): result = np.flatnonzero(x) if size <= len(result): @@ -391,24 +384,20 @@ def np_fun(x): self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) self._CompileAndCheck(jnp_fun, args_maker) - @jtu.sample_product(shape=all_shapes, dtype=all_dtypes) + @jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes) def testArgWhere(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of this # behavior is in testNonzeroSize(). jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(shape=shape, fill_value=fill_value) - for shape in nonempty_array_shapes + for shape in nonempty_nonscalar_array_shapes for fill_value in [None, -1, shape or (1,)] ], dtype=all_dtypes, @@ -427,10 +416,8 @@ def np_fun(x): for fval, arg in safe_zip(fillvals, result.T)]).T jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) - self._CompileAndCheck(jnp_fun, args_maker) + self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( [dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name), @@ -4490,24 +4477,20 @@ def args_maker(): return [] self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( - shape=all_shapes, + shape=nonzerodim_shapes, dtype=all_dtypes, ) def testWhereOneArgument(self, shape, dtype): rng = jtu.rand_some_zero(self.rng()) args_maker = lambda: [rng(shape, dtype)] - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) + self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False) # JIT compilation requires specifying a size statically. Full test of # this behavior is in testNonzeroSize(). jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2) - with jtu.ignore_warning(category=DeprecationWarning, - message="Calling nonzero on 0d arrays.*"): - self._CompileAndCheck(jnp_fun, args_maker) + self._CompileAndCheck(jnp_fun, args_maker) @jtu.sample_product( shapes=filter(_shapes_are_broadcast_compatible,