diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index fceceda0dfb2..2e74ae45e67e 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -1035,6 +1035,32 @@ def kernel(x_bbm_ref, y_ref, sem, dma_sem): )(x) np.testing.assert_array_equal(y, x) + def test_large_array_indexing(self): + n = 6 + dtype = jnp.bfloat16 + x = jax.lax.broadcasted_iota(dtype, (n, 1024 * 1024, 512), 0) + + def kernel(index, x, y, sem): + pltpu.async_copy(x.at[index[0]], y.at[:], sem).wait() + + run = pl.pallas_call(kernel, + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=1, + in_specs=[ + pl.BlockSpec( + memory_space=pltpu.TPUMemorySpace.ANY)], + out_specs=pl.BlockSpec( + memory_space=pltpu.TPUMemorySpace.ANY), + scratch_shapes=[pltpu.SemaphoreType.DMA], + ), + out_shape=jax.ShapeDtypeStruct(x.shape[1:], dtype), + ) + + for i in range(x.shape[0]): + y = run(jnp.array([i], dtype=jnp.int32), x) + np.testing.assert_array_equal(y, i) + del y + class PallasCallRemoteDMATest(parameterized.TestCase):