Skip to content

Commit

Permalink
Merge pull request #22445 from dfm:numpy-nightly-unique
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652520243
  • Loading branch information
jax authors committed Jul 15, 2024
2 parents cab1f85 + 7857bd3 commit 26ec43f
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
8 changes: 3 additions & 5 deletions jax/_src/numpy/setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,6 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
inv_idx = inv_idx.at[perm].set(imask)
else:
inv_idx = zeros(ar.shape[axis], dtype=int)
if ar.ndim > 1:
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
ret += (inv_idx,)
if return_counts:
if aux.size:
Expand Down Expand Up @@ -550,7 +548,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
``ar[unique_index]`` is equivalent to ``unique_values``.
- ``unique_inverse``:
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
is None, or of shape ``(ar.shape[axis],)`` if ``axis`` is specified.
Contains the indices within ``unique_values`` of each value in ``ar``. For 1D inputs,
``unique_values[unique_inverse]`` is equivalent to ``ar``.
- ``unique_counts``:
Expand Down Expand Up @@ -652,10 +650,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
Array(True, dtype=bool)
In multiple dimensions, the input can be reconstructed using
:func:`jax.numpy.take_along_axis`:
:func:`jax.numpy.take`:
>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
>>> jnp.all(jnp.take_along_axis(values, inverse, axis=0) == M)
>>> jnp.all(jnp.take(values, inverse, axis=0) == M)
Array(True, dtype=bool)
**Returning counts**
Expand Down
6 changes: 3 additions & 3 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@ def np_unique_backport(ar, return_index=False, return_inverse=False, return_coun
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
return_counts=return_counts, axis=axis, **kwds)
if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse:
if jtu.numpy_version() >= (2, 0, 1) or np.ndim(ar) == 1 or not return_inverse:
return result

idx = 2 if return_index else 1
inverse_indices = result[idx]
if axis is None:
inverse_indices = inverse_indices.reshape(np.shape(ar))
else:
inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis])
elif jtu.numpy_version() == (2, 0, 0):
inverse_indices = inverse_indices.reshape(-1)
return (*result[:idx], inverse_indices, *result[idx + 1:])


Expand Down

0 comments on commit 26ec43f

Please sign in to comment.