Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.numpy: implement scalar boolean indexing #19722

Merged
merged 1 commit into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Remember to align the itemized text with the first line of an item within a list
* Added [CUDA Array
Interface](https://numba.readthedocs.io/en/stable/cuda/cuda_array_interface.html)
import support (requires jaxlib 0.4.24).
* JAX arrays now support NumPy-style scalar boolean indexing, e.g. `x[True]` or `x[False]`.

* Deprecations & Removals
* {func}`jax.numpy.linalg.solve` now shows a deprecation warning for batched 1D
Expand Down
7 changes: 0 additions & 7 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,6 @@ def __getitem__(self, idx):
from jax._src.numpy import lax_numpy
self._check_if_deleted()

if isinstance(idx, tuple):
num_idx = sum(e is not None and e is not Ellipsis for e in idx)
if num_idx > self.ndim:
raise IndexError(
f"Too many indices for array: array has ndim of {self.ndim}, but "
f"was indexed with {num_idx} non-None/Ellipsis indices.")

Copy link
Collaborator Author

@jakevdp jakevdp Feb 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this because the logic is wrong (boolean indices idx consume a number of dimensions equal to idx.ndim) and this condition is correctly checked in _index_to_gather.

if isinstance(self.sharding, PmapSharding):
if config.pmap_no_rank_reduction.value:
cidx = idx if isinstance(idx, tuple) else (idx,)
Expand Down
38 changes: 28 additions & 10 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4613,6 +4613,9 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
if isinstance(fill_value, np.ndarray):
fill_value = fill_value.item()

if indexer.scalar_bool_dims:
y = lax.expand_dims(y, indexer.scalar_bool_dims)

# Avoid calling gather if the slice shape is empty, both as a fast path and to
# handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
Expand Down Expand Up @@ -4657,6 +4660,10 @@ class _Indexer(NamedTuple):
# gathers and eliminated for scatters.
newaxis_dims: Sequence[int]

# Keep track of dimensions with scalar bool indices. These must be inserted
# for gathers before performing other index operations.
scalar_bool_dims: Sequence[int]


def _split_index_for_jit(idx, shape):
"""Splits indices into necessarily-static and dynamic parts.
Expand Down Expand Up @@ -4705,6 +4712,16 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# Remove ellipses and add trailing slice(None)s.
idx = _canonicalize_tuple_index(len(x_shape), idx)

# Check for scalar boolean indexing: this requires inserting extra dimensions
# before performing the rest of the logic.
scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]
if scalar_bool_dims:
idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx)
x_shape = list(x_shape)
for i in sorted(scalar_bool_dims):
x_shape.insert(i, 1)
x_shape = tuple(x_shape)

# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing

Expand Down Expand Up @@ -4805,8 +4822,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
i = lax.convert_element_type(i, index_dtype)
gather_indices.append((i, len(gather_indices_shape)))
i_converted = lax.convert_element_type(i, index_dtype)
gather_indices.append((i_converted, len(gather_indices_shape)))
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
Expand Down Expand Up @@ -4893,7 +4910,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
dnums=dnums,
gather_indices=gather_indices_array,
unique_indices=advanced_indexes is None,
indices_are_sorted=advanced_indexes is None)
indices_are_sorted=advanced_indexes is None,
scalar_bool_dims=scalar_bool_dims)

def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""
Expand Down Expand Up @@ -4959,7 +4977,7 @@ def _expand_bool_indices(idx, shape):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
elif _ndim(i) == 0:
raise TypeError("JAX arrays do not support boolean scalar indices")
out.append(bool(i))
else:
i_shape = _shape(i)
start = len(out) + ellipsis_offset - newaxis_offset
Expand Down Expand Up @@ -5010,21 +5028,21 @@ def _is_scalar(x):

def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
len_without_none = sum(e is not None and e is not Ellipsis for e in idx)
if len_without_none > arr_ndim:
num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
if num_dimensions_consumed > arr_ndim:
raise IndexError(
f"Too many indices for {array_name}: {len_without_none} "
f"Too many indices for {array_name}: {num_dimensions_consumed} "
f"non-None/Ellipsis indices for dim {arr_ndim}.")
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
if ellipsis_index is not None:
if next(ellipses, None) is not None:
raise IndexError(
f"Multiple ellipses (...) not supported: {list(map(type, idx))}.")
colons = (slice(None),) * (arr_ndim - len_without_none)
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif len_without_none < arr_ndim:
colons = (slice(None),) * (arr_ndim - len_without_none)
elif num_dimensions_consumed < arr_ndim:
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = tuple(idx) + colons
return idx

Expand Down
3 changes: 3 additions & 0 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)
# TODO(jakevdp): implement scalar boolean logic.
if indexer.scalar_bool_dims:
raise TypeError("Scalar boolean indices are not allowed in scatter.")

# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].
Expand Down
3 changes: 0 additions & 3 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Known failures for the array api tests.

# JAX doesn't yet support scalar boolean indexing
array_api_tests/test_array_object.py::test_getitem_masking

# Test suite attempts in-place mutation:
array_api_tests/test_special_cases.py::test_binary
array_api_tests/test_special_cases.py::test_iop
Expand Down
35 changes: 24 additions & 11 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,15 +876,6 @@ def testBooleanIndexingDynamicShapeError(self):
i = np.array([True, True, False])
self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i))

def testScalarBooleanIndexingNotImplemented(self):
msg = "JAX arrays do not support boolean scalar indices"
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[True]
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[False]
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[..., True]

def testIssue187(self):
x = jnp.ones((5, 5))
x[[0, 2, 4], [0, 2, 4]] # doesn't crash
Expand Down Expand Up @@ -1033,6 +1024,29 @@ def testNontrivialBooleanIndexing(self):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=[(2, 3, 4, 5)],
idx=[
np.index_exp[True],
np.index_exp[False],
np.index_exp[..., True],
np.index_exp[..., False],
np.index_exp[0, :2, True],
np.index_exp[0, :2, False],
np.index_exp[:2, 0, True],
np.index_exp[:2, 0, False],
np.index_exp[:2, np.array([0, 2]), True],
np.index_exp[np.array([1, 0]), :, True],
np.index_exp[True, :, True, :, np.array(True)],
]
)
def testScalarBooleanIndexing(self, shape, idx):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32)]
np_fun = lambda x: np.asarray(x)[idx]
jnp_fun = lambda x: jnp.asarray(x)[idx]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

def testFloatIndexingError(self):
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
Expand Down Expand Up @@ -1158,8 +1172,7 @@ def _check_raises(x_type, y_type, msg):
def testWrongNumberOfIndices(self):
with self.assertRaisesRegex(
IndexError,
"Too many indices for array: array has ndim of 1, "
"but was indexed with 2 non-None/Ellipsis indices"):
"Too many indices for array: 2 non-None/Ellipsis indices for dim 1."):
jnp.zeros(3)[:, 5]


Expand Down
Loading