Skip to content

Commit

Permalink
[JAX] Convert stablehlo to MLIR bytecode, not an MLIR string.
Browse files Browse the repository at this point in the history
Bytecode is considerably more compact.

PiperOrigin-RevId: 615386276
  • Loading branch information
hawkinsp authored and jax authors committed Mar 13, 2024
1 parent f0c5051 commit 642f20d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
10 changes: 7 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,10 +589,14 @@ def computation_maker(*args, **kwargs):
arg_shardings=None,
result_shardings=None,
lowering_parameters=mlir.LoweringParameters())

if xla_extension_version >= 244:
m = mlir.module_to_bytecode(lowering_result.module)
else:
m = mlir.module_to_string(lowering_result.module)

built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(lowering_result.module),
use_tuple_args=tuple_args,
return_tuple=True)
m, use_tuple_args=tuple_args, return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
Expand Down
10 changes: 8 additions & 2 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version


source_info_util.register_exclusion(__file__)
Expand Down Expand Up @@ -314,9 +315,14 @@ class XlaLowering(Lowering):

def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
hlo = self.stablehlo()
m: Union[str, bytes]
if xla_extension_version >= 244:
m = mlir.module_to_bytecode(hlo)
else:
m = mlir.module_to_string(hlo)
return xla_extension.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self.stablehlo()),
use_tuple_args=self.compile_args["tuple_args"])
m, use_tuple_args=self.compile_args["tuple_args"])

def mhlo(self) -> ir.Module:
"""Return an MHLO representation of this computation."""
Expand Down

0 comments on commit 642f20d

Please sign in to comment.