Skip to content

Commit

Permalink
Bump minimum jaxlib version to 0.4.31. The corresponding xla_extensio…
Browse files Browse the repository at this point in the history
…n_version is 279 and mlir_api_version is 57

PiperOrigin-RevId: 657400413
  • Loading branch information
yashk2810 authored and jax authors committed Jul 30, 2024
1 parent 2106a25 commit 3003754
Show file tree
Hide file tree
Showing 12 changed files with 59 additions and 128 deletions.
2 changes: 1 addition & 1 deletion docs/export/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ as in the following example:
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32> {mhlo.layout_mode = "default"}) -> (tensor<f32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32} : (tensor<f32>) -> tensor<f32>
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}
}
Expand Down
8 changes: 2 additions & 6 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -2492,11 +2491,8 @@ def maybe_recover_user_shardings(

def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
xl: DeviceLocalLayout) -> bool:
if xla_extension_version >= 274:
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
return ul.major_to_minor == xl.major_to_minor
else:
return ul == xl
if isinstance(ul, DeviceLocalLayout) and ul._tiling is None:
return ul.major_to_minor == xl.major_to_minor
else:
return ul == xl

Expand Down
141 changes: 54 additions & 87 deletions jax/_src/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from jax._src.sharding import Sharding
from jax._src.sharding_impls import AUTO as AutoSharding, is_auto
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version

Shape = tuple[int, ...]

Expand All @@ -31,96 +30,64 @@ def __repr__(self):
return "AUTO"


if xla_extension_version >= 274:
class DeviceLocalLayout:
major_to_minor: tuple[int, ...]
_tiling: tuple[tuple[int, ...], ...] | None
_sub_byte_element_size_in_bits: int

AUTO = AutoLayout()

def __init__(self, major_to_minor: tuple[int, ...],
_tiling: tuple[tuple[int, ...], ...] | None = None,
_sub_byte_element_size_in_bits: int = 0):
self.major_to_minor = tuple(major_to_minor)
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits

@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
xla_layout = pjrt_layout._xla_layout()
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
xla_layout.tiling(),
xla_layout.element_size_in_bits())

def __repr__(self):
return (
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' _tiling={self._tiling},'
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
)
class DeviceLocalLayout:
major_to_minor: tuple[int, ...]
_tiling: tuple[tuple[int, ...], ...] | None
_sub_byte_element_size_in_bits: int

AUTO = AutoLayout()

def __init__(self, major_to_minor: tuple[int, ...],
_tiling: tuple[tuple[int, ...], ...] | None = None,
_sub_byte_element_size_in_bits: int = 0):
self.major_to_minor = tuple(major_to_minor)
self._tiling = None if _tiling is None else tuple(map(tuple, _tiling))
self._sub_byte_element_size_in_bits = _sub_byte_element_size_in_bits

def __hash__(self):
return hash((self.major_to_minor, self._tiling,
self._sub_byte_element_size_in_bits))
@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
xla_layout = pjrt_layout._xla_layout()
return DeviceLocalLayout(xla_layout.minor_to_major()[::-1], # pytype: disable=wrong-arg-types
xla_layout.tiling(),
xla_layout.element_size_in_bits())

def __repr__(self):
return (
f'DeviceLocalLayout(major_to_minor={self.major_to_minor},'
f' _tiling={self._tiling},'
f' _sub_byte_element_size_in_bits={self._sub_byte_element_size_in_bits})'
)

def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return (self.major_to_minor == other.major_to_minor and
self._tiling == other._tiling and
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)
def __hash__(self):
return hash((self.major_to_minor, self._tiling,
self._sub_byte_element_size_in_bits))

def _to_xla_layout(self, dtype) -> str:
if self._tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return (self.major_to_minor == other.major_to_minor and
self._tiling == other._tiling and
self._sub_byte_element_size_in_bits == other._sub_byte_element_size_in_bits)

def _to_xla_layout(self, dtype) -> str:
if self._tiling is None:
xla_layout = xc.Layout(self.major_to_minor[::-1])
else:
if self._sub_byte_element_size_in_bits != 0:
sub_byte_size = self._sub_byte_element_size_in_bits
elif issubdtype(dtype, np.integer):
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
else:
if self._sub_byte_element_size_in_bits != 0:
sub_byte_size = self._sub_byte_element_size_in_bits
elif issubdtype(dtype, np.integer):
sub_byte_size = iinfo(dtype).bits if iinfo(dtype).bits < 8 else 0
else:
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
sub_byte_size)
return str(xla_layout)

def check_compatible_aval(self, aval_shape: Shape):
if len(self.major_to_minor) != len(aval_shape):
raise ValueError(
f'Length of major_to_minor and the rank of the value should match.'
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')

else:
class DeviceLocalLayout: # type: ignore
layout: xc.PjRtLayout

AUTO = AutoLayout()

def __init__(self, layout: xc.PjRtLayout):
self._layout = layout
self._layout_str = str(self._layout)

@staticmethod
def from_pjrt_layout(pjrt_layout: xc.PjRtLayout):
return DeviceLocalLayout(pjrt_layout) # type: ignore

def __repr__(self):
return f'DeviceLocalLayout({self._layout_str})'

def __hash__(self):
return hash(self._layout)

def __eq__(self, other):
if not isinstance(other, DeviceLocalLayout):
return False
return self._layout == other._layout

def _to_xla_layout(self, dtype) -> str:
return self._layout_str

def check_compatible_aval(self, aval_shape: Shape):
pass
sub_byte_size = 0
xla_layout = xc.Layout(self.major_to_minor[::-1], self._tiling, # type: ignore
sub_byte_size)
return str(xla_layout)

def check_compatible_aval(self, aval_shape: Shape):
if len(self.major_to_minor) != len(aval_shape):
raise ValueError(
f'Length of major_to_minor and the rank of the value should match.'
f' Got major_to_minor={self.major_to_minor} and shape={aval_shape}')


LayoutOptions = Union[DeviceLocalLayout, None, AutoLayout] # pytype: disable=invalid-annotation
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.sharding_impls import (
NamedSharding, GSPMDSharding,
Expand Down Expand Up @@ -714,7 +713,7 @@ def _infer_params(
resource_env = None
pjit_mesh = None

skip_cache = xla_extension_version < 273 or config.dynamic_shapes.value
skip_cache = config.dynamic_shapes.value
if not skip_cache:
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
Expand Down
3 changes: 0 additions & 3 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from jax._src import util
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.op_shardings import (
are_op_shardings_equal, get_num_ways_dim_sharded, is_op_sharding_replicated)
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -1065,8 +1064,6 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
return parsed_pspec

if xla_extension_version < 279:
preprocess_with_manual = preprocess

def prepare_axis_resources(axis_resources,
arg_name,
Expand Down
2 changes: 1 addition & 1 deletion jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def make_release_tree(self, base_dir, files):


__version__ = _get_version_string()
_minimum_jaxlib_version = "0.4.30"
_minimum_jaxlib_version = "0.4.31"

def _version_as_tuple(version_str):
return tuple(int(i) for i in version_str.split(".") if i.isdigit())
Expand Down
3 changes: 0 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@
from jax._src.compilation_cache import is_persistent_cache_enabled
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
import jax._src.util as jax_util
from jax.ad_checkpoint import checkpoint_name, checkpoint as new_checkpoint
import jax.custom_batching
Expand Down Expand Up @@ -2641,7 +2640,6 @@ def f(x):

self.assertEqual(count[0], 1)

@unittest.skipIf(xla_extension_version <= 273, "requires jaxlib 0.4.31")
def test_jit_infer_params_cache(self):
def f(x):
return x
Expand Down Expand Up @@ -4427,7 +4425,6 @@ def f(x, y):
g = jax.grad(f, argnums=-1)
g(x, y) # doesn't crash

@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def test_jit_negative_static_argnums(self):
@partial(jax.jit, static_argnums=-1)
def g(x, y):
Expand Down
4 changes: 0 additions & 4 deletions tests/device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version

jax.config.parse_flags_with_absl()

Expand All @@ -27,9 +26,6 @@ def test_repr(self):

# TODO(pobudzey): Add a test for rocm devices when available.
if jtu.is_device_cuda():
if xla_extension_version < 276:
self.skipTest('requires jaxlib 0.4.31')

self.assertEqual(device.platform, 'gpu')
self.assertEqual(repr(device), 'CudaDevice(id=0)')
elif jtu.test_device_matches(['tpu']):
Expand Down
4 changes: 0 additions & 4 deletions tests/export_back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import dataclasses
from functools import partial
import itertools
import logging
import math

from absl.testing import absltest, parameterized
Expand Down Expand Up @@ -64,7 +63,6 @@
from jax._src import test_util as jtu
from jax._src.lib import cuda_versions
from jax._src.lib import version as jaxlib_version
from jax._src.lib import xla_extension_version

config.parse_flags_with_absl()

Expand Down Expand Up @@ -591,8 +589,6 @@ def func():
self.run_one_test(func, data)

def test_cuda_threefry2x32(self):
logging.info("test_cuda_threefry2x32: xla_extension_version: %s",
xla_extension_version)
def func(x):
return jax.random.uniform(x, (2, 4), dtype=np.float32)

Expand Down
4 changes: 0 additions & 4 deletions tests/extend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import os

import numpy as np
import unittest
from absl.testing import absltest, parameterized

import jax
Expand All @@ -30,7 +29,6 @@
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.extend import ffi

Expand Down Expand Up @@ -124,7 +122,6 @@ def testParams(self, param):
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCall(self, shape, dtype):
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
Expand All @@ -140,7 +137,6 @@ def testFfiCall(self, shape, dtype):
vectorized=(False, True),
)
@jtu.run_on_devices("gpu")
@unittest.skipIf(xla_extension_version < 272, "requires jaxlib 0.4.31")
def testFfiCallBatching(self, shape, dtype, vectorized):
shape = (10,) + shape
pivots_size = shape[-1]
Expand Down
10 changes: 0 additions & 10 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from jax._src.layout import Layout, DeviceLocalLayout as DLL
from jax._src import test_util as jtu
from jax._src.util import safe_zip
from jax._src.lib import xla_extension_version

config.parse_flags_with_absl()

Expand Down Expand Up @@ -405,9 +404,6 @@ def f(x):
self.assertArraysEqual(out, inp.T)

def test_device_put_user_concrete_layout(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')

shape = (8, 128)
np_inp = np.arange(math.prod(shape)).reshape(shape)
dll = DLL(major_to_minor=(1, 0))
Expand Down Expand Up @@ -437,9 +433,6 @@ def f(x):
custom_dll.major_to_minor)

def test_compatible_aval_error(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')

custom_dll = DLL(major_to_minor=(0, 1, 2))
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
inp = np.arange(8)
Expand All @@ -454,9 +447,6 @@ def f(x):
f(inp)

def test_incompatible_aval_error_device_put(self):
if xla_extension_version < 274:
self.skipTest('Requires xla_extension_version >= 274')

custom_dll = DLL(major_to_minor=(0, 1, 2))
l = Layout(custom_dll, SingleDeviceSharding(jax.devices()[0]))
inp = np.arange(8)
Expand Down
3 changes: 0 additions & 3 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from functools import partial
import itertools
import unittest

import numpy as np
import scipy
Expand All @@ -34,7 +33,6 @@
from jax._src.lax import linalg as lax_linalg
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_extension_version
from jax._src.numpy.util import promote_dtypes_inexact

config.parse_flags_with_absl()
Expand Down Expand Up @@ -1625,7 +1623,6 @@ def testTriangularSolveGradPrecision(self):
(a, b),
(a, b))

@unittest.skipIf(xla_extension_version < 277, "Requires jaxlib > 0.4.30")
def testTriangularSolveSingularBatched(self):
x = jnp.array([[1, 1], [0, 0]], dtype=np.float32)
y = jnp.array([[1], [1.]], dtype=np.float32)
Expand Down

0 comments on commit 3003754

Please sign in to comment.