Skip to content

Commit

Permalink
[PJRT C API] Bump the minimum support libtpu version and clean up ver…
Browse files Browse the repository at this point in the history
…sion check.

PiperOrigin-RevId: 610508159
  • Loading branch information
Jieying Luo authored and jax authors committed Feb 26, 2024
1 parent 3166cc3 commit ca1844d
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 62 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
tpu-type: ["v3-8", "v4-8", "v5e-4"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20230927
LIBTPU_OLDEST_VERSION_DATE: 20231030
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
timeout-minutes: 120
Expand Down
25 changes: 0 additions & 25 deletions tests/memories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,10 +677,6 @@ def test_sharding_devices_indices_map_cache_hit(self):
self.assertEqual(cache_info2.misses, cache_info1.misses)

def test_device_put_host_to_hbm(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")

mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
np_inp = jnp.arange(16).reshape(8, 2)
Expand All @@ -698,10 +694,6 @@ def f(x):
out_on_hbm, np_inp, s_hbm, "device")

def test_device_put_hbm_to_host(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")

mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("y"), memory_kind="unpinned_host")
inp = jnp.arange(16).reshape(8, 2)
Expand All @@ -718,9 +710,6 @@ def test_device_put_hbm_to_host(self):
def test_device_put_different_device_and_memory_host_to_hbm(self):
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")

out_host0 = jax.device_put(
jnp.arange(8),
Expand All @@ -738,9 +727,6 @@ def test_device_put_different_device_and_memory_host_to_hbm(self):
def test_device_put_different_device_and_memory_hbm_to_host(self):
if jax.device_count() < 3:
raise unittest.SkipTest("Test requires >=3 devices")
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")

out_hbm0 = jnp.arange(8)

Expand All @@ -758,9 +744,6 @@ def test_device_put_different_device_and_memory_hbm_to_host(self):
def test_device_put_on_different_device_with_the_same_memory_kind(self):
if len(jax.devices()) < 2:
raise unittest.SkipTest("Test requires >=2 devices.")
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")

np_inp = np.arange(16).reshape(8, 2)

Expand All @@ -779,10 +762,6 @@ def test_device_put_on_different_device_with_the_same_memory_kind(self):
out_host_dev_1, np_inp, s_host_dev_1, "unpinned_host")

def test_device_put_resharding(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")

mesh = jtu.create_global_mesh((2, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
s_hbm = s_host.with_memory_kind("device")
Expand All @@ -807,10 +786,6 @@ def test_device_put_resharding(self):
out_sharded_hbm, np_inp, s_hbm, "device")

def test_jit_host_inputs_via_device_put_outside(self):
# TODO(jieying): remove after 12/26/2023.
if not jtu.pjrt_c_api_version_at_least(0, 32):
raise unittest.SkipTest("CopyToMemorySpace is not supported on PJRT C API version < 0.32.")

mesh = jtu.create_global_mesh((4, 2), ("x", "y"))
s_host = NamedSharding(mesh, P("x", "y"), memory_kind="unpinned_host")
s_hbm = s_host.with_memory_kind("device")
Expand Down
6 changes: 0 additions & 6 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import math
import os
import tempfile
import unittest

from absl.testing import absltest
import jax
Expand All @@ -37,11 +36,6 @@
class PgleTest(jtu.JaxTestCase):

def testPassingFDOProfile(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
'Profiler is not supported on PJRT C API version < 0.34.'
)
mesh = jtu.create_global_mesh((2,), ('x',))
@partial(
jax.jit,
Expand Down
30 changes: 0 additions & 30 deletions tests/profiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,6 @@ def testCantStopServerBeforeStartingServer(self):
jax.profiler.stop_server()

def testProgrammaticProfiling(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
with tempfile.TemporaryDirectory() as tmpdir:
try:
jax.profiler.start_trace(tmpdir)
Expand All @@ -110,11 +105,6 @@ def testProgrammaticProfiling(self):
self.assertIn(b"pxla.py", proto)

def testProfilerGetFDOProfile(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
# Tests stop_and_get_fod_profile could run.
try:
jax.profiler.start_trace("test")
Expand All @@ -127,11 +117,6 @@ def testProfilerGetFDOProfile(self):
self.assertIn(b"copy", fdo_profile)

def testProgrammaticProfilingErrors(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
with self.assertRaisesRegex(RuntimeError, "No profile started"):
jax.profiler.stop_trace()

Expand All @@ -147,11 +132,6 @@ def testProgrammaticProfilingErrors(self):
jax.profiler.stop_trace()

def testProgrammaticProfilingContextManager(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
with tempfile.TemporaryDirectory() as tmpdir:
with jax.profiler.trace(tmpdir):
jax.pmap(lambda x: jax.lax.psum(x + 1, 'i'), axis_name='i')(
Expand Down Expand Up @@ -208,11 +188,6 @@ def _check_xspace_pb_exist(self, logdir):
@unittest.skipIf(not (portpicker and profiler_client and tf_profiler),
"Test requires tensorflow.profiler and portpicker")
def testSingleWorkerSamplingMode(self, delay_ms=None):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
def on_worker(port, worker_start):
jax.profiler.start_server(port)
worker_start.set()
Expand Down Expand Up @@ -258,11 +233,6 @@ def on_profile(port, logdir, worker_start):
"Test requires tensorflow.profiler, portpicker and "
"tensorboard_profile_plugin")
def test_remote_profiler(self):
# TODO(jieying): remove after 01/10/2023.
if not jtu.pjrt_c_api_version_at_least(0, 34):
raise unittest.SkipTest(
"Profiler is not supported on PJRT C API version < 0.34."
)
port = portpicker.pick_unused_port()
jax.profiler.start_server(port)

Expand Down

0 comments on commit ca1844d

Please sign in to comment.