diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index cd0799297793..b18af885d3eb 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2076,8 +2076,8 @@ def lower_sharding_computation( any(not is_unspecified(o) for o in out_shardings)) gs = GSPMDSharding.get_replicated(device_assignment) - # if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): - in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) + if xla_extension_version < 241 or hasattr(backend, "compile_replicated"): + in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) da_object = _create_da_object(tuple(device_assignment)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index a3f17e7b6637..eccbd1289736 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3942,6 +3942,22 @@ def f(x, y, z, a, b): self.assertArraysEqual(out4, np_inp * 3) self.assertArraysEqual(out5, np_inp.T) + def test_parameter_tupled_jit(self): + if not jtu.test_device_matches(["tpu"]): + self.skipTest('Parameters are tupled only on TPU if >2000 parameters') + + mesh = jtu.create_global_mesh((2, 2), ('x', 'y')) + s = NamedSharding(mesh, P('x')) + + @jax.jit + def f(*args): + return args * 2 + + inp = np.arange(8) + arr = jax.device_put(inp, s) + inps = [arr, *[inp] * 2001] + f(inps) # doesn't crash + class TempSharding(Sharding):