From f5e8ea9cb2fe27ac0cfa82a9d5f31e56314d5828 Mon Sep 17 00:00:00 2001 From: jax authors Date: Mon, 4 Mar 2024 15:09:02 -0800 Subject: [PATCH] [XLA:SPMD] Do not propagate sharding to parameter if it does not evenly partition the parameter. PiperOrigin-RevId: 612601176 --- jax/_src/interpreters/pxla.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 95c4c499e86e..950ee61a060e 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -2048,10 +2048,8 @@ def lower_sharding_computation( any(not is_unspecified(o) for o in out_shardings)) gs = GSPMDSharding.get_replicated(device_assignment) - # TODO(yashkatariya): Enable this when the SPMD chooses correct shardings for - # shapes indivisible by shard_shape. - # if xla_extension_version < 239 or hasattr(backend, "compile_replicated"): - in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) + if xla_extension_version < 239 or hasattr(backend, "compile_replicated"): + in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings) da_object = _create_da_object(tuple(device_assignment))