Skip to content

Commit

Permalink
jax.numpy: implement scalar boolean indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Feb 8, 2024
1 parent 4c505f8 commit 0f68834
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 31 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.25

* Changes

* JAX now supports NumPy-style scalar boolean indexing, e.g. `x[True, :, False]`.

## jaxlib 0.4.25

## jax 0.4.24 (Feb 6, 2024)
Expand Down
7 changes: 0 additions & 7 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,13 +308,6 @@ def __getitem__(self, idx):
from jax._src.numpy import lax_numpy
self._check_if_deleted()

if isinstance(idx, tuple):
num_idx = sum(e is not None and e is not Ellipsis for e in idx)
if num_idx > self.ndim:
raise IndexError(
f"Too many indices for array: array has ndim of {self.ndim}, but "
f"was indexed with {num_idx} non-None/Ellipsis indices.")

if isinstance(self.sharding, PmapSharding):
if config.pmap_no_rank_reduction.value:
cidx = idx if isinstance(idx, tuple) else (idx,)
Expand Down
38 changes: 28 additions & 10 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4613,6 +4613,9 @@ def _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
if isinstance(fill_value, np.ndarray):
fill_value = fill_value.item()

if indexer.scalar_bool_dims:
y = lax.expand_dims(y, indexer.scalar_bool_dims)

# Avoid calling gather if the slice shape is empty, both as a fast path and to
# handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
Expand Down Expand Up @@ -4657,6 +4660,10 @@ class _Indexer(NamedTuple):
# gathers and eliminated for scatters.
newaxis_dims: Sequence[int]

# Keep track of dimensions with scalar bool indices. These must be inserted
# for gathers before performing other index operations.
scalar_bool_dims: Sequence[int]


def _split_index_for_jit(idx, shape):
"""Splits indices into necessarily-static and dynamic parts.
Expand Down Expand Up @@ -4705,6 +4712,16 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# Remove ellipses and add trailing slice(None)s.
idx = _canonicalize_tuple_index(len(x_shape), idx)

# Check for scalar boolean indexing: this requires inserting extra dimensions
# before performing the rest of the logic.
scalar_bool_dims: Sequence[int] = [n for n, i in enumerate(idx) if isinstance(i, bool)]
if scalar_bool_dims:
idx = tuple(np.arange(int(i)) if isinstance(i, bool) else i for i in idx)
x_shape = list(x_shape)
for i in sorted(scalar_bool_dims):
x_shape.insert(i, 1)
x_shape = tuple(x_shape)

# Check for advanced indexing:
# https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html#advanced-indexing

Expand Down Expand Up @@ -4805,8 +4822,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
# XLA gives error when indexing into an axis of size 0
raise IndexError(f"index is out of bounds for axis {x_axis} with size 0")
i = _normalize_index(i, x_shape[x_axis]) if normalize_indices else i
i = lax.convert_element_type(i, index_dtype)
gather_indices.append((i, len(gather_indices_shape)))
i_converted = lax.convert_element_type(i, index_dtype)
gather_indices.append((i_converted, len(gather_indices_shape)))
collapsed_slice_dims.append(x_axis)
gather_slice_shape.append(1)
start_index_map.append(x_axis)
Expand Down Expand Up @@ -4893,7 +4910,8 @@ def _index_to_gather(x_shape: Sequence[int], idx: Sequence[Any],
dnums=dnums,
gather_indices=gather_indices_array,
unique_indices=advanced_indexes is None,
indices_are_sorted=advanced_indexes is None)
indices_are_sorted=advanced_indexes is None,
scalar_bool_dims=scalar_bool_dims)

def _should_unpack_list_index(x):
"""Helper for _eliminate_deprecated_list_indexing."""
Expand Down Expand Up @@ -4959,7 +4977,7 @@ def _expand_bool_indices(idx, shape):
# TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
raise errors.NonConcreteBooleanIndexError(abstract_i)
elif _ndim(i) == 0:
raise TypeError("JAX arrays do not support boolean scalar indices")
out.append(bool(i))
else:
i_shape = _shape(i)
start = len(out) + ellipsis_offset - newaxis_offset
Expand Down Expand Up @@ -5010,21 +5028,21 @@ def _is_scalar(x):

def _canonicalize_tuple_index(arr_ndim, idx, array_name='array'):
"""Helper to remove Ellipsis and add in the implicit trailing slice(None)."""
len_without_none = sum(e is not None and e is not Ellipsis for e in idx)
if len_without_none > arr_ndim:
num_dimensions_consumed = sum(not (e is None or e is Ellipsis or isinstance(e, bool)) for e in idx)
if num_dimensions_consumed > arr_ndim:
raise IndexError(
f"Too many indices for {array_name}: {len_without_none} "
f"Too many indices for {array_name}: {num_dimensions_consumed} "
f"non-None/Ellipsis indices for dim {arr_ndim}.")
ellipses = (i for i, elt in enumerate(idx) if elt is Ellipsis)
ellipsis_index = next(ellipses, None)
if ellipsis_index is not None:
if next(ellipses, None) is not None:
raise IndexError(
f"Multiple ellipses (...) not supported: {list(map(type, idx))}.")
colons = (slice(None),) * (arr_ndim - len_without_none)
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = idx[:ellipsis_index] + colons + idx[ellipsis_index + 1:]
elif len_without_none < arr_ndim:
colons = (slice(None),) * (arr_ndim - len_without_none)
elif num_dimensions_consumed < arr_ndim:
colons = (slice(None),) * (arr_ndim - num_dimensions_consumed)
idx = tuple(idx) + colons
return idx

Expand Down
3 changes: 3 additions & 0 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)
# TODO(jakevdp): implement scalar boolean logic.
if indexer.scalar_bool_dims:
raise TypeError("Scalar boolean indices are not allowed in scatter.")

# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].
Expand Down
3 changes: 0 additions & 3 deletions jax/experimental/array_api/skips.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Known failures for the array api tests.

# JAX doesn't yet support scalar boolean indexing
array_api_tests/test_array_object.py::test_getitem_masking

# Test suite attempts in-place mutation:
array_api_tests/test_special_cases.py::test_binary
array_api_tests/test_special_cases.py::test_iop
Expand Down
35 changes: 24 additions & 11 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,15 +876,6 @@ def testBooleanIndexingDynamicShapeError(self):
i = np.array([True, True, False])
self.assertRaises(IndexError, lambda: jax.jit(lambda x, i: x[i])(x, i))

def testScalarBooleanIndexingNotImplemented(self):
msg = "JAX arrays do not support boolean scalar indices"
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[True]
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[False]
with self.assertRaisesRegex(TypeError, msg):
jnp.arange(4)[..., True]

def testIssue187(self):
x = jnp.ones((5, 5))
x[[0, 2, 4], [0, 2, 4]] # doesn't crash
Expand Down Expand Up @@ -1033,6 +1024,29 @@ def testNontrivialBooleanIndexing(self):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

@jtu.sample_product(
shape=[(2, 3, 4, 5)],
idx=[
np.index_exp[True],
np.index_exp[False],
np.index_exp[..., True],
np.index_exp[..., False],
np.index_exp[0, :2, True],
np.index_exp[0, :2, False],
np.index_exp[:2, 0, True],
np.index_exp[:2, 0, False],
np.index_exp[:2, np.array([0, 2]), True],
np.index_exp[np.array([1, 0]), :, True],
np.index_exp[True, :, True, :, np.array(True)],
]
)
def testScalarBooleanIndexing(self, shape, idx):
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32)]
np_fun = lambda x: np.asarray(x)[idx]
jnp_fun = lambda x: jnp.asarray(x)[idx]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

def testFloatIndexingError(self):
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
Expand Down Expand Up @@ -1158,8 +1172,7 @@ def _check_raises(x_type, y_type, msg):
def testWrongNumberOfIndices(self):
with self.assertRaisesRegex(
IndexError,
"Too many indices for array: array has ndim of 1, "
"but was indexed with 2 non-None/Ellipsis indices"):
"Too many indices for array: 2 non-None/Ellipsis indices for dim 1."):
jnp.zeros(3)[:, 5]


Expand Down

0 comments on commit 0f68834

Please sign in to comment.