From f5690314563cfa2f98b2c85fd0c600619f9f4687 Mon Sep 17 00:00:00 2001 From: Junwhan Ahn Date: Fri, 15 Mar 2024 11:55:32 -0700 Subject: [PATCH] Reverts 55394a0914dc0583427a4ceb73dac56348911d15 PiperOrigin-RevId: 616201321 --- jax/_src/api.py | 24 +++++++++++++++++++++++- tests/api_test.py | 14 ++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index 597d0c057844..ee5290080f3a 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2958,7 +2958,29 @@ def try_to_block(x): return x.block_until_ready() except AttributeError: return x - return tree_map(try_to_block, x) + + if xla_extension_version < 246: + return tree_map(try_to_block, x) + + arrays = [] + for leaf in tree_leaves(x): + if isinstance(leaf, array.ArrayImpl): + arrays.append(leaf) + else: + try_to_block(leaf) + + if not arrays: + # `arrays` will be empty if tree_leaves(x) is empty or all leaves are not + # jax.Array. + pass + elif len(arrays) == 1: + # Fast path for single array. + try_to_block(arrays[0]) + else: + # Optimized for multiple arrays. + xc.batched_block_until_ready(arrays) + + return x def clear_backends(): diff --git a/tests/api_test.py b/tests/api_test.py index 6db67205534d..9e7f87ca759b 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2374,6 +2374,20 @@ def test_block_until_ready_function(self): self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False) self.assertAllClose(pytree[1], np.ones(3), check_dtypes=False) + def test_block_until_ready_numpy_arrays(self): + pytree = (np.ones(1), np.ones(2)) + pytree = jax.block_until_ready(pytree) + self.assertAllClose(pytree[0], np.ones(1), check_dtypes=False) + self.assertAllClose(pytree[1], np.ones(2), check_dtypes=False) + + def test_block_until_ready_mixed(self): + pytree = (device_put(1.), device_put(2.), np.ones(3), 4) + pytree = jax.block_until_ready(pytree) + self.assertAllClose(pytree[0], jnp.array(1.), check_dtypes=False) + self.assertAllClose(pytree[1], jnp.array(2.), check_dtypes=False) + self.assertAllClose(pytree[2], np.ones(3), check_dtypes=False) + self.assertEqual(pytree[3], 4) + def test_devicearray_weakref_friendly(self): x = device_put(1.) y = weakref.ref(x)