Skip to content

Commit

Permalink
Finalize deprecation of zero-dimensional inputs to jnp.nonzero
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626299531
  • Loading branch information
Jake VanderPlas authored and jax authors committed Apr 19, 2024
1 parent 837f0bb commit 41fa67c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 38 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Remember to align the itemized text with the first line of an item within a list
* The {func}`jax.numpy.hypot` function now issues a deprecation warning when
passing complex-valued inputs to it. This will raise an error when the
deprecation is completed.
* Scalar arguments to {func}`jax.numpy.nonzero`, {func}`jax.numpy.where`, and
related functions now raise an error, following a similar change in NumPy.

## jaxlib 0.4.27

Expand Down
6 changes: 2 additions & 4 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,10 +1454,8 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
arr = asarray(a)
del a
if ndim(arr) == 0:
# Added 2023 Dec 6
warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()",
DeprecationWarning, stacklevel=2)
arr = atleast_1d(arr)
raise ValueError("Calling nonzero on 0d arrays is not allowed. "
"Use jnp.atleast_1d(scalar).nonzero() instead.")
mask = arr if arr.dtype == bool else (arr != 0)
calculated_size = mask.sum() if size is None else size
calculated_size = core.concrete_dim_or_error(calculated_size,
Expand Down
51 changes: 17 additions & 34 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,17 +323,15 @@ def testCountNonzero(self, shape, dtype, axis):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
for shape in nonempty_array_shapes
for shape in nonempty_nonscalar_array_shapes
for fill_value in [None, -1, shape or (1,)]
],
dtype=all_dtypes,
Expand All @@ -351,17 +349,13 @@ def np_fun(x):
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result))
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testFlatNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np.flatnonzero)
np_fun = np.flatnonzero
jnp_fun = jnp.flatnonzero
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
Expand All @@ -371,15 +365,14 @@ def testFlatNonzero(self, shape, dtype):
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=nonempty_array_shapes,
shape=nonempty_nonscalar_array_shapes,
dtype=all_dtypes,
fill_value=[None, -1, 10, (-1,), (10,)],
size=[1, 5, 10],
)
def testFlatNonzeroSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.flatnonzero(x)
if size <= len(result):
Expand All @@ -391,24 +384,20 @@ def np_fun(x):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
@jtu.sample_product(shape=nonzerodim_shapes, dtype=all_dtypes)
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
for shape in nonempty_array_shapes
for shape in nonempty_nonscalar_array_shapes
for fill_value in [None, -1, shape or (1,)]
],
dtype=all_dtypes,
Expand All @@ -427,10 +416,8 @@ def np_fun(x):
for fval, arg in safe_zip(fillvals, result.T)]).T
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
Expand Down Expand Up @@ -4490,24 +4477,20 @@ def args_maker(): return []
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=all_shapes,
shape=nonzerodim_shapes,
dtype=all_dtypes,
)
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of
# this behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shapes=filter(_shapes_are_broadcast_compatible,
Expand Down

0 comments on commit 41fa67c

Please sign in to comment.