Skip to content

Commit

Permalink
array api: add unique_* interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 21, 2023
1 parent e3e26f2 commit dd0a341
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 34 deletions.
4 changes: 4 additions & 0 deletions docs/jax.numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,10 @@ namespace; they are listed below.
uint8
union1d
unique
unique_all
unique_counts
unique_inverse
unique_values
unpackbits
unravel_index
unsignedinteger
Expand Down
49 changes: 48 additions & 1 deletion jax/_src/numpy/setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import math
import operator
from textwrap import dedent as _dedent
from typing import cast
from typing import cast, NamedTuple

import numpy as np

Expand Down Expand Up @@ -338,3 +338,50 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
axis_int: int = core.concrete_or_error(operator.index, axis, "axis argument of jnp.unique()")
return _unique(arr, axis_int, return_index, return_inverse,
return_counts, equal_nan=equal_nan, size=size, fill_value=fill_value)


class _UniqueAllResult(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array


class _UniqueCountsResult(NamedTuple):
values: Array
counts: Array


class _UniqueInverseResult(NamedTuple):
values: Array
inverse_indices: Array


@_wraps(getattr(np, "unique_all", None))
def unique_all(x: ArrayLike, /) -> _UniqueAllResult:
check_arraylike("unique_all", x)
values, indices, inverse_indices, counts = unique(
x, return_index=True, return_inverse=True, return_counts=True, equal_nan=False)
inverse_indices = inverse_indices.reshape(np.shape(x))
return _UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)


@_wraps(getattr(np, "unique_counts", None))
def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult:
check_arraylike("unique_counts", x)
values, counts = unique(x, return_counts=True, equal_nan=False)
return _UniqueCountsResult(values=values, counts=counts)


@_wraps(getattr(np, "unique_inverse", None))
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult:
check_arraylike("unique_inverse", x)
values, inverse_indices = unique(x, return_inverse=True, equal_nan=False)
inverse_indices = inverse_indices.reshape(np.shape(x))
return _UniqueInverseResult(values=values, inverse_indices=inverse_indices)


@_wraps(getattr(np, "unique_values", None))
def unique_values(x: ArrayLike, /) -> Array:
check_arraylike("unique_values", x)
return unique(x, equal_nan=False)
31 changes: 3 additions & 28 deletions jax/experimental/array_api/_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,47 +12,22 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import NamedTuple
import jax


class UniqueAllResult(NamedTuple):
values: jax.Array
indices: jax.Array
inverse_indices: jax.Array
counts: jax.Array


class UniqueCountsResult(NamedTuple):
values: jax.Array
counts: jax.Array


class UniqueInverseResult(NamedTuple):
values: jax.Array
inverse_indices: jax.Array


def unique_all(x, /):
"""Returns the unique elements of an input array x, the first occurring indices for each unique element in x, the indices from the set of unique elements that reconstruct x, and the corresponding counts for each unique element in x."""
values, indices, inverse_indices, counts = jax.numpy.unique(
x, return_index=True, return_inverse=True, return_counts=True)
# jnp.unique() flattens inverse indices
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueAllResult(values=values, indices=indices, inverse_indices=inverse_indices, counts=counts)
return jax.numpy.unique_all(x)


def unique_counts(x, /):
"""Returns the unique elements of an input array x and the corresponding counts for each unique element in x."""
values, counts = jax.numpy.unique(x, return_counts=True)
return UniqueCountsResult(values=values, counts=counts)
return jax.numpy.unique_counts(x)


def unique_inverse(x, /):
"""Returns the unique elements of an input array x and the indices from the set of unique elements that reconstruct x."""
values, inverse_indices = jax.numpy.unique(x, return_inverse=True)
inverse_indices = inverse_indices.reshape(x.shape)
return UniqueInverseResult(values=values, inverse_indices=inverse_indices)
return jax.numpy.unique_inverse(x)


def unique_values(x, /):
Expand Down
4 changes: 0 additions & 4 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ array_api_tests/test_linalg.py::test_matrix_power
array_api_tests/test_linalg.py::test_solve

# JAX's NaN sorting doesn't match specification
array_api_tests/test_set_functions.py::test_unique_all
array_api_tests/test_set_functions.py::test_unique_counts
array_api_tests/test_set_functions.py::test_unique_inverse
array_api_tests/test_set_functions.py::test_unique_values
array_api_tests/test_sorting_functions.py::test_argsort

# fft test suite is buggy as of 83f0bcdc
Expand Down
4 changes: 4 additions & 0 deletions jax/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@
setxor1d as setxor1d,
union1d as union1d,
unique as unique,
unique_all as unique_all,
unique_counts as unique_counts,
unique_inverse as unique_inverse,
unique_values as unique_values,
)

from jax._src.numpy.ufuncs import (
Expand Down
17 changes: 16 additions & 1 deletion jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations


from typing import Any, Callable, Literal, Optional, Sequence, TypeVar, Union, overload
from typing import Any, Callable, Literal, NamedTuple, Optional, Sequence, TypeVar, Union, overload

from jax._src import core as _core
from jax._src import dtypes as _dtypes
Expand Down Expand Up @@ -792,11 +792,26 @@ def union1d(
size: Optional[int] = ...,
fill_value: Optional[ArrayLike] = ...,
) -> Array: ...
class _UniqueAllResult(NamedTuple):
values: Array
indices: Array
inverse_indices: Array
counts: Array
class _UniqueCountsResult(NamedTuple):
values: Array
counts: Array
class _UniqueInverseResult(NamedTuple):
values: Array
inverse_indices: Array
def unique(ar: ArrayLike, return_index: bool = ..., return_inverse: bool = ...,
return_counts: bool = ..., axis: Optional[int] = ...,
*, equal_nan: bool = ..., size: Optional[int] = ...,
fill_value: Optional[ArrayLike] = ...
): ...
def unique_all(x: ArrayLike, /) -> _UniqueAllResult: ...
def unique_counts(x: ArrayLike, /) -> _UniqueCountsResult: ...
def unique_inverse(x: ArrayLike, /) -> _UniqueInverseResult: ...
def unique_values(x: ArrayLike, /) -> Array: ...
def unpackbits(
a: ArrayLike,
axis: Optional[int] = ...,
Expand Down
45 changes: 45 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,6 +1745,51 @@ def testUnique(self, shape, dtype, axis, return_index, return_inverse, return_co
jnp_fun = lambda x: jnp.unique(x, *extra_args, axis=axis)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueAll(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fun(x):
values, indices, inverse_indices, counts = np.unique(
x, return_index=True, return_inverse=True, return_counts=True)
return values, indices, inverse_indices.reshape(np.shape(x)), counts
else:
np_fun = np.unique_all
self._CheckAgainstNumpy(jnp.unique_all, np_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueCounts(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fun = lambda x: np.unique(x, return_counts=True)
else:
np_fun = np.unique_counts
self._CheckAgainstNumpy(jnp.unique_counts, np_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueInverse(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
def np_fun(x):
values, inverse_indices = np.unique(x, return_inverse=True)
return values, inverse_indices.reshape(np.shape(x))
else:
np_fun = np.unique_inverse
self._CheckAgainstNumpy(jnp.unique_inverse, np_fun, args_maker)

@jtu.sample_product(shape=all_shapes, dtype=number_dtypes)
def testUniqueValues(self, shape, dtype):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
if jtu.numpy_version() < (2, 0, 0):
np_fun = np.unique
else:
np_fun = np.unique_values
self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker)

@jtu.sample_product(
[dict(shape=shape, axis=axis)
for shape in nonempty_array_shapes
Expand Down

0 comments on commit dd0a341

Please sign in to comment.