From 43cbc9454c4c8a80ab505977269fb5cc00f5d45e Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 12 Mar 2024 22:22:00 -0700 Subject: [PATCH] [XLA:MHLO->HLO] Allow partially-set parameter tuple sharding to exist by filling in the missing sharding elements with replicated sharding. (This is what is done for the missing shardings in the result tuple.) Before this change, if an element of a tuple parameter did not have a sharding, MHLO->HLO conversion dropped the existing annotations on the parameter. This issue caused the disappearing of the parameter sharding for a model, which then resulted in an OOM. PiperOrigin-RevId: 615287423 --- jax/_src/interpreters/pxla.py | 4 ++-- tests/pjit_test.py | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index cd0799297793..b18af885d3eb 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2076,8 +2076,8 @@ def lower_sharding_computation( any(not is_unspecified(o) for o in out_shardings)) gs = GSPMDSharding.get_replicated(device_assignment) - # if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): - in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) + if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): + in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) da_object = _create_da_object(tuple(device_assignment)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a3f17e7b6637..eccbd1289736 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3942,6 +3942,22 @@ def f(x, y, z, a, b): self.assertArraysEqual(out4, np_inp * 3) self.assertArraysEqual(out5, np_inp.T) + def test_parameter_tupled_jit(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest('Parameters are tupled only on TPU if >2000 parameters') + + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x')) + + @jax.jit + def f(*args): + return args * 2 + + inp = np.arange(8) + arr = jax.device_put(inp, s) + inps = [arr, *[inp] * 2001] + f(inps) # doesn't crash + class TempSharding(Sharding):