Skip to content

Commit

Permalink
Add zeros initialization to failing smem-hbm copy test.
Browse files Browse the repository at this point in the history
  • Loading branch information
justinjfu committed May 9, 2024
1 parent 9b79f65 commit 7245714
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,9 +900,10 @@ def body(x_ref, sem):
def test_smem_hbm_dma(self):
def kernel(x_ref, y_hbm_ref):
def body(y_ref, sem):
y_ref[0, 0] = x_ref[4, 4]
y_ref[0, 0] = 0.0
y_ref[0, 1] = x_ref[4, 4]
pltpu.async_copy(y_ref, y_hbm_ref, sem).wait()
pltpu.run_scoped(body, pltpu.SMEM((8, 128), jnp.float32),
pltpu.run_scoped(body, pltpu.SMEM((1, 2), jnp.float32),
pltpu.SemaphoreType.DMA)
x = jnp.arange(8 * 128.).reshape((8, 128))
y = pl.pallas_call(
Expand All @@ -911,9 +912,9 @@ def body(y_ref, sem):
pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.SMEM),
],
out_specs=pl.BlockSpec(memory_space=pltpu.TPUMemorySpace.ANY),
out_shape=jax.ShapeDtypeStruct((8, 128), jnp.float32),
out_shape=jax.ShapeDtypeStruct((1, 2), jnp.float32),
)(x)
expected = jnp.zeros_like(x).at[0, 0].set(x[4, 4])
expected = jnp.zeros_like(x[0:1, 0:2]).at[0, 1].set(x[4, 4])
np.testing.assert_allclose(y, expected)

def test_vmem_vmem_dma(self):
Expand Down

0 comments on commit 7245714

Please sign in to comment.