Skip to content

Commit

Permalink
[Pallas TPU] Pallas while loop -> fori test.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 623204164
  • Loading branch information
jax authors committed Apr 9, 2024
1 parent 77db7a6 commit 28b81be
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions tests/pallas/pallas_call_tpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1654,6 +1654,58 @@ def _false():
return


class PallasCallWhileLoopTest(PallasTPUTest):

def setUp(self):
super().setUp()
if jtu.device_under_test() != 'tpu':
self.skipTest('Test only works on TPU')

def test_range_while_loop(self):
"""Tests lowering of a while_loop which can reduce to a fori_loop."""

def kernel(x_ref, r_ref):
@pl.when(pl.program_id(0) == 0)
def _():
pl.store(r_ref, (0, 0), 0)

def cond(carry):
i, j = carry
return i < j

def body(carry):
i, j = carry
sl = sl = jax.lax.div(i, 128)
l = jax.lax.rem(i, 128)
v = x_ref[0, sl, l]
s = pl.load(r_ref, (0, 0))
pl.store(r_ref, (0, 0), s + v)
return i + 1, j

i = 0
j = 1024
i, j = jax.lax.while_loop(cond, body, (i, j))

x = jnp.arange(4096)
x = jnp.reshape(x, [4, 8, 128])

r = pl.pallas_call(
kernel,
grid=(1,),
out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM),
out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32),
in_specs=[
pl.BlockSpec(
lambda i: (i, 0, 0),
block_shape=(1, 8, 128),
memory_space=pltpu.SMEM,
)
],
)(x)
expected = jnp.sum(jnp.arange(1024))
np.testing.assert_array_equal(r, expected)


class PallasCallPipelineTest(parameterized.TestCase):

def setUp(self):
Expand Down

0 comments on commit 28b81be

Please sign in to comment.