From 28b81bef5fe29898c49bb8ee3133ad4990c8c3c9 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 9 Apr 2024 10:10:13 -0700 Subject: [PATCH] [Pallas TPU] Pallas while loop -> fori test. PiperOrigin-RevId: 623204164 --- tests/pallas/pallas_call_tpu_test.py | 52 ++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/tests/pallas/pallas_call_tpu_test.py b/tests/pallas/pallas_call_tpu_test.py index dbc3ad47f541..ddc5bc99ce3e 100644 --- a/tests/pallas/pallas_call_tpu_test.py +++ b/tests/pallas/pallas_call_tpu_test.py @@ -1654,6 +1654,58 @@ def _false(): return +class PallasCallWhileLoopTest(PallasTPUTest): + + def setUp(self): + super().setUp() + if jtu.device_under_test() != 'tpu': + self.skipTest('Test only works on TPU') + + def test_range_while_loop(self): + """Tests lowering of a while_loop which can reduce to a fori_loop.""" + + def kernel(x_ref, r_ref): + @pl.when(pl.program_id(0) == 0) + def _(): + pl.store(r_ref, (0, 0), 0) + + def cond(carry): + i, j = carry + return i < j + + def body(carry): + i, j = carry + sl = sl = jax.lax.div(i, 128) + l = jax.lax.rem(i, 128) + v = x_ref[0, sl, l] + s = pl.load(r_ref, (0, 0)) + pl.store(r_ref, (0, 0), s + v) + return i + 1, j + + i = 0 + j = 1024 + i, j = jax.lax.while_loop(cond, body, (i, j)) + + x = jnp.arange(4096) + x = jnp.reshape(x, [4, 8, 128]) + + r = pl.pallas_call( + kernel, + grid=(1,), + out_specs=pl.BlockSpec(block_shape=(1, 1), memory_space=pltpu.SMEM), + out_shape=jax.ShapeDtypeStruct([1, 1], jnp.int32), + in_specs=[ + pl.BlockSpec( + lambda i: (i, 0, 0), + block_shape=(1, 8, 128), + memory_space=pltpu.SMEM, + ) + ], + )(x) + expected = jnp.sum(jnp.arange(1024)) + np.testing.assert_array_equal(r, expected) + + class PallasCallPipelineTest(parameterized.TestCase): def setUp(self):