Skip to content

Commit

Permalink
Set out_mut to None as default on from_hlo instead of in `__init_…
Browse files Browse the repository at this point in the history
…_` of `MeshComputation` and correct the types too.

PiperOrigin-RevId: 611814102
  • Loading branch information
yashk2810 authored and jax authors committed Mar 1, 2024
1 parent cfeb113 commit 2761f26
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,7 @@ def build_execute_fun(self):
self.unordered_effects,
self.ordered_effects, self.keepalive,
bool(self.host_callbacks),
set(range(len(input_indices))), [])
set(range(len(input_indices))), None)
return execute_fun

def load(self) -> PmapExecutable:
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: set[int],
out_mut: Sequence[int | None]):
out_mut: Sequence[int | None] | None):
self.xla_executable = xla_executable
self.name = name
self.backend = backend
Expand Down Expand Up @@ -1210,7 +1210,7 @@ def __call__(self, *args):
out = self.out_handler(out_arrays)
else:
out = results.consume_with_handlers(self.out_handler.handlers)
if not self.out_mut:
if self.out_mut is None:
return out
else:
out_ = []
Expand Down Expand Up @@ -2282,7 +2282,6 @@ def lower_mesh_computation(
host_callbacks=lowering_result.host_callbacks,
keepalive=lowering_result.keepalive,
kept_var_idx=set(range(len(global_in_avals))),
out_mut=None,
backend=backend,
device_assignment=_create_da_object(tuple(mesh.devices.flat)),
committed=True,
Expand All @@ -2297,7 +2296,6 @@ class MeshComputation(stages.XlaLowering):

def __init__(self, name: str, hlo: ir.Module | None,
donated_invars: Sequence[bool], **compile_args):
compile_args.setdefault('out_mut', None) # TODO(mattjj): remove default
self._name = name
self._hlo = hlo
self._donated_invars = donated_invars
Expand Down Expand Up @@ -2763,7 +2761,7 @@ class UnloadedMeshExecutable:
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
kept_var_idx: set[int]
out_mut: Sequence[None | int]
out_mut: Sequence[None | int] | None
auto_spmd_lowering: bool
in_layouts: Sequence[SpecifiedLayout | None]
out_layouts: Sequence[SpecifiedLayout | None]
Expand Down Expand Up @@ -2802,7 +2800,7 @@ def from_hlo(name: str,
global_out_avals: Sequence[ShapedArray],
in_shardings: Sequence[sharding_impls.XLACompatibleSharding | AUTO],
out_shardings: Sequence[(sharding_impls.XLACompatibleSharding | AUTO |
UnspecifiedValue)],
UnspecifiedValue)],
spmd_lowering: bool,
tuple_args: bool,
auto_spmd_lowering: bool,
Expand All @@ -2811,13 +2809,13 @@ def from_hlo(name: str,
host_callbacks: list[Any],
keepalive: Any,
kept_var_idx: set[int],
out_mut: Sequence[None | int],
backend: xb.XlaBackend,
device_assignment: xc.DeviceList | Sequence[xc.Device], # type: ignore
committed: bool,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
pmap_nreps: int = 1,
out_mut: Sequence[None | int] | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
Expand Down

0 comments on commit 2761f26

Please sign in to comment.