diff --git a/jax/_src/api.py b/jax/_src/api.py index 2fecc4fd78db..597d0c057844 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -589,10 +589,14 @@ def computation_maker(*args, **kwargs): arg_shardings=None, result_shardings=None, lowering_parameters=mlir.LoweringParameters()) + + if xla_extension_version >= 244: + m = mlir.module_to_bytecode(lowering_result.module) + else: + m = mlir.module_to_string(lowering_result.module) + built = xc._xla.mlir.mlir_module_to_xla_computation( - mlir.module_to_string(lowering_result.module), - use_tuple_args=tuple_args, - return_tuple=True) + m, use_tuple_args=tuple_args, return_tuple=True) out_shapes_flat = [ ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals] out_shape = tree_unflatten(out_tree(), out_shapes_flat) diff --git a/jax/_src/stages.py b/jax/_src/stages.py index c05007ccacea..2e36e92276fa 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -48,6 +48,7 @@ from jax._src.interpreters import mlir from jax._src.lib.mlir import ir from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version source_info_util.register_exclusion(__file__) @@ -314,9 +315,14 @@ class XlaLowering(Lowering): def hlo(self) -> xc.XlaComputation: """Return an HLO representation of this computation.""" + hlo = self.stablehlo() + m: Union[str, bytes] + if xla_extension_version >= 244: + m = mlir.module_to_bytecode(hlo) + else: + m = mlir.module_to_string(hlo) return xla_extension.mlir.mlir_module_to_xla_computation( - mlir.module_to_string(self.stablehlo()), - use_tuple_args=self.compile_args["tuple_args"]) + m, use_tuple_args=self.compile_args["tuple_args"]) def mhlo(self) -> ir.Module: """Return an MHLO representation of this computation."""