Skip to content

Commit

Permalink
Thread out_avals to MeshExecutable
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 612037684
  • Loading branch information
yashk2810 authored and jax authors committed Mar 2, 2024
1 parent 8569b89 commit 0b70244
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2789,7 +2789,7 @@ def build_unsafe_call(self):

def load(self) -> MeshExecutable:
return MeshExecutable(self.xla_executable, self.build_unsafe_call,
self.input_avals,
self.input_avals, self.output_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.in_layouts, self.out_layouts,
Expand Down Expand Up @@ -2942,12 +2942,13 @@ def reflatten_outputs_for_dispatch(out_tree, out_flat):
class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"_in_shardings", "_out_shardings", "_auto_spmd_lowering", "_kept_var_idx",
"_in_layouts", "_out_layouts", "_all_args_info", "_unloaded_executable",
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
"_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info",
"_unloaded_executable",
]

def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
out_shardings, auto_spmd_lowering, kept_var_idx,
def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
in_layouts, out_layouts,
all_args_info: AllArgsInfo | None = None,
unloaded_executable=None):
Expand All @@ -2956,6 +2957,7 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, in_shardings,
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals
self.out_avals = out_avals
self._unsafe_call = None
self._in_shardings = in_shardings
self._out_shardings = out_shardings
Expand Down Expand Up @@ -3118,8 +3120,9 @@ def _compile_replicated_mesh_executable_from_hlo(
committed=committed, pmap_nreps=pmap_nreps)
xla_executable = None
return MeshExecutable(xla_executable, lambda: unsafe_call, global_in_avals,
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx, (None,) * len(global_in_avals),
global_out_avals, in_shardings, out_shardings,
auto_spmd_lowering, kept_var_idx,
(None,) * len(global_in_avals),
(None,) * len(global_out_avals))


Expand Down

0 comments on commit 0b70244

Please sign in to comment.