diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index b3e29eb95695..2f98b6daae6f 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -4181,10 +4181,19 @@ def _top_k_translation_rule(ctx, avals_in, avals_out, x, *, k): top_k_p.def_impl(partial(dispatch.apply_primitive, top_k_p)) top_k_p.def_abstract_eval(_top_k_abstract_eval) def _top_k_lower(ctx, operand, k): - if not core.is_constant_dim(k): + if core.is_constant_dim(k): + return chlo.TopKOp(operand, mlir.i64_attr(k)).results + if xla_client.mlir_api_version < 54: # TODO: https://github.com/openxla/stablehlo/issues/1396 raise ValueError("native serialization with shape polymorphism not implemented for top_k") - return chlo.TopKOp(operand, mlir.i64_attr(k)).results + k_value, = mlir.eval_dynamic_shape_as_vals(ctx, (k,)) + out_values_aval, out_indices_aval, = ctx.avals_out + return mlir.custom_call( + "stablehlo.dynamic_top_k", + [mlir.aval_to_ir_type(out_values_aval), + mlir.aval_to_ir_type(out_indices_aval)], + [operand, k_value]).results + mlir.register_lowering(top_k_p, _top_k_lower) ad.primitive_jvps[top_k_p] = _top_k_jvp batching.primitive_batchers[top_k_p] = _top_k_batch_rule diff --git a/jax/experimental/jax2tf/jax_export.py b/jax/experimental/jax2tf/jax_export.py index be2fb18dd348..39e621418678 100644 --- a/jax/experimental/jax2tf/jax_export.py +++ b/jax/experimental/jax2tf/jax_export.py @@ -697,12 +697,13 @@ def _check_lowering(lowering) -> None: # ApproxTopK on TPU "ApproxTopK", "tf.call_tf_function", # From jax2tf.call_tf(func, call_tf_graph=True) - "tpu_custom_call", # Pallas kernels + "tpu_custom_call", # Pallas/TPU kernels # TODO(burmako): maintain backwards compatibility for these, until they # are upstreamed to StableHLO. # See https://github.com/openxla/stablehlo/issues/8. "stablehlo.dynamic_reduce_window", "stablehlo.dynamic_rng_bit_generator", + "stablehlo.dynamic_top_k", "shape_assertion", # Used by shape_poly to evaluate assertions } diff --git a/jax/experimental/jax2tf/tests/back_compat_test.py b/jax/experimental/jax2tf/tests/back_compat_test.py index 1a56b5925227..69274c839301 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test.py +++ b/jax/experimental/jax2tf/tests/back_compat_test.py @@ -51,6 +51,7 @@ from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_Sharding from jax.experimental.jax2tf.tests.back_compat_testdata import tpu_stablehlo_dynamic_reduce_window from jax.experimental.jax2tf.tests.back_compat_testdata import stablehlo_dynamic_rng_bit_generator +from jax.experimental.jax2tf.tests.back_compat_testdata import stablehlo_dynamic_top_k from jax.experimental import pjit from jax.experimental.shard_map import shard_map @@ -114,7 +115,9 @@ def test_custom_call_coverage(self): tpu_ApproxTopK.data_2023_05_16, tpu_stablehlo_dynamic_reduce_window.data_unary_2023_06_17, tpu_stablehlo_dynamic_reduce_window.data_variadic_2023_06_17, - stablehlo_dynamic_rng_bit_generator.data_2023_06_17,] + stablehlo_dynamic_rng_bit_generator.data_2023_06_17, + stablehlo_dynamic_top_k.data_2023_07_16, + ] # Some of the above are nested structures. covering_testdatas = itertools.chain( *[self.load_testdata_nested(d) for d in covering_testdatas]) @@ -129,7 +132,10 @@ def test_custom_call_coverage(self): "shape_assertion", }) not_covered = targets_to_cover.difference(covered_targets) - self.assertEmpty(not_covered) + self.assertEmpty(not_covered, + msg=("The following custom call targets are declared " + "stable but are not covered by any tests: " + f"{not_covered}")) def test_ducc_fft(self): def func(x): @@ -672,6 +678,29 @@ def func(key, a): # a is only used for its shape finally: jax.config.update("jax_default_prng_impl", prev_default_prng_impl) + def test_stablehlo_dynamic_top_k(self): + # stablehlo.dynamic_top_k is used temporarily for a top_k with dynamism + a = np.arange(12, dtype=np.float32).reshape((4, 3)) + + def func(a): + return lax.top_k(a, k=a.shape[-1] - 1) + + data = self.load_testdata(stablehlo_dynamic_top_k.data_2023_07_16) + def check_top_k_results(res_run, res_expected, *, rtol, atol): + # The order of the results may be different, but should be the same ones + values_expected, _ = res_expected + values_run, indices_run = res_run + # Check that indices are correct + self.assertAllClose(values_run, + a[np.arange(a.shape[0]).reshape(a.shape[0], 1), + indices_run], atol=atol, rtol=rtol) + self.assertAllClose(np.sort(values_run), np.sort(values_expected), + atol=atol, rtol=rtol) + + self.run_one_test(func, data, + polymorphic_shapes=("_, b",), + check_results=check_top_k_results) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader()) diff --git a/jax/experimental/jax2tf/tests/back_compat_test_util.py b/jax/experimental/jax2tf/tests/back_compat_test_util.py index 6f6fa97a9f4b..797c9fd6de35 100644 --- a/jax/experimental/jax2tf/tests/back_compat_test_util.py +++ b/jax/experimental/jax2tf/tests/back_compat_test_util.py @@ -40,19 +40,13 @@ def test_foo_call(self): def func(...): ... inputs = (...,) # Tuple of nd.array, keep it small, perhaps generate the # inputs in `func`. - data = dataclasses.replace(self.load_testdata(bctu.dummy_data_dict), - inputs=inputs, - platform=self.default_jax_backend()) - self.run_one_test(func, data, - # Temporarily allow calls to "foo" - allow_additional_custom_call_targets=("foo",)) + data = self.starter_data(inputs) # This is temporary, just for starting. + self.run_one_test(func, data) The test will fail, but will save to a file the test data you will need. The file name will be printed in the logs. Create a new file ./back_compat_testdata/foo_call.py and paste the test data that -you will see printed in the logs. You may want to -edit the serialization string to remove any pathnames that may be included at -the end, or gxxxxx3 at the beginning. +you will see printed in the logs. Name the literal `data_YYYYY_MM_DD` to include the date of serializaton (for readability only). Then add to this file: @@ -137,6 +131,13 @@ def default_jax_backend(self) -> str: # Canonicalize to turn into "cuda" or "rocm" return xb.canonicalize_platform(jax.default_backend()) + def starter_data(self, inputs: Sequence[np.ndarray]) -> CompatTestData: + # Helper for starting a test, see module docstring. + assert isinstance(inputs, Sequence), f"{inputs}" + return dataclasses.replace(self.load_testdata(dummy_data_dict), + inputs=inputs, + platform=self.default_jax_backend()) + def load_testdata(self, testdata_dict: dict[str, Any]) -> CompatTestData: if testdata_dict["testdata_version"] == CURRENT_TESTDATA_VERSION: return CompatTestData(**testdata_dict) @@ -212,7 +213,7 @@ def run_one_test(self, func: Callable[..., jax.Array], np.set_printoptions(threshold=sys.maxsize, floatmode="unique") # Print the current test data to simplify updating the test. updated_testdata = f""" -# Pasted from the test output (see back_compat_test.py module docstring) +# Pasted from the test output (see back_compat_test_util.py module docstring) data_{datetime.date.today().strftime('%Y_%m_%d')} = dict( testdata_version={CURRENT_TESTDATA_VERSION}, platform={repr(self.default_jax_backend())}, diff --git a/jax/experimental/jax2tf/tests/back_compat_testdata/stablehlo_dynamic_top_k.py b/jax/experimental/jax2tf/tests/back_compat_testdata/stablehlo_dynamic_top_k.py new file mode 100644 index 000000000000..ce4c3151a2ff --- /dev/null +++ b/jax/experimental/jax2tf/tests/back_compat_testdata/stablehlo_dynamic_top_k.py @@ -0,0 +1,67 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# flake8: noqa + +import datetime +from numpy import array, float32, int32 + + +# Pasted from the test output (see back_compat_test_util.py module docstring) +data_2023_07_16 = dict( + testdata_version=1, + platform='cpu', + custom_call_targets=['stablehlo.dynamic_top_k'], + serialized_date=datetime.date(2023, 7, 16), + inputs=(array([[ 0., 1., 2.], + [ 3., 4., 5.], + [ 6., 7., 8.], + [ 9., 10., 11.]], dtype=float32),), + expected_outputs=(array([[ 2., 1.], + [ 5., 4.], + [ 8., 7.], + [11., 10.]], dtype=float32), array([[2, 1], + [2, 1], + [2, 1], + [2, 1]], dtype=int32)), + mlir_module_text=r""" +#loc = loc(unknown) +module @jit_func attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { + func.func public @main(%arg0: tensor<4x?xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x?xf32> {jax.result_info = "[0]"}, tensor<4x?xi32> {jax.result_info = "[1]"}) { + %0 = stablehlo.get_dimension_size %arg0, dim = 1 : (tensor<4x?xf32>) -> tensor loc(#loc3) + %1 = stablehlo.convert %0 : (tensor) -> tensor loc(#loc3) + %2 = stablehlo.constant dense<> : tensor<0xi1> loc(#loc) + %3 = stablehlo.convert %arg0 : tensor<4x?xf32> loc(#loc) + %4:2 = call @_wrapped_jax_export_main(%1, %3) : (tensor, tensor<4x?xf32>) -> (tensor<4x?xf32>, tensor<4x?xi32>) loc(#loc) + return %4#0, %4#1 : tensor<4x?xf32>, tensor<4x?xi32> loc(#loc) + } loc(#loc) + func.func private @_wrapped_jax_export_main(%arg0: tensor loc(unknown), %arg1: tensor<4x?xf32> {jax.arg_info = "a", mhlo.sharding = "{replicated}"} loc(unknown)) -> (tensor<4x?xf32> {jax.result_info = "[0]"}, tensor<4x?xi32> {jax.result_info = "[1]"}) { + %0 = stablehlo.convert %arg0 : tensor loc(#loc4) + %1 = stablehlo.constant dense<-1> : tensor loc(#loc5) + %2 = stablehlo.add %0, %1 : tensor loc(#loc6) + %3 = stablehlo.convert %2 : (tensor) -> tensor loc(#loc5) + %4:2 = stablehlo.custom_call @stablehlo.dynamic_top_k(%arg1, %3) {api_version = 2 : i32} : (tensor<4x?xf32>, tensor) -> (tensor<4x?xf32>, tensor<4x?xi32>) loc(#loc5) + return %4#0, %4#1 : tensor<4x?xf32>, tensor<4x?xi32> loc(#loc) + } loc(#loc) +} loc(#loc) +#loc1 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":621:0) +#loc2 = loc("third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py":613:0) +#loc3 = loc("/dimension_size[dimension=1]"(#loc1)) +#loc4 = loc("jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]"(#loc2)) +#loc5 = loc("jit(func)/jit(main)/top_k[k=b + -1]"(#loc2)) +#loc6 = loc("jit(func)/jit(main)/add"(#loc2)) +""", + mlir_module_serialized=b"ML\xefR\x01StableHLO_v0.9.0\x00\x01!\x05\x01\x03\x01\x03\x05\x03\x11\x07\t\x0b\r\x0f\x11\x13\x15\x03\xb5\x89\x19\x01Q\x07\x0b\x17\x0f\x0f\x0b\x0b\x0b\x0b\x0b\x0b\x0f#\x0b\x0b\x0b33\x0f\x0b\x13\x0b\x0f\x0bK\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x0b\x13\x0b\x0b\x17\x13\x13\x0b\x039\x0b\x1b\x13\x0b\x0b\x0f\x0b\x0b\x0b\x0b\x0b\x13\x0b\x13\x0b\x0b\x0b\x13\x0b\x0b\x0b/\x0b\x0b\x0b\x0b\x0f\x0f\x01\x03\x0f\x03\x177\x0f7\x07\x07\x0f\x13\x1b\x07\x1f\x07\x02R\x04\x1f\x05\x17\x17\x13\x96\t\x01\x1d+\x05\x11\x01\x05\x05\x19\x05\x1b\x05\x1d\x05\x1f\x05!\x05#\x1dGI\x03\x07\x1b\t\x1d\t\x03\x1f\x05%\x05'\x05)\x03\x0b\x0b[\re\x0fU\x03o\x11q\x03\x0b\x0bs\rw\x0fU\x03Y\x11y\x1d'\x05\x05+\x03\x03\x15{\x05-\x1d/\x05\x05/\x03\x113}5\x7f7\x819Q;\x83=Q?QAQ\x051\x053\x055\x057\x059\x05;\x05=\x05?\x03\x03E\x85\x05A\x05C\x17\x13\xb6\t\x01\x03\x03\x15\x87\x03\x03OY\x05E\x03\x01\r\x05]_ac\x03\x05gk\x1dG\x1dI\x03\x03S\x1dK\x1dM\x1dO\x1dQ#\x11\r\x03Wi\x1dS\r\x03Wm\x1dU\x1dW\x1dY\x03\x05uS\r\x01#\x15\x1d[\x1f\x05\x11\xff\xff\xff\xff\xff\xff\xff\xff\x0b\x05\x1d]\x1d_\x05\x01\x13\x0b\x05\x1f\x0f\x01\x01\x02\x02)\x05\x11\x00\xff\xff\xff\xff\xff\xff\xff\xff\x13)\x01\x0b)\x05\x11\x00\xff\xff\xff\xff\xff\xff\xff\xff\t\x1b\x1d)\x01\t)\x03\x01\x17\x11\x03\x03\x05\x03\x07\t\x11\x05\x05\x03\x05\x03\x07\x01\x04\xf3\x05\x01\x11\x01\x19\x07\x03\x01\t\x05\x11\x01!\x05\x03\x0f\x1b\x03\x03\x01\x0f\x07\x17C\x03\r\x03\x01\x03\x06\x17\x03\x05\x03\x03\x07\x03\x01K\x03\x0f\x03\x06\x01\x03\x03\x03\x01\x11\x07\x01M\x05\x03\x07\x05\x05\t\t\x04\x01\x05\x0b\r\x05\x11\x01#\x05\x03\x11\x1b\x05\x05\x01\x03\x01\x03\x06%\x03\x05\x03\x01\x07\x03\x07)\x03\x05\x0b\x06-\x03\x05\x05\x05\x07\x03\x06\x07\x03\r\x03\t\r\x07\x071\x05\x03\x07\x05\x03\x0b\t\x04\x01\x05\r\x0f\x06\x03\x01\x05\x01\x00R\x0ca1\x03\x11\x0f\x0b\t\t\x1b\x1d\x05\x1b3!\x0f;\x15\x1f/!!)#\x1f\x191I\x95\x13%)\r\x83\x1f\x15\x1d\x15\x13\x11-\x1f\x0f\x15\x19\x11\x17\x0f\x0b\x11builtin\x00vhlo\x00module\x00convert_v1\x00func_v1\x00constant_v1\x00return_v1\x00add_v1\x00custom_call_v1\x00get_dimension_size_v1\x00call_v1\x00sym_name\x00arg_attrs\x00function_type\x00res_attrs\x00sym_visibility\x00third_party/py/jax/experimental/jax2tf/tests/back_compat_test.py\x00value\x00mhlo.num_partitions\x00mhlo.num_replicas\x00jit_func\x00jit(func)/jit(main)/convert_element_type[new_dtype=int64 weak_type=False]\x00jit(func)/jit(main)/top_k[k=b + -1]\x00jit(func)/jit(main)/add\x00api_version\x00backend_config\x00call_target_name\x00called_computations\x00has_side_effect\x00operand_layouts\x00output_operand_aliases\x00result_layouts\x00dimension\x00/dimension_size[dimension=1]\x00callee\x00jax.result_info\x00_wrapped_jax_export_main\x00jax.arg_info\x00a\x00mhlo.sharding\x00{replicated}\x00[0]\x00[1]\x00main\x00public\x00private\x00\x00stablehlo.dynamic_top_k\x00", + xla_call_module_version=6, +) # End paste diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 59fed2b5d107..80c737f59cd8 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -3097,10 +3097,6 @@ def test_harness(self, harness: PolyHarness): raise unittest.SkipTest( "native lowering with shape polymorphism requires additional StableHLO feature support") - if "top_k" in harness.fullname and "approx_top_k" not in harness.fullname: - # https://github.com/openxla/stablehlo/issues/1255: need DynamicTopK - raise unittest.SkipTest("native lowering with shape polymorphism not implemented for top_k") - # Some tests need the latest jaxlib need_new_jaxlib = [] if jaxlib_version < (0, 4, 13):