diff --git a/tests/pgle_test.py b/tests/pgle_test.py index 466da7f27067..188b56c8cf78 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -46,7 +46,7 @@ def f(x, y): z = x @ y return z @ y - shape = (8, 8) + shape = (16, 16) x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32) y = x + 1 f_lowered = f.lower(x, y)