From d3f63a66b8a060a62045617294bfb7de690dce52 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 4 Oct 2024 11:14:01 -0400 Subject: [PATCH] Remove code to support jaxlib <= 0.4.33. --- jax/_src/interpreters/pxla.py | 12 +-- jax/_src/lax/linalg.py | 11 +-- jax/_src/pjit.py | 107 ++++++++-------------- jax/experimental/host_callback.py | 144 ++---------------------------- tests/export_back_compat_test.py | 16 ++-- tests/lax_test.py | 13 +-- tests/layout_test.py | 16 ---- tests/memories_test.py | 17 ---- tests/pjit_test.py | 8 +- tests/tree_util_test.py | 7 +- 10 files changed, 60 insertions(+), 291 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d531727f410f..4aa05010cd7a 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -62,7 +62,6 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -3055,14 +3054,9 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - if xla_extension_version >= 286: - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], - JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) - else: - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, cc_shard_arg) + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], + JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): return shard_args([sharding], [layout], [x])[0] diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 1ae11bee30c9..ec2dd91b258a 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -47,7 +47,6 @@ from jax._src.lib import gpu_solver from jax._src.lib import gpu_sparse from jax._src.lib import lapack -from jax._src.lib import version as jaxlib_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo from jax._src.lib.mlir.dialects import hlo @@ -709,8 +708,7 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] op_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - ctx_args = (ctx,) - w, vl, vr, info = lapack.geev_hlo(*ctx_args, operand_aval.dtype, operand, + w, vl, vr, info = lapack.geev_hlo(ctx, operand_aval.dtype, operand, input_shape_vals=op_shape_vals, jobvl=compute_left_eigenvectors, jobvr=compute_right_eigenvectors) @@ -2033,8 +2031,7 @@ def _svd_cpu_gpu_lowering( compute_uv=compute_uv) else: a_shape_vals = mlir.eval_dynamic_shape_as_ivals(ctx, operand_aval.shape) - ctx_args = (ctx,) - s, u, vt, info = gesvd_impl(*ctx_args, operand_aval.dtype, operand, + s, u, vt, info = gesvd_impl(ctx, operand_aval.dtype, operand, full_matrices=full_matrices, compute_uv=compute_uv, a_shape_vals=a_shape_vals) @@ -2477,9 +2474,7 @@ def _hessenberg_batching_rule(batched_args, batch_dims): def _hessenberg_cpu_hlo(ctx, a): a_aval, = ctx.avals_in batch_dims = a_aval.shape[:-2] - # TODO(b/344892332): Remove the conditional after the compatibility period. - ctx_args = (ctx,) if jaxlib_version >= (0, 4, 34) else () - a, taus, info = lapack.gehrd_hlo(*ctx_args, a_aval.dtype, a) + a, taus, info = lapack.gehrd_hlo(ctx, a_aval.dtype, a) ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 0a75128477ce..5a08a4d414d8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -62,7 +62,6 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc -from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( @@ -322,28 +321,11 @@ def _cpp_pjit_evict_fn(self): _cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) -if xla_extension_version < 286: - def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache_fun_only - - def _pjit_explicit_sharding_and_layout( - in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, - device, backend) -> bool: - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(o) for o in out_shardings_flat) or - any(i is not None for i in in_layouts_flat) or - any(o is not None for o in out_layouts_flat)) -else: - def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore - if contains_explicit_attributes: - return _cpp_pjit_cache_explicit_attributes - else: - return _cpp_pjit_cache_fun_only +def _get_cpp_global_cache(contains_explicit_attributes: bool): + if contains_explicit_attributes: + return _cpp_pjit_cache_explicit_attributes + else: + return _cpp_pjit_cache_fun_only def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @@ -364,35 +346,24 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - if xla_extension_version >= 286: - cache_key = pxla.JitGlobalCppCacheKeys( - donate_argnums=jit_info.donate_argnums, - donate_argnames=jit_info.donate_argnames, - device=jit_info.device, backend=jit_info.backend, - in_shardings_treedef=jit_info.in_shardings_treedef, - in_shardings_leaves=jit_info.in_shardings_leaves, - out_shardings_treedef=jit_info.out_shardings_treedef, - out_shardings_leaves=jit_info.out_shardings_leaves, - in_layouts_treedef=jit_info.in_layouts_treedef, - in_layouts_leaves=jit_info.in_layouts_leaves, - out_layouts_treedef=jit_info.out_layouts_treedef, - out_layouts_leaves=jit_info.out_layouts_leaves, - use_resource_env=jit_info.use_resource_env) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), fun, cache_miss, jit_info.static_argnums, - jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore - pxla.cc_shard_arg, - _get_cpp_global_cache(cache_key.contains_explicit_attributes)) - else: - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, - jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, - jit_info.device, jit_info.backend) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), fun, cache_miss, jit_info.static_argnums, - jit_info.static_argnames, jit_info.donate_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding)) + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=jit_info.donate_argnums, + donate_argnames=jit_info.donate_argnames, + device=jit_info.device, backend=jit_info.backend, + in_shardings_treedef=jit_info.in_shardings_treedef, + in_shardings_leaves=jit_info.in_shardings_leaves, + out_shardings_treedef=jit_info.out_shardings_treedef, + out_shardings_leaves=jit_info.out_shardings_leaves, + in_layouts_treedef=jit_info.in_layouts_treedef, + in_layouts_leaves=jit_info.in_layouts_leaves, + out_layouts_treedef=jit_info.out_layouts_treedef, + out_layouts_leaves=jit_info.out_layouts_leaves, + use_resource_env=jit_info.use_resource_env) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore + pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -1752,26 +1723,18 @@ def call_impl_cache_miss(*args_, **kwargs_): jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) - if xla_extension_version >= 286: - cache_key = pxla.JitGlobalCppCacheKeys( - donate_argnums=donated_argnums, donate_argnames=None, - device=None, backend=None, - in_shardings_treedef=None, in_shardings_leaves=in_shardings, - out_shardings_treedef=None, out_shardings_leaves=out_shardings, - in_layouts_treedef=None, in_layouts_leaves=in_layouts, - out_layouts_treedef=None, out_layouts_leaves=out_layouts, - use_resource_env=resource_env is not None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], cache_key, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) - else: - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings, out_shardings, in_layouts, out_layouts, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=donated_argnums, donate_argnames=None, + device=None, backend=None, + in_shardings_treedef=None, in_shardings_leaves=in_shardings, + out_shardings_treedef=None, out_shardings_leaves=out_shardings, + in_layouts_treedef=None, in_layouts_leaves=in_layouts, + out_layouts_treedef=None, out_layouts_leaves=out_layouts, + use_resource_env=resource_env is not None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], cache_key, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index bc5477ebc766..1ab44a4fd586 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -536,7 +536,6 @@ def power3_with_cotangents(x): from jax._src import xla_bridge as xb from jax._src.lib import xla_client from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo @@ -1079,117 +1078,6 @@ def _outside_call_impl(*args, **params): outside_call_p.def_impl(_outside_call_impl) -def _with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): - """Builds op_fn(*args, **kwargs) with sharding annotation.""" - builder.set_sharding(sharding_proto) - try: - return op_fn(*args, **kwargs) - finally: - builder.clear_sharding() - -def _outside_call_translation_rule(ctx, - avals_in, - avals_out, - *args_op: XlaOp, - has_token, - identity, - device_index, - flat_results_aval=(), - **params): - # We expect the current tokens at the end, inserted by _rewrite_jaxpr. - assert has_token - use_outfeed = _use_outfeed(ctx.platform) - assert use_outfeed, 'Should be using MLIR path for `CustomCall` lowering' - current_token = args_op[-2] - current_itoken = args_op[-1] - comp = ctx.builder - assert comp.get_shape(current_token).is_token() and comp.get_shape(current_itoken).is_token(), ( - "The last two arguments must be tokens") - - args_to_outfeed = args_op[:-2] - # Some platforms refuse to infeed empty arrays. We generate constants - # instead. - non_empty_flat_results_aval = list(filter(lambda aval: not (_aval_is_empty(aval)), - flat_results_aval)) - need_callback_results_on_device = (not identity and - len(non_empty_flat_results_aval) > 0) - send_infeed = use_outfeed and need_callback_results_on_device - generated_infeed = False # Keep track if we emitted an infeed op - - _raise_if_using_outfeed_with_pjrt_c_api(xb.get_backend(ctx.platform)) - callback_id = _register_callback( - functools.partial( - _outside_call_run_callback, - send_infeed=send_infeed, - identity=identity, - flat_results_aval=flat_results_aval, - **params)) - next_token = _callback_handler_data.receiver.add_outfeed( - comp, current_token, callback_id, args_to_outfeed, device_index) - if identity: - results = list(args_to_outfeed) - next_itoken = current_itoken - else: - empty_results = [ - xops.ConstantLiteral(comp, np.zeros(aval.shape, aval.dtype)) - for aval in flat_results_aval - if _aval_is_empty(aval) - ] - if non_empty_flat_results_aval: - assert need_callback_results_on_device - after_outfeed_itoken = xops.AfterAll(comp, [current_itoken, next_token]) - # We shard the infeed as AssignedDevice(device_index). This must match the - # outfeed (from outfeed_receiver.cc). Since `lax.infeed` does not support - # this kind of sharding, we use a custom translation for infeed. - array_sharding_proto = xla_client.OpSharding() - array_sharding_proto.type = xla_client.OpSharding.Type.MAXIMAL - array_sharding_proto.tile_assignment_dimensions = [1] - array_sharding_proto.tile_assignment_devices = [device_index] - - token_sharding_proto = xla_client.OpSharding() - token_sharding_proto.type = xla_client.OpSharding.Type.REPLICATED - infeed_sharding_proto = xla.tuple_sharding_proto( - [array_sharding_proto] * len(non_empty_flat_results_aval) + - [token_sharding_proto]) - - shape = [ - shape.with_major_to_minor_layout_if_absent() - for x in non_empty_flat_results_aval - for shape in xla.aval_to_xla_shapes(x) - ] - - build_infeed = functools.partial(xops.InfeedWithToken, - after_outfeed_itoken, - xla_client.Shape.tuple_shape(shape)) - outs_and_token = _with_sharding_proto(comp, infeed_sharding_proto, - build_infeed) - outs = xops.GetTupleElement(outs_and_token, 0) - next_itoken = xops.GetTupleElement(outs_and_token, 1) - non_empty_results = [ - xops.GetTupleElement(outs, i) - for i in range(len(non_empty_flat_results_aval)) - ] - generated_infeed = True - results = [ - empty_results.pop(0) - if _aval_is_empty(result_aval) else non_empty_results.pop(0) - for result_aval in flat_results_aval - ] - else: - results = empty_results - next_itoken = current_itoken - - assert generated_infeed == send_infeed, ( - f"generated_infeed ({generated_infeed}) != send_infeed ({send_infeed})") - assert identity or len(results) == len(flat_results_aval), ( - f"got {len(results)} but expected {len(flat_results_aval)}. " - f"identity = {identity}") - return results + [next_token, next_itoken] - -if xla_extension_version < 287: - xla.register_translation(outside_call_p, _outside_call_translation_rule) - - def _outside_call_outfeed_lowering(ctx: mlir.LoweringRuleContext, *args_op, identity, @@ -1318,25 +1206,14 @@ def _outside_call_lowering(ctx: mlir.LoweringRuleContext, platform = ctx.module_context.platforms[0] use_outfeed = _use_outfeed(platform) if use_outfeed: - if xla_extension_version < 287: - return mlir.xla_fallback_lowering(outside_call_p)( - ctx, - *args, - has_token=has_token, - identity=identity, - device_index=device_index, - flat_results_aval=flat_results_aval, - **params, - ) - else: - return _outside_call_outfeed_lowering( - ctx, *args, - has_token=has_token, - identity=identity, - flat_results_aval=flat_results_aval, - device_index=device_index, - **params, - ) + return _outside_call_outfeed_lowering( + ctx, *args, + has_token=has_token, + identity=identity, + flat_results_aval=flat_results_aval, + device_index=device_index, + **params, + ) else: # TODO(necula): It seems that on CPU, with custom call, the device_index # does not work, and the callback is always run on device_index=0 @@ -1405,10 +1282,7 @@ def wrapped_callback(*args): f"identity = {identity}") return list(results) + [next_token, next_itoken] -if xla_extension_version < 287: - mlir.register_lowering(outside_call_p, _outside_call_lowering, platform="cpu") -else: - mlir.register_lowering(outside_call_p, _outside_call_lowering) +mlir.register_lowering(outside_call_p, _outside_call_lowering) def _outside_call_run_callback( arrays, device, *, diff --git a/tests/export_back_compat_test.py b/tests/export_back_compat_test.py index 66e82cacbad8..e261e1dfce83 100644 --- a/tests/export_back_compat_test.py +++ b/tests/export_back_compat_test.py @@ -67,7 +67,6 @@ from jax._src import config from jax._src import test_util as jtu from jax._src.lib import cuda_versions -from jax._src.lib import version as jaxlib_version config.parse_flags_with_absl() @@ -705,15 +704,12 @@ def func(): cpu_hessenberg_lapack_gehrd.data_2024_08_30[dtype_name] ) self.run_one_test(func, data, rtol=rtol, atol=atol) - # TODO(b/344892332): Remove the check after the compatibility period. - has_xla_ffi_support = jaxlib_version >= (0, 4, 34) - if has_xla_ffi_support: - with config.export_ignore_forward_compatibility(True): - # FFI Kernel test - data = self.load_testdata( - cpu_hessenberg_lapack_gehrd.data_2024_08_31[dtype_name] - ) - self.run_one_test(func, data, rtol=rtol, atol=atol) + with config.export_ignore_forward_compatibility(True): + # FFI Kernel test + data = self.load_testdata( + cpu_hessenberg_lapack_gehrd.data_2024_08_31[dtype_name] + ) + self.run_one_test(func, data, rtol=rtol, atol=atol) def test_approx_top_k(self): def func(): diff --git a/tests/lax_test.py b/tests/lax_test.py index c75058194653..ab00bf06bf0b 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3975,8 +3975,7 @@ def _testOnComplexPlaneWorker(self, name, dtype, kind): size_im = 11 atol = None - if (name in {"arccos", "arcsin", "arcsinh", "arccosh"} - or name in {"arctan", "arctanh"} and jax._src.lib.version > (0, 4, 31)): + if name in {"arccos", "arcsin", "arcsinh", "arccosh", "arctan", "arctanh"}: # TODO(pearu): eliminate this if-block when a fix to mpmath#787 # becomes available extra_prec_multiplier = 20 @@ -4132,16 +4131,6 @@ def regions_with_inaccuracies_keep(*to_keep): elif name == 'arccos': regions_with_inaccuracies_keep('q4.imag', 'ninf', 'pinf', 'ninfj', 'pinfj.real') - elif name == 'arctan' and jax._src.lib.version <= (0, 4, 31): - if dtype == np.complex64: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', - 'mq1.imag', 'mq2.imag', 'mq3.imag', 'mq4.imag', 'mnegj.real', 'mnegj.imag', 'mposj.imag') - if dtype == np.complex128: - regions_with_inaccuracies_keep('q1', 'q2', 'q3', 'q4', 'neg', 'pos', 'negj', 'posj', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mnegj.real') - - elif name == 'arctanh' and jax._src.lib.version <= (0, 4, 31): - regions_with_inaccuracies_keep('pos.imag', 'ninf', 'pinf', 'ninfj', 'pinfj', 'mpos.imag') - elif name in {'cos', 'sin'}: regions_with_inaccuracies_keep('ninf.imag', 'pinf.imag') diff --git a/tests/layout_test.py b/tests/layout_test.py index 1d18179ccfee..699e36409d77 100644 --- a/tests/layout_test.py +++ b/tests/layout_test.py @@ -25,7 +25,6 @@ from jax._src.layout import Layout, DeviceLocalLayout as DLL from jax._src import test_util as jtu from jax._src.util import safe_zip -from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -46,9 +45,6 @@ def setUp(self): super().setUp() def test_auto_layout(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape1 = (128, 128) shape2 = (128, 128) @@ -114,9 +110,6 @@ def init(x, y): self.assertArraysEqual(apply_out[1], (np_inp2 * 2).T) def test_default_layout(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 4, 2) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -156,9 +149,6 @@ def f(x): out_shardings=DLL.AUTO).lower(sds).compile() def test_in_layouts_out_layouts(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (8, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -183,9 +173,6 @@ def f(x): self.assertEqual(out.sharding, NamedSharding(mesh, P('y', 'x'))) def test_sharding_and_layouts(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) shape = (4, 8) np_inp = np.arange(math.prod(shape)).reshape(shape) @@ -477,9 +464,6 @@ def test_incompatible_aval_error_device_put(self): jax.device_put(inp, l) def test_concrete_layout_in_shardings(self): - # Remove this condition when xla_extension_version >= 285 - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 285: - self.skipTest("Requires xla_extension_version >= 285 for GPU backend.") mesh = jtu.create_mesh((2, 2), ('x', 'y')) s = NamedSharding(mesh, P('x', 'y')) shape = (16, 128) diff --git a/tests/memories_test.py b/tests/memories_test.py index 7f4b75c8dbe3..1dbd6298d3ff 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -35,7 +35,6 @@ TransferToMemoryKind, PartitionSpec as P) from jax.experimental.compute_on import compute_on from jax.experimental.shard_map import shard_map -from jax._src.lib import xla_extension_version import numpy as np config.parse_flags_with_absl() @@ -416,8 +415,6 @@ def f(a, b): out, np_inp * np_inp, s_dev, "device") def test_parameter_streaming(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") _, s_host, np_inp, inp_host = _create_inputs( (8, 2), P("x", "y"), mem_kind="pinned_host") s_dev = s_host.with_memory_kind('device') @@ -461,8 +458,6 @@ def f(a): out, np_inp, s_host, 'pinned_host') def test_parameter_streaming_with_scalar_and_constant(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") mesh = jtu.create_mesh((2, 2), ("x", "y")) scalar_inp = 1 s_host = NamedSharding(mesh, P(), memory_kind="pinned_host") @@ -512,8 +507,6 @@ def f(x): ) def test_parameter_and_output_streaming_with_scalar(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") @@ -581,8 +574,6 @@ def body(carry, x): self.assertEqual(out_hbm.sharding, out_s) def test_output_streaming(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") mesh = jtu.create_mesh((1, 1), ("x", "y")) np_inp = np.arange(16.0).reshape(8, 2) s_hbm = NamedSharding(mesh, P("x", "y"), memory_kind="device") @@ -599,8 +590,6 @@ def f(xs): self.assertEqual(out_host.sharding, s_host) def test_weight_offload_with_dp_on_output(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") _, s_dev, np_inp, inp_dev = _create_inputs( (8, 2), P("x", "y"), mem_kind="device") s_host = s_dev.with_memory_kind('pinned_host') @@ -616,8 +605,6 @@ def f(x): out_host, np_inp * 2, s_host, 'pinned_host') def test_output_streaming_inside_scan(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") if xb.backend_xla_version() is not None and xb.backend_xla_version() < 2: self.skipTest("This test requires an xla_version >= 2.") mesh = jtu.create_mesh((1, 1, 2), ("x", "y", "z")) @@ -650,8 +637,6 @@ def test_deepcopy(self): self.assertEqual(t.shape, t_copy.shape) def test_close_over_host_constant_and_stream(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") _, s_host, np_inp, inp_host = _create_inputs( (8, 2), P("x", "y"), mem_kind="pinned_host") @@ -1562,8 +1547,6 @@ def g(ys, _): self.assertGreater(compiled_stats.host_temp_size_in_bytes, 0) def test_remat_scan_layout_change_offloadable(self): - if jtu.test_device_matches(["gpu"]) and xla_extension_version < 289: - self.skipTest("Requires xla_extension_version >= 289") mesh = jtu.create_mesh((2,), ("x",)) shape = (256, 128) np_inp = np.arange(math.prod(shape), dtype=np.float32).reshape(shape) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0dc5284d5149..9ca4185bb0da 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -59,7 +59,6 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension -from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -661,10 +660,7 @@ def testAutodiffCache(self): jax.grad(f)(x) # Warm up the cache. with jtu.count_pjit_cpp_cache_miss() as count: jax.grad(f)(x) - if xla_extension_version >= 286: - self.assertEqual(count[0], 0) # no cache miss i.e. cache hit - else: - self.assertEqual(count[0], 2) + self.assertEqual(count[0], 0) # no cache miss i.e. cache hit @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -4590,8 +4586,6 @@ def test_wsc_abstract_mesh_errors(self): ' match the mesh shape of the target sharding.*'): with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) - @unittest.skipIf(xla_extension_version < 286, - "Requires xla_extension_version >= 286") def test_global_jit_cpp_cache_hit_out_shardings(self): mesh = jtu.create_mesh((2,), 'x') s = NamedSharding(mesh, P('x')) diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index c5342a99365d..a8e54537d14f 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -24,7 +24,6 @@ import jax from jax import flatten_util from jax import tree_util -from jax._src.lib import xla_extension_version from jax._src import test_util as jtu from jax._src.tree_util import flatten_one_level, prefix_errors import jax.numpy as jnp @@ -485,10 +484,8 @@ def testFlattenUpTo(self, tree, xs, expected): [([1], (2,), {"a": [1]})], re.escape("Custom node type mismatch"), ), - *( - [] - if xla_extension_version < 288 - else [(None, [2], re.escape("Expected None, got [2]."))] + ( + (None, [2], re.escape("Expected None, got [2].")) ), ) def testFlattenUpToErrors(self, tree, xs, error):