Skip to content

Commit

Permalink
Merge pull request #18850 from jakevdp:nonzero-zerodim
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 588534749
  • Loading branch information
jax authors committed Dec 6, 2023
2 parents 1dd68c5 + 5196004 commit 5bdc303
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 26 deletions.
7 changes: 6 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1404,8 +1404,13 @@ def nonzero(a: ArrayLike, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]:
util.check_arraylike("nonzero", a)
arr = atleast_1d(a)
arr = asarray(a)
del a
if ndim(arr) == 0:
# Added 2023 Dec 6
warnings.warn("Calling nonzero on 0d arrays is deprecated. Use `atleast_1d(arr).nonzero()",
DeprecationWarning, stacklevel=2)
arr = atleast_1d(arr)
mask = arr if arr.dtype == bool else (arr != 0)
calculated_size = mask.sum() if size is None else size
calculated_size = core.concrete_dim_or_error(calculated_size,
Expand Down
51 changes: 26 additions & 25 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,13 +280,10 @@ def testCountNonzero(self, shape, dtype, axis):
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
def testNonzero(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.nonzero(x)
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np_fun)
jnp_fun = lambda x: jnp.nonzero(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.nonzero, jnp.nonzero, args_maker, check_dtypes=False)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
Expand All @@ -299,7 +296,6 @@ def testNonzero(self, shape, dtype):
def testNonzeroSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.nonzero(x)
if size <= len(result[0]):
Expand All @@ -309,8 +305,10 @@ def np_fun(x):
return tuple(np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result))
jnp_fun = lambda x: jnp.nonzero(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
def testFlatNonzero(self, shape, dtype):
Expand Down Expand Up @@ -350,17 +348,17 @@ def np_fun(x):
@jtu.sample_product(shape=all_shapes, dtype=all_dtypes)
def testArgWhere(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np.argwhere)
jnp_fun = jnp.argwhere
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.argwhere, jnp.argwhere, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of this
# behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.argwhere(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)
with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(shape=shape, fill_value=fill_value)
Expand All @@ -373,7 +371,6 @@ def testArgWhere(self, shape, dtype):
def testArgWhereSize(self, shape, dtype, size, fill_value):
rng = jtu.rand_some_zero(self.rng())
args_maker = lambda: [rng(shape, dtype)]
@jtu.ignore_warning(category=DeprecationWarning, message="Calling nonzero on 0d arrays.*")
def np_fun(x):
result = np.argwhere(x)
if size <= len(result):
Expand All @@ -383,8 +380,11 @@ def np_fun(x):
return np.empty((size, 0), dtype=int) if np.ndim(x) == 0 else np.stack([np.concatenate([arg, np.full(size - len(arg), fval, arg.dtype)])
for fval, arg in safe_zip(fillvals, result.T)]).T
jnp_fun = lambda x: jnp.argwhere(x, size=size, fill_value=fill_value)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
[dict(np_op=getattr(np, rec.name), jnp_op=getattr(jnp, rec.name),
Expand Down Expand Up @@ -4086,18 +4086,19 @@ def args_maker(): return []
)
def testWhereOneArgument(self, shape, dtype):
rng = jtu.rand_some_zero(self.rng())
np_fun = lambda x: np.where(x)
np_fun = jtu.ignore_warning(
category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*")(np_fun)
jnp_fun = lambda x: jnp.where(x)
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker, check_dtypes=False)

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CheckAgainstNumpy(np.where, jnp.where, args_maker, check_dtypes=False)

# JIT compilation requires specifying a size statically. Full test of
# this behavior is in testNonzeroSize().
jnp_fun = lambda x: jnp.where(x, size=np.size(x) // 2)
self._CompileAndCheck(jnp_fun, args_maker)

with jtu.ignore_warning(category=DeprecationWarning,
message="Calling nonzero on 0d arrays.*"):
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shapes=filter(_shapes_are_broadcast_compatible,
Expand Down

0 comments on commit 5bdc303

Please sign in to comment.