Skip to content

Commit

Permalink
[XLA:MHLO->HLO] Allow partially-set parameter tuple sharding to exist…
Browse files Browse the repository at this point in the history
… by filling in the missing sharding elements with replicated sharding. (This is what is done for the missing shardings in the result tuple.)

Before this change, if an element of a tuple parameter did not have a sharding, MHLO->HLO conversion dropped the existing annotations on the parameter. This issue caused the disappearing of the parameter sharding for a model, which then resulted in an OOM.

PiperOrigin-RevId: 615287423
  • Loading branch information
jax authors committed Mar 14, 2024
1 parent 9a00721 commit 43cbc94
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,8 +2076,8 @@ def lower_sharding_computation(
any(not is_unspecified(o) for o in out_shardings))

gs = GSPMDSharding.get_replicated(device_assignment)
# if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
if xla_extension_version < 241 or hasattr(backend, "compile_replicated"):
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)

da_object = _create_da_object(tuple(device_assignment))

Expand Down
16 changes: 16 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3942,6 +3942,22 @@ def f(x, y, z, a, b):
self.assertArraysEqual(out4, np_inp * 3)
self.assertArraysEqual(out5, np_inp.T)

def test_parameter_tupled_jit(self):
if not jtu.test_device_matches(["tpu"]):
self.skipTest('Parameters are tupled only on TPU if >2000 parameters')

mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x'))

@jax.jit
def f(*args):
return args * 2

inp = np.arange(8)
arr = jax.device_put(inp, s)
inps = [arr, *[inp] * 2001]
f(inps) # doesn't crash


class TempSharding(Sharding):

Expand Down

0 comments on commit 43cbc94

Please sign in to comment.