Skip to content

Commit

Permalink
[XLA:SPMD] Do not propagate sharding to parameter if it does not even…
Browse files Browse the repository at this point in the history
…ly partition the parameter.

PiperOrigin-RevId: 612601176
  • Loading branch information
jax authors committed Mar 4, 2024
1 parent feda85d commit f5e8ea9
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2048,10 +2048,8 @@ def lower_sharding_computation(
any(not is_unspecified(o) for o in out_shardings))

gs = GSPMDSharding.get_replicated(device_assignment)
# TODO(yashkatariya): Enable this when the SPMD chooses correct shardings for
# shapes indivisible by shard_shape.
# if xla_extension_version < 239 or hasattr(backend, "compile_replicated"):
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)
if xla_extension_version < 239 or hasattr(backend, "compile_replicated"):
in_shardings = tuple(gs if is_unspecified(i) else i for i in in_shardings)

da_object = _create_da_object(tuple(device_assignment))

Expand Down

0 comments on commit f5e8ea9

Please sign in to comment.