diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index 6f3081e1a817..2d7217c367ec 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -900,9 +900,10 @@ def body(x_ref, sem): def test_smem_hbm_dma(self): def kernel(x_ref, y_hbm_ref): def body(y_ref, sem): - y_ref[0, 0] = x_ref[4, 4] + y_ref[0, 0] = 0.0 + y_ref[0, 1] = x_ref[4, 4] pltpu.async_copy(y_ref, y_hbm_ref, sem).wait() - pltpu.run_scoped(body, pltpu.SMEM((8, 128), jnp.float32), + pltpu.run_scoped(body, pltpu.SMEM((1, 2), jnp.float32), pltpu.SemaphoreType.DMA) x = jnp.arange(8 * 128.).reshape((8, 128)) y = pl.pallas_call( @@ -911,9 +912,9 @@ def body(y_ref, sem): pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM), ], out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY), - out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32), + out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32), )(x) - expected = jnp.zeros_like(x).at[0, 0].set(x[4, 4]) + expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4]) np.testing.assert_allclose(y, expected) def test_vmem_vmem_dma(self):