From ab2e9063234507acae8544f46b950ed7db934cd8 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Fri, 15 Mar 2024 16:58:38 -0700 Subject: [PATCH] Fix the indentation of the physical_hlo_sharding function PiperOrigin-RevId: 616280971 --- jax/_src/prng.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 00d3b1bf7d38..f495e1c99d26 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -361,14 +361,13 @@ def get_logical_gspmd_sharding(aval, phys_sharding): def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding: - key_shape = aval.dtype._impl.key_shape - new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore - partitions, num_replicas = op_shardings.get_num_ways_dim_sharded( - hlo_sharding) - suffix = [] if num_replicas == 1 else [num_replicas] - tad = partitions + [1] * len(key_shape) + suffix - new_op_sharding.tile_assignment_dimensions = tad - return xc.HloSharding.from_proto(new_op_sharding) + key_shape = aval.dtype._impl.key_shape + new_op_sharding = hlo_sharding.to_proto().clone() # type: ignore + partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(hlo_sharding) + suffix = [] if num_replicas == 1 else [num_replicas] + tad = partitions + [1] * len(key_shape) + suffix + new_op_sharding.tile_assignment_dimensions = tad + return xc.HloSharding.from_proto(new_op_sharding) class KeyTyRules: