diff --git a/jax/_src/numpy/setops.py b/jax/_src/numpy/setops.py index 6ac7ce804d8f..2e953c67abd6 100644 --- a/jax/_src/numpy/setops.py +++ b/jax/_src/numpy/setops.py @@ -490,8 +490,6 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo inv_idx = inv_idx.at[perm].set(imask) else: inv_idx = zeros(ar.shape[axis], dtype=int) - if ar.ndim > 1: - inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],) ret += (inv_idx,) if return_counts: if aux.size: @@ -550,7 +548,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal ``ar[unique_index]`` is equivalent to ``unique_values``. - ``unique_inverse``: *(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis`` - is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified. + is None, or of shape ``(ar.shape[axis],)`` if ``axis`` is specified. Contains the indices within ``unique_values`` of each value in ``ar``. For 1D inputs, ``unique_values[unique_inverse]`` is equivalent to ``ar``. - ``unique_counts``: @@ -652,10 +650,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal Array(True, dtype=bool) In multiple dimensions, the input can be reconstructed using - :func:`jax.numpy.take_along_axis`: + :func:`jax.numpy.take`: >>> values, inverse = jnp.unique(M, axis=0, return_inverse=True) - >>> jnp.all(jnp.take_along_axis(values, inverse, axis=0) == M) + >>> jnp.all(jnp.take(values, inverse, axis=0) == M) Array(True, dtype=bool) **Returning counts** diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index dc0a12f5cc7d..634b8657e35a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -91,15 +91,15 @@ def np_unique_backport(ar, return_index=False, return_inverse=False, return_coun # Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0 result = np.unique(ar, return_index=return_index, return_inverse=return_inverse, return_counts=return_counts, axis=axis, **kwds) - if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse: + if jtu.numpy_version() >= (2, 0, 1) or np.ndim(ar) == 1 or not return_inverse: return result idx = 2 if return_index else 1 inverse_indices = result[idx] if axis is None: inverse_indices = inverse_indices.reshape(np.shape(ar)) - else: - inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis]) + elif jtu.numpy_version() == (2, 0, 0): + inverse_indices = inverse_indices.reshape(-1) return (*result[:idx], inverse_indices, *result[idx + 1:])