Skip to content

Commit

Permalink
Fix the indentation of the physical_hlo_sharding function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616280971
  • Loading branch information
yashk2810 authored and jax authors committed Mar 15, 2024
1 parent cd1e55a commit ab2e906
Showing 1 changed file with 7 additions and 8 deletions.
15 changes: 7 additions & 8 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,14 +361,13 @@ def get_logical_gspmd_sharding(aval, phys_sharding):


def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
key_shape = aval.dtype._impl.key_shape
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)
key_shape = aval.dtype._impl.key_shape
new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
tad = partitions + [1] * len(key_shape) + suffix
new_op_sharding.tile_assignment_dimensions = tad
return xc.HloSharding.from_proto(new_op_sharding)


class KeyTyRules:
Expand Down

0 comments on commit ab2e906

Please sign in to comment.