Skip to content

Commit

Permalink
[shape_poly] Add shape polymorphism support for TopK.
Browse files Browse the repository at this point in the history
This relies on a newly introduced support for a custom
call @stablehlo.dynamic_top_k.

PiperOrigin-RevId: 551833809
  • Loading branch information
gnecula authored and jax authors committed Jul 28, 2023
1 parent 3b28d4e commit 88e11ae
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 19 deletions.
13 changes: 11 additions & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4181,10 +4181,19 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k):
top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p))
top_k_p.def_abstract_eval(_top_k_abstract_eval)
def _top_k_lower(ctx, operand, k):
if not core.is_constant_dim(k):
if core.is_constant_dim(k):
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
if xla_client.mlir_api_version < 54:
# TODO: https://github.com/openxla/stablehlo/issues/1396
raise ValueError("native serialization with shape polymorphism not implemented for top_k")
return chlo.TopKOp(operand, mlir.i64_attr(k)).results
k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,))
out_values_aval, out_indices_aval, = ctx.avals_out
return mlir.custom_call(
"stablehlo.dynamic_top_k",
[mlir.aval_to_ir_type(out_values_aval),
mlir.aval_to_ir_type(out_indices_aval)],
[operand, k_value]).results

mlir.register_lowering(top_k_p, _top_k_lower)
ad.primitive_jvps[top_k_p] = _top_k_jvp
batching.primitive_batchers[top_k_p] = _top_k_batch_rule
Expand Down
3 changes: 2 additions & 1 deletion jax/experimental/jax2tf/jax_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,12 +697,13 @@ def _check_lowering(lowering) -> None:
# ApproxTopK on TPU
"ApproxTopK",
"tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True)
"tpu_custom_call", # Pallas kernels
"tpu_custom_call", # Pallas/TPU kernels
# TODO(burmako): maintain backwards compatibility for these, until they
# are upstreamed to StableHLO.
# See https://github.com/openxla/stablehlo/issues/8.
"stablehlo.dynamic_reduce_window",
"stablehlo.dynamic_rng_bit_generator",
"stablehlo.dynamic_top_k",
"shape_assertion", # Used by shape_poly to evaluate assertions
}

Expand Down
33 changes: 31 additions & 2 deletions jax/experimental/jax2tf/tests/back_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Sharding
from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_stablehlo_dynamic_reduce_window
from jax.experimental.jax2tf.tests.back_compat_testdata import stablehlo_dynamic_rng_bit_generator
from jax.experimental.jax2tf.tests.back_compat_testdata import stablehlo_dynamic_top_k

from jax.experimental import pjit
from jax.experimental.shard_map import shard_map
Expand Down Expand Up @@ -114,7 +115,9 @@ def test_custom_call_coverage(self):
tpu_ApproxTopK.data_2023_05_16,
tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17,
tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17,
stablehlo_dynamic_rng_bit_generator.data_2023_06_17,]
stablehlo_dynamic_rng_bit_generator.data_2023_06_17,
stablehlo_dynamic_top_k.data_2023_07_16,
]
# Some of the above are nested structures.
covering_testdatas = itertools.chain(
*[self.load_testdata_nested(d) for d in covering_testdatas])
Expand All @@ -129,7 +132,10 @@ def test_custom_call_coverage(self):
"shape_assertion",
})
not_covered = targets_to_cover.difference(covered_targets)
self.assertEmpty(not_covered)
self.assertEmpty(not_covered,
msg=("The following custom call targets are declared "
"stable but are not covered by any tests: "
f"{not_covered}"))

def test_ducc_fft(self):
def func(x):
Expand Down Expand Up @@ -672,6 +678,29 @@ def func(key, a): # a is only used for its shape
finally:
jax.config.update("jax_default_prng_impl", prev_default_prng_impl)

def test_stablehlo_dynamic_top_k(self):
# stablehlo.dynamic_top_k is used temporarily for a top_k with dynamism
a = np.arange(12, dtype=np.float32).reshape((4, 3))

def func(a):
return lax.top_k(a, k=a.shape[-1] - 1)

data = self.load_testdata(stablehlo_dynamic_top_k.data_2023_07_16)
def check_top_k_results(res_run, res_expected, *, rtol, atol):
# The order of the results may be different, but should be the same ones
values_expected, _ = res_expected
values_run, indices_run = res_run
# Check that indices are correct
self.assertAllClose(values_run,
a[np.arange(a.shape[0]).reshape(a.shape[0], 1),
indices_run], atol=atol, rtol=rtol)
self.assertAllClose(np.sort(values_run), np.sort(values_expected),
atol=atol, rtol=rtol)

self.run_one_test(func, data,
polymorphic_shapes=("_, b",),
check_results=check_top_k_results)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
21 changes: 11 additions & 10 deletions jax/experimental/jax2tf/tests/back_compat_test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,13 @@ def test_foo_call(self):
def func(...): ...
inputs = (...,) # Tuple of nd.array, keep it small, perhaps generate the
# inputs in `func`.
data = dataclasses.replace(self.load_testdata(bctu.dummy_data_dict),
inputs=inputs,
platform=self.default_jax_backend())
self.run_one_test(func, data,
# Temporarily allow calls to "foo"
allow_additional_custom_call_targets=("foo",))
data = self.starter_data(inputs) # This is temporary, just for starting.
self.run_one_test(func, data)
The test will fail, but will save to a file the test data you will need. The
file name will be printed in the logs. Create a new
file ./back_compat_testdata/foo_call.py and paste the test data that
you will see printed in the logs. You may want to
edit the serialization string to remove any pathnames that may be included at
the end, or gxxxxx3 at the beginning.
you will see printed in the logs.
Name the literal `data_YYYYY_MM_DD` to include the date of serializaton
(for readability only). Then add to this file:
Expand Down Expand Up @@ -137,6 +131,13 @@ def default_jax_backend(self) -> str:
# Canonicalize to turn into "cuda" or "rocm"
return xb.canonicalize_platform(jax.default_backend())

def starter_data(self, inputs: Sequence[np.ndarray]) -> CompatTestData:
# Helper for starting a test, see module docstring.
assert isinstance(inputs, Sequence), f"{inputs}"
return dataclasses.replace(self.load_testdata(dummy_data_dict),
inputs=inputs,
platform=self.default_jax_backend())

def load_testdata(self, testdata_dict: dict[str, Any]) -> CompatTestData:
if testdata_dict["testdata_version"] == CURRENT_TESTDATA_VERSION:
return CompatTestData(**testdata_dict)
Expand Down Expand Up @@ -212,7 +213,7 @@ def run_one_test(self, func: Callable[..., jax.Array],
np.set_printoptions(threshold=sys.maxsize, floatmode="unique")
# Print the current test data to simplify updating the test.
updated_testdata = f"""
# Pasted from the test output (see back_compat_test.py module docstring)
# Pasted from the test output (see back_compat_test_util.py module docstring)
data_{datetime.date.today().strftime('%Y_%m_%d')} = dict(
testdata_version={CURRENT_TESTDATA_VERSION},
platform={repr(self.default_jax_backend())},
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright 2023 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# flake8: noqa

import datetime
from numpy import array, float32, int32


# Pasted from the test output (see back_compat_test_util.py module docstring)
data_2023_07_16 = dict(
testdata_version=1,
platform='cpu',
custom_call_targets=['stablehlo.dynamic_top_k'],
serialized_date=datetime.date(2023, 7, 16),
inputs=(array([[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.],
[ 9., 10., 11.]], dtype=float32),),
expected_outputs=(array([[ 2., 1.],
[ 5., 4.],
[ 8., 7.],
[11., 10.]], dtype=float32), array([[2, 1],
[2, 1],
[2, 1],
[2, 1]], dtype=int32)),
mlir_module_text=r"""
#loc = loc(unknown)
module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<4x?xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x?xf32> {jax.result_info = "[0]"}, tensor<4x?xi32> {jax.result_info = "[1]"}) {
%0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<4x?xf32>) -> tensor<i32> loc(#loc3)
%1 = stablehlo.convert %0 : (tensor<i32>) -> tensor<i64> loc(#loc3)
%2 = stablehlo.constant dense<> : tensor<0xi1> loc(#loc)
%3 = stablehlo.convert %arg0 : tensor<4x?xf32> loc(#loc)
%4:2 = call @_wrapped_jax_export_main(%1, %3) : (tensor<i64>, tensor<4x?xf32>) -> (tensor<4x?xf32>, tensor<4x?xi32>) loc(#loc)
return %4#0, %4#1 : tensor<4x?xf32>, tensor<4x?xi32> loc(#loc)
} loc(#loc)
func.func private @_wrapped_jax_export_main(%arg0: tensor<i64> loc(unknown), %arg1: tensor<4x?xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x?xf32> {jax.result_info = "[0]"}, tensor<4x?xi32> {jax.result_info = "[1]"}) {
%0 = stablehlo.convert %arg0 : tensor<i64> loc(#loc4)
%1 = stablehlo.constant dense<-1> : tensor<i64> loc(#loc5)
%2 = stablehlo.add %0, %1 : tensor<i64> loc(#loc6)
%3 = stablehlo.convert %2 : (tensor<i64>) -> tensor<i32> loc(#loc5)
%4:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg1, %3) {api_version = 2 : i32} : (tensor<4x?xf32>, tensor<i32>) -> (tensor<4x?xf32>, tensor<4x?xi32>) loc(#loc5)
return %4#0, %4#1 : tensor<4x?xf32>, tensor<4x?xi32> loc(#loc)
} loc(#loc)
} loc(#loc)
#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":621:0)
#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":613:0)
#loc3 = loc("/dimension_size[dimension=1]"(#loc1))
#loc4 = loc("jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]"(#loc2))
#loc5 = loc("jit(func)/jit(main)/top_k[k=b + -1]"(#loc2))
#loc6 = loc("jit(func)/jit(main)/add"(#loc2))
""",
mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01!\x05\x01\x03\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xb5\x89\x19\x01Q\x07\x0b\x17\x0f\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f#\x0b\x0b\x0b33\x0f\x0b\x13\x0b\x0f\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x17\x13\x13\x0b\x039\x0b\x1b\x13\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0b\x0b\x0b\x13\x0b\x0b\x0b/\x0b\x0b\x0b\x0b\x0f\x0f\x01\x03\x0f\x03\x177\x0f7\x07\x07\x0f\x13\x1b\x07\x1f\x07\x02R\x04\x1f\x05\x17\x17\x13\x96\t\x01\x1d+\x05\x11\x01\x05\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x1dGI\x03\x07\x1b\t\x1d\t\x03\x1f\x05%\x05'\x05)\x03\x0b\x0b[\re\x0fU\x03o\x11q\x03\x0b\x0bs\rw\x0fU\x03Y\x11y\x1d'\x05\x05+\x03\x03\x15{\x05-\x1d/\x05\x05/\x03\x113}5\x7f7\x819Q;\x83=Q?QAQ\x051\x053\x055\x057\x059\x05;\x05=\x05?\x03\x03E\x85\x05A\x05C\x17\x13\xb6\t\x01\x03\x03\x15\x87\x03\x03OY\x05E\x03\x01\r\x05]_ac\x03\x05gk\x1dG\x1dI\x03\x03S\x1dK\x1dM\x1dO\x1dQ#\x11\r\x03Wi\x1dS\r\x03Wm\x1dU\x1dW\x1dY\x03\x05uS\r\x01#\x15\x1d[\x1f\x05\x11\xff\xff\xff\xff\xff\xff\xff\xff\x0b\x05\x1d]\x1d_\x05\x01\x13\x0b\x05\x1f\x0f\x01\x01\x02\x02)\x05\x11\x00\xff\xff\xff\xff\xff\xff\xff\xff\x13)\x01\x0b)\x05\x11\x00\xff\xff\xff\xff\xff\xff\xff\xff\t\x1b\x1d)\x01\t)\x03\x01\x17\x11\x03\x03\x05\x03\x07\t\x11\x05\x05\x03\x05\x03\x07\x01\x04\xf3\x05\x01\x11\x01\x19\x07\x03\x01\t\x05\x11\x01!\x05\x03\x0f\x1b\x03\x03\x01\x0f\x07\x17C\x03\r\x03\x01\x03\x06\x17\x03\x05\x03\x03\x07\x03\x01K\x03\x0f\x03\x06\x01\x03\x03\x03\x01\x11\x07\x01M\x05\x03\x07\x05\x05\t\t\x04\x01\x05\x0b\r\x05\x11\x01#\x05\x03\x11\x1b\x05\x05\x01\x03\x01\x03\x06%\x03\x05\x03\x01\x07\x03\x07)\x03\x05\x0b\x06-\x03\x05\x05\x05\x07\x03\x06\x07\x03\r\x03\t\r\x07\x071\x05\x03\x07\x05\x03\x0b\t\x04\x01\x05\r\x0f\x06\x03\x01\x05\x01\x00R\x0ca1\x03\x11\x0f\x0b\t\t\x1b\x1d\x05\x1b3!\x0f;\x15\x1f/!!)#\x1f\x191I\x95\x13%)\r\x83\x1f\x15\x1d\x15\x13\x11-\x1f\x0f\x15\x19\x11\x17\x0f\x0b\x11builtin\x00vhlo\x00module\x00convert_v1\x00func_v1\x00constant_v1\x00return_v1\x00add_v1\x00custom_call_v1\x00get_dimension_size_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00value\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]\x00jit(func)/jit(main)/top_k[k=b + -1]\x00jit(func)/jit(main)/add\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00dimension\x00/dimension_size[dimension=1]\x00callee\x00jax.result_info\x00_wrapped_jax_export_main\x00jax.arg_info\x00a\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00stablehlo.dynamic_top_k\x00",
xla_call_module_version=6,
) # End paste
4 changes: 0 additions & 4 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3097,10 +3097,6 @@ def test_harness(self, harness: PolyHarness):
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")

if "top_k" in harness.fullname and "approx_top_k" not in harness.fullname:
# https://github.com/openxla/stablehlo/issues/1255: need DynamicTopK
raise unittest.SkipTest("native lowering with shape polymorphism not implemented for top_k")

# Some tests need the latest jaxlib
need_new_jaxlib = []
if jaxlib_version < (0, 4, 13):
Expand Down

0 comments on commit 88e11ae

Please sign in to comment.