From 0b70244b1c7325287df406e794ffd0cc7e97a641 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Sat, 2 Mar 2024 13:34:46 -0800 Subject: [PATCH] Thread out_avals to MeshExecutable PiperOrigin-RevId: 612037684 --- jax/_src/interpreters/pxla.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f9cfbfc9ead3..e8ddc67b20f2 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2789,7 +2789,7 @@ def build_unsafe_call(self): def load(self) -> MeshExecutable: return MeshExecutable(self.xla_executable, self.build_unsafe_call, - self.input_avals, + self.input_avals, self.output_avals, self.input_shardings, self.output_shardings, self.auto_spmd_lowering, self.kept_var_idx, self.in_layouts, self.out_layouts, @@ -2942,12 +2942,13 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat): class MeshExecutable(stages.XlaExecutable): __slots__ = [ "xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals", - "_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx", - "_in_layouts", "_out_layouts", "_all_args_info", "_unloaded_executable", + "out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering", + "_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info", + "_unloaded_executable", ] - def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, - out_shardings, auto_spmd_lowering, kept_var_idx, + def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals, + in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx, in_layouts, out_layouts, all_args_info: AllArgsInfo | None = None, unloaded_executable=None): @@ -2956,6 +2957,7 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings, # in_avals is a list of global and local avals. Aval is global if input # is a GDA or jax.Array else local. self.in_avals = in_avals + self.out_avals = out_avals self._unsafe_call = None self._in_shardings = in_shardings self._out_shardings = out_shardings @@ -3118,8 +3120,9 @@ def _compile_replicated_mesh_executable_from_hlo( committed=committed, pmap_nreps=pmap_nreps) xla_executable = None return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals, - in_shardings, out_shardings, auto_spmd_lowering, - kept_var_idx, (None,) * len(global_in_avals), + global_out_avals, in_shardings, out_shardings, + auto_spmd_lowering, kept_var_idx, + (None,) * len(global_in_avals), (None,) * len(global_out_avals))