diff --git a/tests/BUILD b/tests/BUILD index d66342dc3c6c..7ab6cc136e97 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -1142,23 +1142,6 @@ jax_multiplatform_test( deps = ["//jax:ode"], ) -jax_multiplatform_test( - name = "host_callback_outfeed_test", - srcs = ["host_callback_test.py"], - args = ["--jax_host_callback_outfeed=true"], - shard_count = { - "tpu": 5, - }, - tags = [ - "noasan", # Times out. - ], - deps = [ - "//jax:experimental", - "//jax:experimental_host_callback", - "//jax:ode", - ], -) - jax_multiplatform_test( name = "host_callback_test", srcs = ["host_callback_test.py"], diff --git a/tests/host_callback_test.py b/tests/host_callback_test.py index 5e624ef9a83c..21cb31f693ef 100644 --- a/tests/host_callback_test.py +++ b/tests/host_callback_test.py @@ -15,16 +15,15 @@ from __future__ import annotations import contextlib -from collections.abc import Callable, Sequence +from collections.abc import Callable from functools import partial import itertools import logging import os import re -import threading import time import unittest -from unittest import skip, SkipTest +from unittest import SkipTest from absl.testing import absltest @@ -34,8 +33,6 @@ from jax import lax from jax import numpy as jnp from jax.experimental import host_callback as hcb -from jax.experimental import pjit -from jax.sharding import PartitionSpec as P from jax._src import core from jax._src import xla_bridge from jax._src import test_util as jtu @@ -264,10 +261,6 @@ def tearDown(self) -> None: hcb.barrier_wait("HostCallbackTapTest.tearDown") super().tearDown() - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - def test_tap_eval(self): self.assertAllClose((5. * 2.) ** 2, fun1(5.)) hcb.barrier_wait() @@ -334,60 +327,6 @@ def func2(x): assertMultiLineStrippedEqual(self, "called tap_func with None", testing_stream.output) - def test_tap_with_device(self): - self.supported_only_in_legacy_mode() - def func2(x): - x1 = hcb_id_print((x * 2., x * 3.), result=x * 4., - output_stream=testing_stream, - tap_with_device=True) - return x1 - - self.assertEqual(3. * 4., func2(3.)) - hcb.barrier_wait() - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 - ( 6.00 9.00 )""") - - def test_tap_eval_exception(self): - self.supported_only_in_legacy_mode() - if not hcb._HOST_CALLBACK_OUTFEED.value: - raise SkipTest("TODO: implement error handling for customcall") - - # Simulate a tap error - def tap_err(*args, **kwargs): - raise ValueError("Some user message") - - def func(x): - x1 = hcb_id_print(x + 1, what="x1", output_stream=testing_stream) - x2 = hcb.id_tap(tap_err, x1 + 1) - x3 = hcb_id_print(x2 + 1, what="x3", output_stream=testing_stream) - return x3 - - if hcb._HOST_CALLBACK_LEGACY.value: - ctx = self.assertRaisesRegex( - hcb.CallbackException, - re.compile("There were exceptions during callback processing. Last one was:.*" - "ValueError: Some user message", re.DOTALL)) - else: - ctx = self.assertRaisesRegex(Exception, "Some user message") - - with ctx: - func(0) - hcb.barrier_wait() - - if hcb._HOST_CALLBACK_LEGACY.value: - # We should have received everything before the error - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) - else: - # We should have received everything before the error - assertMultiLineStrippedEqual(self, """ - what: x1 - 1""", testing_stream.output) - def test_tap_empty(self): """Tap empty arrays.""" hcb_id_print((), output_stream=testing_stream) @@ -519,26 +458,6 @@ def func_nested(x): where: 3 3""", testing_stream.output) - def test_tap_jit_devices(self): - """Running on multiple devices.""" - self.supported_only_in_legacy_mode() - logging.info("%s: has devices %s", self._testMethodName, local_devices()) - - def func(x, device_id): - x1 = hcb_id_print(x, dev=str(device_id), output_stream=testing_stream) - x2 = hcb_id_print(x1 + 1, dev=str(device_id), output_stream=testing_stream) - return x2 - - for d in local_devices(): - self.assertEqual(112, jax.jit(func, device=d, static_argnums=1)(111, d.id)) - hcb.barrier_wait() - logging.info("%s: found output %s", self._testMethodName, - testing_stream.output) - self.assertEqual( - len(local_devices()), len(re.findall(r"111", testing_stream.output))) - self.assertEqual( - len(local_devices()), len(re.findall(r"112", testing_stream.output))) - @jtu.sample_product(with_jit=[True, False]) def test_tap_pytree(self, with_jit=False): def func(x, what=""): @@ -571,79 +490,6 @@ def tap_func(a, _, *, what=""): hcb.barrier_wait() # Wait for receivers to be done self.assertEqual(3, tap_count) - @jtu.sample_product(concurrent=[True, False]) - def test_tap_multiple(self, concurrent=False): - """Call id_tap multiple times, concurrently or in sequence. """ - if concurrent and jtu.test_device_matches(["cpu", "gpu"]): - # TODO(necula): if there is device side concurrency, outfeeds from - # different computations can be interleaved. For example, it seems that - # on GPU if multiple host threads run a jit computation, the multiple - # computations are interleaved on the GPU. This can result in the outfeed - # trains being interleaved, which will trigger an error. - # The solution is to fix on GPU the receiving logic so that we can outfeed - # the train as one tuple, and receive it one piece as a time. Then the - # trains should be atomic. - # See also b/160692602. - raise SkipTest("concurrent id_tap not supported on CPU, GPU") - - self.supported_only_in_legacy_mode() - received = set() - count = 5 - - def pause_tap(idx, _): - received.add(int(idx)) - logging.info("Starting do_tap %s. Sleeping 1sec ...", idx) - time.sleep(0.3) - logging.info("Finish do_tap %s", idx) - - def do_tap(idx): - jax.jit(lambda idx: hcb.id_tap(pause_tap, idx))(idx) - - if concurrent: - threads = [ - threading.Thread( - name=f"enqueue_tap_{idx}", target=do_tap, args=(idx,)) - for idx in range(count) - ] - [t.start() for t in threads] - [t.join() for t in threads] - else: - for idx in range(count): - do_tap(idx) - - hcb.barrier_wait() - self.assertEqual(received, set(range(count))) - - # TODO(necula): see comment for test_multiple_tap. Here we disable also - # on TPU, because the barrier_wait runs on all devices, including on the CPU - # where it would run into concurrency problems. - @skip("Concurrency not supported") - def test_tap_multiple_barriers(self): - """Call barrier_wait concurrently.""" - - def pause_tap(*args, **kwargs): - logging.info("pause_tap waiting") - time.sleep(0.3) - logging.info("pause_tap done") - - def long_run(x): - return hcb.id_tap(pause_tap, x) - - jax.jit(long_run)(5.) - - def try_barrier(idx): - logging.info("Starting test barrier %s", idx) - hcb.barrier_wait() - logging.info("Finished test barrier %s", idx) - - threads = [ - threading.Thread( - name=f"barrier_{idx}", target=try_barrier, args=(idx,)) - for idx in range(3) - ] - [t.start() for t in threads] - [t.join() for t in threads] - @jtu.sample_product(with_jit=[True, False]) def test_tap_cond(self, with_jit=False): """A conditional""" @@ -852,39 +698,6 @@ def func(x, count): hcb.barrier_wait() self.assertEqual(100, count) - def test_tap_jit_tap_exception(self): - self.supported_only_in_legacy_mode() - if not hcb._HOST_CALLBACK_OUTFEED.value: - raise SkipTest("TODO: implement error handling for customcall") - # Simulate a tap error - def tap_err(*args, **kwargs): - raise NotImplementedError - - def func(x): - x1 = hcb_id_print(x + 1, what="x1", output_stream=testing_stream) - x2 = hcb.id_tap(tap_err, x1 + 1) - x3 = hcb_id_print(x2 + 1, what="x3", output_stream=testing_stream) - return x3 - - if hcb._HOST_CALLBACK_LEGACY.value: - res = jax.jit(func)(0) # No error yet - with self.assertRaises(hcb.CallbackException): - hcb.barrier_wait() - - # Even though the receiver thread raised, the main thread should still - # return 3. - self.assertEqual(3, res) - # We should have received all others - assertMultiLineStrippedEqual(self, """ - what: x1 - 1 - what: x3 - 3""", testing_stream.output) - else: - with self.assertRaisesRegex(Exception, "NotImplementedError"): - res = jax.jit(func)(0) - hcb.barrier_wait() - def test_tap_while(self): """Executing while, even without JIT uses compiled code""" y = jnp.ones(5) # captured const @@ -1470,231 +1283,6 @@ def power3_with_cotangents(x): self.assertMultiLineStrippedEqual(expected, testing_stream.output) testing_stream.reset() - def test_tap_pmap(self): - self.supported_only_in_legacy_mode() - if len(local_devices()) < 2: - raise SkipTest("test requires at least 2 devices") - - def power3(x): - y = x * x - # Print both 'x' and 'x^2'. Must pack as a tuple. - _, y = hcb_id_print((x, y), - what="x,x^2", - output_stream=testing_stream, - tap_with_device=True) - return y * x - - pmap_power3 = jax.pmap(power3, devices=local_devices()) - xv = np.array([3, 4], dtype=np.int32) - res = pmap_power3(xv) - hcb.barrier_wait() - self.assertAllClose(xv * xv * xv, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual( - self, """ - device: cpu:0 what: x,x^2 - ( 3 9 ) - device: cpu:1 what: x,x^2 - ( 4 16 )""") - - def test_tap_pmap_vmap(self): - self.supported_only_in_legacy_mode() - # A matrix M[ij] = i * 10 + j - nr_devices = len(local_devices()) - shape = (nr_devices, 3) - matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, - dtype=np.int32) - - def fun1(x, do_print=False): # x: i32 - return maybe_print(do_print, x * 2, "x * 2", tap_with_device=True) - - pmap_vmap_fun1 = jax.pmap( - jax.vmap(partial(fun1, do_print=True)), devices=local_devices()) - - res = pmap_vmap_fun1(matrix) - hcb.barrier_wait() - expected_res = jax.pmap( - jax.vmap(partial(fun1, do_print=False)), devices=local_devices())( - matrix) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [0.00 2.00 4.00] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [20.00 22.00 24.00]""") - - def test_tap_pmap_pmap_vmap(self): - # A matrix M[ijk] = i * 100 + j * 10 + k - self.supported_only_in_legacy_mode() - nr_devices = len(local_devices()) - if nr_devices % 2 != 0: - raise SkipTest("test works only on even number of devices") - - shape = (2, nr_devices // 2, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun1(x, do_print=False): # x: f32 - y = maybe_print(do_print, x * 2., "x * 2", tap_with_device=True) - return y ** 2 - - pmap_fun1 = jax.pmap( - jax.pmap(jax.vmap(partial(fun1, do_print=True))), - devices=local_devices()) - res = pmap_fun1(matrix) - hcb.barrier_wait() - expected_res = jax.pmap( - jax.pmap(jax.vmap(partial(fun1, do_print=False))), - devices=local_devices())( - matrix) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [0.00 2.00 4.00] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [200.00 202.00 204.00]""") - - @ignore_jit_of_pmap_warning() - def test_tap_pmap_pmap_extra(self): - """pmap of a pmap surrounded by extra code.""" - # A matrix M[ij] = i * 10 + j - self.supported_only_in_legacy_mode() - nr_devices = len(local_devices()) - if nr_devices != 2: - raise SkipTest("test works only on 2 devices") - shape = (2, 1, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # This will be printed on all devices, with shape [1, 3] - xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True) - res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv) - # This will be printed on all devices, with shape [1, 3] - return maybe_print(do_print, res + 1., "after", tap_with_device=True) - - res = jax.pmap(partial(fun, do_print=True))(matrix) - self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False) - hcb.barrier_wait() - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 what: before - [[1.00 2.00 3.00]] - device: cpu:0 what: inside - [2.00 4.00 6.00] - device: cpu:0 what: after - [[3.00 5.00 7.00]] - device: cpu:1 what: before - [[101.00 102.00 103.00]] - device: cpu:1 what: inside - [202.00 204.00 206.00] - device: cpu:1 what: after - [[203.00 205.00 207.00]]""") - - def test_tap_jvp_pmap_vmap(self): - self.supported_only_in_legacy_mode() - # A matrix M[ijk] = i * 100 + j * 10 * k - nr_devices = len(local_devices()) - shape = (nr_devices, 2, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # x: f32[3] - return jax.jvp(jax.pmap(jax.vmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True))), - (xv,), (.1 * jnp.ones_like(xv),)) - - res = fun(matrix, do_print=True) - hcb.barrier_wait() - expected_res = fun(matrix, do_print=False) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - # Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[0, :, :] - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[ 0.00 2.00 4.00] - [20.00 22.00 24.00]] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[200.00 202.00 204.00] - [220.00 222.00 224.00]]""") - - def test_tap_vmap_pmap(self): - self.supported_only_in_legacy_mode() - # A matrix M[ijk] = i * 100 + j * 10 * k - nr_devices = len(local_devices()) - shape = (2, nr_devices, 3) - matrix = np.fromfunction(lambda i, j, k: 100. * i + 10. * j + k, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # x: f32[3] - return jax.vmap(jax.pmap(lambda x: maybe_print(do_print, x * 2., "x * 2", tap_with_device=True)))(xv) - - res = fun(matrix, do_print=True) - hcb.barrier_wait() - expected_res = fun(matrix, do_print=False) - self.assertAllClose(expected_res, res, check_dtypes=False) - # Assertion text is for 2 devices (also works for 1 device) - # Device 0 will get to execute jax.jvp(jax.vmap(...)) for matrix[:, 0, :] - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[ 0.00 2.00 4.00] - [200.00 202.00 204.00]] - device: cpu:1 transforms: [('batch', {'batch_dims': (0,)})] what: x * 2 - [[ 20.00 22.00 24.00] - [220.00 222.00 224.00]]""") - - @ignore_jit_of_pmap_warning() - def test_tap_jit_pmap_extra(self): - """jit of a pmap surrounded by extra code.""" - self.supported_only_in_legacy_mode() - # A matrix M[ij] = i * 10 + j - nr_devices = len(local_devices()) - assert nr_devices in (1, 2) - shape = (nr_devices, 3) - matrix = np.fromfunction(lambda i, j: 10. * i + j, shape, - dtype=np.float32) - - def fun(xv, do_print=False): - # This will be printed on all devices with shape (nr_devices, 3) - xv = maybe_print(do_print, xv + 1., "before", tap_with_device=True) - res = jax.pmap(lambda x: maybe_print(do_print, x * 2., "inside", tap_with_device=True))(xv) - # This will be printed on all devices with shape (nr_devices, 3) - return maybe_print(do_print, res + 1., "after", tap_with_device=True) - - res = jax.jit(partial(fun, do_print=True))(matrix) - self.assertAllClose(fun(matrix, do_print=False), res, check_dtypes=False) - hcb.barrier_wait() - if len(local_devices()) == 2: - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 what: before - [[ 1.00 2.00 3.00] - [11.00 12.00 13.00]] - device: cpu:0 what: inside - [2.00 4.00 6.00] - device: cpu:0 what: after - [[ 3.00 5.00 7.00] - [23.00 25.00 27.00]] - device: cpu:1 what: before - [[ 1.00 2.00 3.00] - [11.00 12.00 13.00]] - device: cpu:1 what: inside - [22.00 24.00 26.00] - device: cpu:1 what: after - [[ 3.00 5.00 7.00] - [23.00 25.00 27.00]]""") - else: - assert len(local_devices()) == 1 - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 what: before - [[1.00 2.00 3.00]] - device: cpu:0 what: inside - [2.00 4.00 6.00] - device: cpu:0 what: after - [[3.00 5.00 7.00]]""") - @unittest.skip("cond of pmap does not work in JAX. Issue #5178.") def test_tap_cond_pmap(self): # A matrix M[ij] = i * 10 + j @@ -1716,146 +1304,6 @@ def fun2(cond, xv, do_print=False): assertMultiLineStrippedEqual(self, """ TBD""", testing_stream.output) - @jtu.sample_product(device_index=[0, 1]) - def test_tap_pjit(self, device_index=0): - self.supported_only_in_legacy_mode() - if (device_index != 0 and - not hcb._HOST_CALLBACK_OUTFEED.value and - jtu.test_device_matches(["cpu"])): - # See comment in host_callback.py. - raise SkipTest("device_index works only with outfeed on CPU") - - devices = np.array(local_devices()) - nr_devices = len(devices) - if nr_devices < 2: - raise SkipTest("test requires at least 2 devices") - - logging.info(f"test_tap_pjit is running on devices {devices}.") - # x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...] - # y: i32[3, 4] - x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3] - y = jnp.ones((3, 4), np.int32) - - @partial(jax.named_call, name="fun1") # for xprof debugging - def fun1(x): - z = jnp.dot(x, y) - return hcb_id_print(z, what="z", - output_stream=testing_stream, - tap_with_device=True, device_index=device_index) - - pjit_fun1 = pjit.pjit(fun1, in_shardings=(P("d"),), out_shardings=P("d")) - - with jax.sharding.Mesh(devices, ["d"]): - # Print the internal IR - helper_log_ir( - f"{self._testMethodName}.pjit", - pjit_fun1, - x, - num_partitions=nr_devices) - res = pjit_fun1(x) - - self.assertAllClose(jnp.dot(x, y), res) - hcb.barrier_wait("before check") - - # Assertion text is for 2 devices (also works for 1 device) - # Note that a single call is made. - assertMultiDeviceOutputEqual( - self, f""" - device: cpu:{device_index} what: z - [[ 3 3 3 3] - [33 33 33 33]]""") - - def test_tap_scan_custom_jvp(self): - """custom JVP, inside scan. - This exercises the custom_jvp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_jvp - def f(x): - return x * hcb_id_print(x, output_stream=testing_stream, what="x") - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - primal_out = f(x) - tangent_out = 3. * x * hcb_id_print(x_dot, output_stream=testing_stream, what="x_dot") - return primal_out, tangent_out - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((2,), 0.7) - self.assertAllClose(0.7 * 0.7 * 2, g(arg)) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7""", testing_stream.output) - testing_stream.reset() - - self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7 - transforms: ['transpose'] what: x_dot - 2.1 - transforms: ['transpose'] what: x_dot - 2.1""", testing_stream.output) - - def test_tap_scan_custom_vjp(self): - """custom VJP, inside scan. - This exercises the custom_vjp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_vjp - def f(x): - return x * hcb_id_print(x, output_stream=testing_stream, what="x") - - # f_fwd: a -> (b, residual) - def f_fwd(x): - return f(x), 3. * x - - # f_bwd: (residual, CT b) -> [CT a] - def f_bwd(residual, ct_b): - return residual * hcb_id_print(ct_b, output_stream=testing_stream, what="ct_b"), - - f.defvjp(f_fwd, f_bwd) - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((2,), 0.7) - - self.assertAllClose(0.7 * 0.7 * 2, g(arg)) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7""", testing_stream.output) - testing_stream.reset() - - self.assertAllClose(np.array([2.1, 2.1]), jax.grad(g)(arg), check_dtypes=False) - hcb.barrier_wait() - self.assertMultiLineStrippedEqual(""" - what: x - 0.7 - what: x - 0.7 - what: ct_b - 1. - what: ct_b - 1.""", testing_stream.output) - def test_tap_callback_delay(self): hcb.callback_extra = lambda dev: time.sleep(1) @@ -1978,52 +1426,6 @@ def loss(k): 10""" self.assertMultiLineStrippedEqual(expected, testing_stream.output) - @jtu.sample_product( - use_result=[True, False], - grad_func=["grad", "value_and_grad"], - use_remat=["old", "new", "none"], - ) - def test_tap_remat(self, use_result=False, grad_func="grad", use_remat="new"): - self.supported_only_in_legacy_mode() - if use_remat == "old": raise SkipTest() - - def f(x): - id_print_result = hcb_id_print(x, output_stream=testing_stream) - if use_result: - x = id_print_result - return 3. * x - grad_f = jax.grad if grad_func == "grad" else jax.value_and_grad - if use_remat == "old": - trans_f = jax.remat(f) - elif use_remat == "new": - trans_f = ad_checkpoint.checkpoint(f) - else: - assert use_remat == "none" - trans_f = f - print(jax.make_jaxpr(grad_f(trans_f))(2.)) - grad_f(trans_f)(2.) - - hcb.barrier_wait() - - if use_remat == "none": - # GOOD: whether or not we use_result, we get the same callback. - expected = "2." - else: # use_remat - if use_result: - expected = """ - 2. - 2.""" - else: - if use_remat == "old": - # TODO: we should see two callbacks - expected = "" - else: - # Good: we see two callbacks, whether or not we use the result. - expected = """ - 2. - 2.""" - self.assertMultiLineStrippedEqual(expected, testing_stream.output) - def test_tap_named_call(self): def tap_scalar(init, do_print=False): @partial(jax.named_call, name="step") @@ -2066,10 +1468,6 @@ def tearDown(self) -> None: hcb.barrier_wait("HostCallbackCallTest.tearDown") super().tearDown() - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - def call_log_testing_stream(self, func, arg, *, result_shape, name=""): """Call `func` and log inputs and outputs to the testing stream""" @@ -2314,44 +1712,6 @@ def fun2(m): m = np.ones((2,), np.float32) helper_print_optimized_hlo(fun2, m) - def test_call_with_device(self): - self.supported_only_in_legacy_mode() - def callback_func(x, device=None): - testing_stream.write(f"device: {device}\n Called with {x}") - return x - - def func(x): - return hcb.call(callback_func, x, - result_shape=x, - call_with_device=True) - - self.assertEqual(3., func(3.)) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 - Called with 3.00""") - - def test_call_pmap(self): - self.supported_only_in_legacy_mode() - # Works for 1 or 2 devices - def callback_func(x, device=None): - testing_stream.write(f"device: {device}\n Called with {x}") - return x * np.array(3, np.int32) - - def fun(x): # x: i32 - return hcb.call(callback_func, x * 2, - result_shape=x, - call_with_device=True) - - xv = jnp.arange(len(local_devices()), dtype=jnp.int32) - res = jax.pmap(fun)(xv) - self.assertAllClose(jax.pmap(lambda x: x * 6)(xv), res) - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual(self, """ - device: cpu:0 - Called with 0 - device: cpu:1 - Called with 2""") - def test_call_vmap(self): def f_outside(x): return x @@ -2366,52 +1726,6 @@ def fun(x): else: jax.vmap(fun)(np.ones((2, 3))) - @jtu.sample_product(device_index=[0, 1]) - @jtu.skip_on_devices("cpu") # TODO: RET_CHECK failure - def test_call_pjit(self, device_index=0): - devices = np.array(local_devices()) - nr_devices = len(devices) - if nr_devices < 2: - raise SkipTest("test requires at least 2 devices") - - logging.info(f"test_call_pjit is running on devices {devices}.") - # x: i32[D, 3] = [[0, 1, 2], [10, 11, 12], ...] - # y: i32[3, 4] - x = jnp.arange(100, dtype=jnp.int32).reshape((10, 10))[:nr_devices, :3] - y = jnp.ones((3, 4), np.int32) - - def callback_x5_func(x, device=None): - testing_stream.write(f"device: {device}\n Called with {x}") - return x * np.array(5, np.int32) - - def fun(x): - xy = jnp.dot(x, y) - return hcb.call( - callback_x5_func, xy, result_shape=xy, call_with_device=True, - device_index=device_index) - - pjit_fun = pjit.pjit(fun, in_shardings=(P("d"),), out_shardings=P("d")) - with jax.sharding.Mesh(devices, ["d"]): - # Print the internal IR - helper_log_ir( - f"{self._testMethodName}.pjit", - pjit_fun, - x, - num_partitions=nr_devices) - - res = pjit_fun(x) - - expected_res = jnp.dot(x, y) * np.array(5, np.int32) - self.assertAllClose(expected_res, res, check_dtypes=False) - - hcb.barrier_wait("before assertion") - # Assertion text is for 2 devices (also works for 1 device) - assertMultiDeviceOutputEqual( - self, f""" - device: cpu:{device_index} - Called with [[ 3 3 3 3] - [33 33 33 33]]""") - def test_call_error_bad_result_shape(self): with self.assertRaisesRegex( ValueError, @@ -2452,36 +1766,6 @@ def helper_check_callback_errors(self, thunk: Callable, re.DOTALL)): hcb.barrier_wait("Waiting for error") - def test_call_error_callback_throws_exception(self): - self.supported_only_in_legacy_mode() - def f_outside(x): - raise ValueError("user exception") - def fun(x): - return hcb.call(f_outside, x, result_shape=x) - - self.helper_check_callback_errors(lambda: fun(3.), - "ValueError: user exception") - - def test_call_error_callback_returns_unexpected_shape(self): - self.supported_only_in_legacy_mode() - def fun(x): - return hcb.call(lambda x: (x, x), x, result_shape=x) - - self.helper_check_callback_errors(lambda: fun(3.), - "Callback func .* should have returned a result with pytree") - - def test_call_error_then_compute(self): - self.supported_only_in_legacy_mode() - # Continue computation on device after error - def f_outside(x): - raise ValueError("user exception") - def fun(x): - x1 = hcb.call(f_outside, x, result_shape=x) - return x1 - arg = np.arange(3, dtype=np.int32) - self.helper_check_callback_errors(lambda: self.assertAllClose(arg, fun(arg)), - "ValueError: user exception") - def call_jax_other_device( jax_outside_fun, arg, *, device, @@ -2602,491 +1886,5 @@ def f_outside(x): res_outside = jax.grad(jax.grad(f_outside))(5.) self.assertAllClose(res_jax, res_outside) - -class OutfeedRewriterTest(jtu.JaxTestCase): - - def setUp(self): - if jtu.test_device_matches(["gpu"]) and jax.device_count() > 1: - raise SkipTest("host_callback broken on multi-GPU platforms (#6447)") - if xla_bridge.using_pjrt_c_api(): - raise SkipTest("host_callback not implemented in PJRT C API") - super().setUp() - self.enter_context(jtu.ignore_warning( - category=DeprecationWarning, message="The host_callback APIs are deprecated")) - - def supported_only_in_legacy_mode(self): - if not hcb._HOST_CALLBACK_LEGACY.value: - self.skipTest("Not supported when JAX_HOST_CALLBACK_LEGACY=False") - - def assertRewrite(self, expected: str, func: Callable, args: Sequence, - has_input_token=True, has_output_token=True): - """Check that the rewrite of func(*args) matches expected.""" - jaxpr = jax.make_jaxpr(func)(*args) - rewritten = hcb._rewrite_closed_jaxpr(jaxpr, # noqa: F841 - has_input_token, has_output_token) - # Since it is somewhat annoying to update the Jaxpr assertions when we change - # the Jaxpr printing, we do not check these by default. It is recommended that - # before making changes to the code generation and Jaxpr rewriting, turn on - # the checking, update the expected Jaxpr, and then make the changes. - # assertMultiLineStrippedEqual(self, expected, str(rewritten)) - del rewritten - - def test_no_outfeed(self): - self.assertRewrite(""" - { lambda ; a. - let b = mul a a - c = add a b - in (c,) }""", lambda x: x + x * x, [0], has_input_token=False, - has_output_token=False) - self.assertRewrite(""" - { lambda ; a d e. - let b = mul a a - c = add a b - in (c,) }""", lambda x: x + x * x, [0], has_output_token=False) - self.assertRewrite(""" - { lambda ; a d e. - let b = mul a a - c = add a b - in (c, d, e) }""", lambda x: x + x * x, [0]) - - def test_simple_outfeed(self): - self.assertRewrite(""" - { lambda ; a d e. - let b = add a a - c f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] b d e - in (c, f, g) }""", lambda x: hcb_id_print(x + x), [0]) - - def test_simple_outfeed_without_input_token(self): - self.assertRewrite(""" - { lambda ; a b. - let e = create_token a b - f = create_token a b - c = add a b - d g h = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c e f - in (d,) }""", lambda x1, x2: hcb_id_print(x1 + x2), [1, 2], - has_input_token=False, has_output_token=False) - - def test_simple_outfeed_without_input_token_nor_invars(self): - self.assertRewrite(""" - { lambda ; . - let b = create_token - c = create_token - a d e = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] 42 b c - in (a,) }""", lambda: hcb_id_print(42), [], - has_input_token=False, has_output_token=False) - - def test_multiple_tap_without_dependencies(self): - def f(x): - hcb_id_print(x, what="x") - hcb_id_print(x + 1, what="x + 1") - return 2 - - self.assertRewrite(""" - { lambda ; a c d. - let _ e f = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a c d - b = add a 1 - _ g h = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] b e f - in (2, g, h) }""", f, [1]) - - def test_cond(self): - y = jnp.ones(5) # captured const - - def func(x, z): - return lax.cond(z > 0, (1, 2), lambda a: (a[0], jnp.zeros(5)), - z, lambda a: (hcb_id_print(a), y)) - - self.assertRewrite(""" - { lambda a ; b c h i. - let d = gt c 0 - e = convert_element_type[ new_dtype=int32 ] d - f g j k = - cond[ branches=( { lambda ; a b c d f g. - let e h i = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] d f g - in (e, a, h, i) } - { lambda ; f_ a b c g h. - let d = broadcast_in_dim[ broadcast_dimensions=( ) - shape=(5,) ] 0.00 - in (a, d, g, h) } ) ] e a 1 2 c h i - in (f, g, j, k) }""", func, [y, 5]) - - def test_while(self): - ct_body = jnp.ones(5, np.float32) # captured const for the body - ct_cond = jnp.ones(5, np.float32) # captured const for the conditional - - def func(x): - # x: f32[5] - # c: (f32[5], f32) - return lax.while_loop(lambda c: c[1] < jnp.sum(c[0] + ct_cond), - lambda c: (ct_body, hcb_id_print(c[1]) + 1.), - (x, np.float32(1.))) - - self.assertRewrite(""" - { lambda a b ; c f g. - let d e h i = - while[ body_jaxpr={ lambda ; a b c f g. - let d h i = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c f g - e = add d 1.00 - in (a, e, h, i) } - body_nconsts=1 - cond_jaxpr={ lambda ; a b c g h. - let d = add b a - e = reduce_sum[ axes=(0,) ] d - f = lt c e - in (f,) } - cond_nconsts=1 ] a b c 1.00 f g - in (d, e, h, i) }""", func, [ct_body]) - - def test_while_pred_outfeed(self): - """A while with outfeed in the pred.""" - ct_body = jnp.ones(5) # captured const for the body - ct_cond = jnp.ones(2) # captured const for the conditional - - def func(x): - return lax.while_loop(lambda c: hcb_id_print(ct_cond, result=c[1]) < 5, - lambda c: (ct_body, hcb_id_print(c[1]) + 1), - (x, 1)) - - self.assertRewrite(""" - { lambda a b ; c f g. - let j k l = xla_call[ call_jaxpr={ lambda ; a b c g h. - let d i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a g h - e = id_tap_dep c d - f = lt e 5 - in (f, i, j) } - donated_invars=(False, False, False, False, False) - name=cond_before ] a c 1 f g - bf d e h i = - while[ body_jaxpr={ lambda ; r s t u v w x. - let y z ba bb = - xla_call[ call_jaxpr={ lambda ; a b c f g. - let d h i = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c f g - e = add d 1 - in (a, e, h, i) } - donated_invars=(False, False, False, False, False) - name=body ] s u v w x - bc bd be = - xla_call[ call_jaxpr={ lambda ; a b c g h. - let d i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a g h - e = id_tap_dep c d - f = lt e 5 - in (f, i, j) } - donated_invars=(False, False, False, False, False) - name=cond_body ] r y z ba bb - in (bc, y, z, bd, be) } - body_nconsts=2 - cond_jaxpr={ lambda ; m n o p q. - let - in (m,) } - cond_nconsts=0 ] a b j c 1 k l - in (d, e, h, i) }""", func, [ct_body]) - - def test_scan(self): - y = jnp.ones(5) # captured const - - def func(x): - return lax.scan(lambda c, a: (hcb_id_print(c), y), (1, 2), x) - - self.assertRewrite(""" - { lambda a ; b f g. - let c d h i e = - scan[ jaxpr={ lambda ; a b c g h d. - let e f i j = - outside_call[ arg_treedef=PyTreeDef(tuple, [*,*]) - callback=... - has_token=True - identity=True ] b c g h - in (e, f, i, j, a) } - length=5 - linear=(False, False, False, False, False, False) - num_carry=4 - num_consts=1 - reverse=False - unroll=1 ] a 1 2 f g b - in (c, d, e, h, i) }""", func, [y]) - - def test_scan_custom_jvp(self): - """custom JVP, inside scan. - This exercises the custom_jvp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_jvp - def f(x): - return x * hcb_id_print(x) - - @f.defjvp - def f_jvp(primals, tangents): - x, = primals - x_dot, = tangents - primal_out = f(x) - tangent_out = 3. * x * hcb_id_print(x_dot) - return primal_out, tangent_out - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((5,), 0.7) - self.assertRewrite(""" - { lambda ; a c d. - let b e f _ = - scan[ jaxpr={ lambda ; a e f b. - let c g h = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 ] b e f - d = add a c - in (d, g, h, 0.00) } - length=5 - linear=(False, False, False, False) - num_carry=3 - num_consts=0 - reverse=False - unroll=1 ] 0.00 c d a - in (b, e, f) }""", g, [arg]) - self.assertRewrite(""" - { lambda ; a d e. - let _ _ f g _ b = - scan[ jaxpr={ lambda ; a b h i c d. - let e j k = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 ] c h i - f = add a e - g = mul c 3.00 - in (f, *, j, k, 0.00, g) } - length=5 - linear=(False, True, False, False, False, True) - num_carry=4 - num_consts=0 - reverse=False - unroll=1 ] 0.00 * d e a * - _ _ h i _ c = - scan[ jaxpr={ lambda ; a b g h c d. - let e = mul b d - f i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True - transforms=(('transpose',),) ] e g h - in (*, b, i, j, *, f) } - length=5 - linear=(True, True, False, False, True, False) - num_carry=4 - num_consts=0 - reverse=True - unroll=1 ] * 1.00 f g * b - in (c, h, i) }""", jax.grad(g), [arg]) - - def test_scan_custom_vjp(self): - """custom VJP, inside scan. - This exercises the custom_vjp_call_jaxpr primitives.""" - self.supported_only_in_legacy_mode() - @jax.custom_vjp - def f(x): - return x * hcb_id_print(x) - - # f_fwd: a -> (b, residual) - def f_fwd(x): - return f(x), 3. * x - - # f_bwd: (residual, CT b) -> [CT a] - def f_bwd(residual, ct_b): - return residual * hcb_id_print(ct_b), - - f.defvjp(f_fwd, f_bwd) - - def g(x): - # Sum f(x_i) - return lax.scan(lambda carry, inp: (carry + f(inp), 0.), - np.full(x.shape[1:], 0.), # Like x w/o leading dim - x)[0] - - arg = np.full((2,), 0.7) - self.assertRewrite(""" - { lambda ; a c d. - let b e f _ = - scan[ jaxpr={ lambda ; a e f b. - let c g h = custom_vjp_call_jaxpr[ - fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 - ] b e f - d = add a c - in (d, g, h, 0.00) } - length=2 - linear=(False, False, False, False) - num_carry=3 - num_consts=0 - reverse=False - unroll=1 ] 0.00 c d a - in (b, e, f) }""", g, [arg]) - self.assertRewrite(""" - { lambda ; a d e. - let _ _ f g _ b = - scan[ jaxpr={ lambda ; a b h i c d. - let e j k = custom_vjp_call_jaxpr[ - fun_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = mul a b - in (c, f, g) } - num_consts=0 - ] c h i - f = add a e - g = mul c 3.00 - in (f, *, j, k, 0.00, g) } - length=2 - linear=(False, True, False, False, False, True) - num_carry=4 - num_consts=0 - reverse=False - unroll=1 ] 0.00 * d e a * - _ _ h i _ c = - scan[ jaxpr={ lambda ; a b g h c d. - let e i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] b g h - f = mul d e - in (*, b, i, j, *, f) } - length=2 - linear=(True, True, False, False, True, False) - num_carry=4 - num_consts=0 - reverse=True - unroll=1 ] * 1.00 f g * b - in (c, h, i) }""", jax.grad(g), [arg]) - - def test_remat_loop(self): - def f(k, x): - x = hcb_id_print(k + x) - return -k * x - - def loss(k): - return lax.fori_loop(0, 1, jax.remat(f), k) - - self.assertRewrite(""" - { lambda ; a c d. - let _ _ b e f = - while[ body_jaxpr={ lambda ; a b c f g. - let d = add a 1 - e h i = remat_call[ call_jaxpr={ lambda ; a b g h. - let c = add a b - d i j = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] c g h - e = neg a - f = mul e d - in (f, i, j) } - concrete=False - name=f ] a c f g - in (d, b, e, h, i) } - body_nconsts=0 - cond_jaxpr={ lambda ; a b c e f. - let d = lt a b - in (d,) } - cond_nconsts=0 ] 0 1 a c d - in (b, e, f) }""", loss, [2]) - - def test_named_call(self): - def tap_scalar(init, do_print=False): - @partial(jax.named_call, name="step") - def step(acc, step_nr): - acc = acc + step_nr - maybe_print(do_print, step_nr, what="step_nr") - return acc, None - - return lax.scan(step, init, np.arange(2, dtype=np.int32)) - - self.assertRewrite(""" - { lambda a ; b d e. - let c = scan[ jaxpr={ lambda ; a b. - let c = named_call[ call_jaxpr={ lambda ; a b. - let c = add a b - in (c,) } - name=step ] a b - in (c,) } - length=2 - linear=(False, False) - num_carry=1 - num_consts=0 - reverse=False - unroll=1 ] b a - in (c, d, e) }""", tap_scalar, [np.int32(3)]) - - def test_pmap(self): - self.supported_only_in_legacy_mode() - def f(xv): - jax.pmap(lambda x: jnp.sin(hcb_id_print(x, tap_with_device=True)), - axis_name="i")(xv) - - self.assertRewrite(""" - { lambda ; a b c. - let _ d e = xla_pmap[ axis_name=i - axis_size=1 - backend=None - call_jaxpr={ lambda ; a d e. - let b f g = outside_call[ arg_treedef=* - callback=... - has_token=True - identity=True ] a d e - c = sin b - in (c, f, g) } - devices=None - donated_invars=(False, False, False) - global_axis_size=None - in_axes=(0, 0, 0) - name= - out_axes=(0, 0, 0) ] a b c - in (d, e) }""", f, [np.array([2.], dtype=np.float32)]) - - if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())