Skip to content

Commit

Permalink
Support scalar boolean indices in arr.at[idx].set(vals)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed May 20, 2024
1 parent 53ec2cd commit 8cc9b5f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
8 changes: 5 additions & 3 deletions jax/_src/ops/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,15 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
idx = jnp._merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
indexer = jnp._index_to_gather(jnp.shape(x), idx,
normalize_indices=normalize_indices)
# TODO(jakevdp): implement scalar boolean logic.
if indexer.scalar_bool_dims:
raise TypeError("Scalar boolean indices are not allowed in scatter.")

# Avoid calling scatter if the slice shape is empty, both as a fast path and
# to handle cases like zeros(0)[array([], int32)].
if core.is_empty_shape(indexer.slice_shape):
return x

if indexer.scalar_bool_dims:
x = lax.expand_dims(x, indexer.scalar_bool_dims)

x, y = promote_dtypes(x, y)

# Broadcast `y` to the slice output shape.
Expand All @@ -133,6 +133,8 @@ def _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
indices_are_sorted=indexer.indices_are_sorted or indices_are_sorted,
unique_indices=indexer.unique_indices or unique_indices,
mode=mode)
if indexer.scalar_bool_dims:
out = lax.squeeze(out, indexer.scalar_bool_dims)
return lax_internal._convert_element_type(out, dtype, weak_type)


Expand Down
28 changes: 28 additions & 0 deletions tests/lax_numpy_indexing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,34 @@ def testScalarBooleanIndexing(self, shape, idx):
jnp_fun = lambda x: jnp.asarray(x)[idx]
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

@jtu.sample_product(
shape=[(2, 3, 4, 5)],
update_ndim=[0, 1, 2],
idx=[
np.index_exp[True],
np.index_exp[False],
np.index_exp[..., True],
np.index_exp[..., False],
np.index_exp[0, :2, True],
np.index_exp[0, :2, False],
np.index_exp[:2, 0, True],
np.index_exp[:2, 0, False],
np.index_exp[:2, np.array([0, 2]), True],
np.index_exp[np.array([1, 0]), :, True],
np.index_exp[True, :, True, :, np.array(True)],
]
)
def testScalarBoolUpdate(self, shape, idx, update_ndim):
update_shape = np.zeros(shape)[idx].shape[-update_ndim:]
rng = jtu.rand_default(self.rng())
args_maker = lambda: [rng(shape, np.int32), rng(update_shape, np.int32)]
def np_fun(x, update):
x = np.array(x, copy=True)
x[idx] = update
return x
jnp_fun = lambda x, update: jnp.asarray(x).at[idx].set(update)
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)

def testFloatIndexingError(self):
BAD_INDEX_TYPE_ERROR = "Indexer must have integer or boolean type, got indexer with type"
with self.assertRaisesRegex(TypeError, BAD_INDEX_TYPE_ERROR):
Expand Down

0 comments on commit 8cc9b5f

Please sign in to comment.