From ca1844dd60ea527356028b0f3d91cf3366a02e35 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Mon, 26 Feb 2024 13:25:01 -0800 Subject: [PATCH] [PJRT C API] Bump the minimum support libtpu version and clean up version check. PiperOrigin-RevId: 610508159 --- .github/workflows/cloud-tpu-ci-nightly.yml | 2 +- tests/memories_test.py | 25 ------------------ tests/pgle_test.py | 6 ----- tests/profiler_test.py | 30 ---------------------- 4 files changed, 1 insertion(+), 62 deletions(-) diff --git a/.github/workflows/cloud-tpu-ci-nightly.yml b/.github/workflows/cloud-tpu-ci-nightly.yml index 4ff929a62523..fdbe99e9dbf6 100644 --- a/.github/workflows/cloud-tpu-ci-nightly.yml +++ b/.github/workflows/cloud-tpu-ci-nightly.yml @@ -28,7 +28,7 @@ jobs: tpu-type: ["v3-8", "v4-8", "v5e-4"] name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})" env: - LIBTPU_OLDEST_VERSION_DATE: 20230927 + LIBTPU_OLDEST_VERSION_DATE: 20231030 ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }} runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"] timeout-minutes: 120 diff --git a/tests/memories_test.py b/tests/memories_test.py index 0cff66b95a73..d3784ab59209 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -677,10 +677,6 @@ def test_sharding_devices_indices_map_cache_hit(self): self.assertEqual(cache_info2.misses, cache_info1.misses) def test_device_put_host_to_hbm(self): - # TODO(jieying): remove after 12/26/2023. - if not jtu.pjrt_c_api_version_at_least(0, 32): - raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host") np_inp = jnp.arange(16).reshape(8, 2) @@ -698,10 +694,6 @@ def f(x): out_on_hbm, np_inp, s_hbm, "device") def test_device_put_hbm_to_host(self): - # TODO(jieying): remove after 12/26/2023. - if not jtu.pjrt_c_api_version_at_least(0, 32): - raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host") inp = jnp.arange(16).reshape(8, 2) @@ -718,9 +710,6 @@ def test_device_put_hbm_to_host(self): def test_device_put_different_device_and_memory_host_to_hbm(self): if jax.device_count() < 3: raise unittest.SkipTest("Test requires >=3 devices") - # TODO(jieying): remove after 12/26/2023. - if not jtu.pjrt_c_api_version_at_least(0, 32): - raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") out_host0 = jax.device_put( jnp.arange(8), @@ -738,9 +727,6 @@ def test_device_put_different_device_and_memory_host_to_hbm(self): def test_device_put_different_device_and_memory_hbm_to_host(self): if jax.device_count() < 3: raise unittest.SkipTest("Test requires >=3 devices") - # TODO(jieying): remove after 12/26/2023. - if not jtu.pjrt_c_api_version_at_least(0, 32): - raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") out_hbm0 = jnp.arange(8) @@ -758,9 +744,6 @@ def test_device_put_different_device_and_memory_hbm_to_host(self): def test_device_put_on_different_device_with_the_same_memory_kind(self): if len(jax.devices()) < 2: raise unittest.SkipTest("Test requires >=2 devices.") - # TODO(jieying): remove after 12/26/2023. - if not jtu.pjrt_c_api_version_at_least(0, 32): - raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") np_inp = np.arange(16).reshape(8, 2) @@ -779,10 +762,6 @@ def test_device_put_on_different_device_with_the_same_memory_kind(self): out_host_dev_1, np_inp, s_host_dev_1, "unpinned_host") def test_device_put_resharding(self): - # TODO(jieying): remove after 12/26/2023. - if not jtu.pjrt_c_api_version_at_least(0, 32): - raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") - mesh = jtu.create_global_mesh((2, 2), ("x", "y")) s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") s_hbm = s_host.with_memory_kind("device") @@ -807,10 +786,6 @@ def test_device_put_resharding(self): out_sharded_hbm, np_inp, s_hbm, "device") def test_jit_host_inputs_via_device_put_outside(self): - # TODO(jieying): remove after 12/26/2023. - if not jtu.pjrt_c_api_version_at_least(0, 32): - raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.") - mesh = jtu.create_global_mesh((4, 2), ("x", "y")) s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host") s_hbm = s_host.with_memory_kind("device") diff --git a/tests/pgle_test.py b/tests/pgle_test.py index e0adf8963e29..466da7f27067 100644 --- a/tests/pgle_test.py +++ b/tests/pgle_test.py @@ -18,7 +18,6 @@ import math import os import tempfile -import unittest from absl.testing import absltest import jax @@ -37,11 +36,6 @@ class PgleTest(jtu.JaxTestCase): def testPassingFDOProfile(self): - # TODO(jieying): remove after 01/10/2023. - if not jtu.pjrt_c_api_version_at_least(0, 34): - raise unittest.SkipTest( - 'Profiler is not supported on PJRT C API version < 0.34.' - ) mesh = jtu.create_global_mesh((2,), ('x',)) @partial( jax.jit, diff --git a/tests/profiler_test.py b/tests/profiler_test.py index 325392bddf50..c232c3afd699 100644 --- a/tests/profiler_test.py +++ b/tests/profiler_test.py @@ -84,11 +84,6 @@ def testCantStopServerBeforeStartingServer(self): jax.profiler.stop_server() def testProgrammaticProfiling(self): - # TODO(jieying): remove after 01/10/2023. - if not jtu.pjrt_c_api_version_at_least(0, 34): - raise unittest.SkipTest( - "Profiler is not supported on PJRT C API version < 0.34." - ) with tempfile.TemporaryDirectory() as tmpdir: try: jax.profiler.start_trace(tmpdir) @@ -110,11 +105,6 @@ def testProgrammaticProfiling(self): self.assertIn(b"pxla.py", proto) def testProfilerGetFDOProfile(self): - # TODO(jieying): remove after 01/10/2023. - if not jtu.pjrt_c_api_version_at_least(0, 34): - raise unittest.SkipTest( - "Profiler is not supported on PJRT C API version < 0.34." - ) # Tests stop_and_get_fod_profile could run. try: jax.profiler.start_trace("test") @@ -127,11 +117,6 @@ def testProfilerGetFDOProfile(self): self.assertIn(b"copy", fdo_profile) def testProgrammaticProfilingErrors(self): - # TODO(jieying): remove after 01/10/2023. - if not jtu.pjrt_c_api_version_at_least(0, 34): - raise unittest.SkipTest( - "Profiler is not supported on PJRT C API version < 0.34." - ) with self.assertRaisesRegex(RuntimeError, "No profile started"): jax.profiler.stop_trace() @@ -147,11 +132,6 @@ def testProgrammaticProfilingErrors(self): jax.profiler.stop_trace() def testProgrammaticProfilingContextManager(self): - # TODO(jieying): remove after 01/10/2023. - if not jtu.pjrt_c_api_version_at_least(0, 34): - raise unittest.SkipTest( - "Profiler is not supported on PJRT C API version < 0.34." - ) with tempfile.TemporaryDirectory() as tmpdir: with jax.profiler.trace(tmpdir): jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')( @@ -208,11 +188,6 @@ def _check_xspace_pb_exist(self, logdir): @unittest.skipIf(not (portpicker and profiler_client and tf_profiler), "Test requires tensorflow.profiler and portpicker") def testSingleWorkerSamplingMode(self, delay_ms=None): - # TODO(jieying): remove after 01/10/2023. - if not jtu.pjrt_c_api_version_at_least(0, 34): - raise unittest.SkipTest( - "Profiler is not supported on PJRT C API version < 0.34." - ) def on_worker(port, worker_start): jax.profiler.start_server(port) worker_start.set() @@ -258,11 +233,6 @@ def on_profile(port, logdir, worker_start): "Test requires tensorflow.profiler, portpicker and " "tensorboard_profile_plugin") def test_remote_profiler(self): - # TODO(jieying): remove after 01/10/2023. - if not jtu.pjrt_c_api_version_at_least(0, 34): - raise unittest.SkipTest( - "Profiler is not supported on PJRT C API version < 0.34." - ) port = portpicker.pick_unused_port() jax.profiler.start_server(port)