Skip to content

Commit

Permalink
[PALLAS] add test for large indexing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611925093
  • Loading branch information
blakehechtman authored and jax authors committed Mar 2, 2024
1 parent 51a31e5 commit ab83469
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,32 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem):
)(x)
np.testing.assert_array_equal(y, x)

def test_large_array_indexing(self):
n = 6
dtype = jnp.bfloat16
x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0)

def kernel(index, x, y, sem):
pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait()

run = pl.pallas_call(kernel,
grid_spec=pltpu.PrefetchScalarGridSpec(
num_scalar_prefetch=1,
in_specs=[
pl.BlockSpec(
memory_space=pltpu.TPUMemorySpace.ANY)],
out_specs=pl.BlockSpec(
memory_space=pltpu.TPUMemorySpace.ANY),
scratch_shapes=[pltpu.SemaphoreType.DMA],
),
out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype),
)

for i in range(x.shape[0]):
y = run(jnp.array([i], dtype=jnp.int32), x)
np.testing.assert_array_equal(y, i)
del y


class PallasCallRemoteDMATest(parameterized.TestCase):

Expand Down

0 comments on commit ab83469

Please sign in to comment.