Skip to content

Commit

Permalink
Merge pull request #556 from google:ss-sharding-utest
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 621584403
  • Loading branch information
maxtext authors committed Apr 3, 2024
2 parents 5575702 + 6820b5b commit eb4a173
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 0 deletions.
13 changes: 13 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ def calculate_num_params_from_pytree(params):
assert total_parameters >= 0
return total_parameters


def calculate_total_params_per_chip(params):
def calculate_leaf_params_per_chip(arr):
shard = arr.addressable_shards[0]
return np.prod(shard.data.shape)

params_sizes_per_chip = jax.tree_util.tree_map(
calculate_leaf_params_per_chip, params)
total_parameters_per_chip = jax.tree_util.tree_reduce(
lambda x, y: x + y, params_sizes_per_chip)
return total_parameters_per_chip


def calculate_bytes_from_pytree(params):
params_bytes = jax.tree_util.tree_map(lambda x : x.nbytes, params)
total_bytes = jax.tree_util.tree_reduce(lambda x, y: x + y, params_bytes)
Expand Down
39 changes: 39 additions & 0 deletions MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Utils that are only interesting to MaxText. """

import jax
import max_utils
from jax.sharding import PartitionSpec as P
from jax.experimental.serialize_executable import deserialize_and_load

Expand Down Expand Up @@ -124,3 +125,41 @@ def calculate_tflops_prefill(num_model_parameters, prefill_length, config, log=T
f'\t\tCausal attention TFLOPs: {causal_attention_tflops} ',
f'({100 * causal_attention_tflops/total_tflops:.2f})% of Total')
return total_tflops, learnable_weight_tflops, causal_attention_tflops


def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01):
"""Checks whether most params are sharded across sharding axis.
This function determines whether the majority of parameters are distributed
across a specified sharding axes with an acceptable tolerance. It compares the
current distribution to a scenario where all parameters are fully sharded
across the 'fsdp', 'fsdp_transpose', 'sequence', and 'tensor' axes.
Args:
params: params of the model state
mesh: mesh constructed from config
tolerance: float between 0.0 and 1.0 representing the allowed percentage of
non-sharded parameters.
Returns:
bool: True if the majority of parameters are sufficiently sharded
"""
total_num_params = max_utils.calculate_num_params_from_pytree(params)
product_num_devices_for_weight_sharding = 1
for axis in ['fsdp', 'fsdp_transpose', 'sequence', 'tensor']:
product_num_devices_for_weight_sharding *= mesh.shape[axis]
total_num_params_per_chip = (
max_utils.calculate_total_params_per_chip(
params)
)
perfectly_sharded_params_per_chip = (
total_num_params / product_num_devices_for_weight_sharding
)
assert total_num_params_per_chip >= perfectly_sharded_params_per_chip, (
'Number of parameters per chip must not be less than in the ideal sharded '
'scenario accross `fsdp`, `fsdp_transpose`,`sequence`, `tensor` axes.'
)
assert (
total_num_params_per_chip/perfectly_sharded_params_per_chip - 1 < tolerance
), (f'Number of unsharded parameters exceeds tolerance {tolerance * 100}% '
'of total parameters.')

2 changes: 2 additions & 0 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,8 @@ def setup_train_loop(config):
state, state_mesh_annotations, data_iterator = max_utils.setup_training_state(model, data_iterator,
tx, config, init_rng, mesh, checkpoint_manager)

maxtext_utils.assert_params_sufficiently_sharded(state.params, mesh)

return ( init_rng, writer, checkpoint_manager, state_mesh_annotations, model,
mesh, learning_rate_schedule, data_iterator, eval_data_iterator, state)

Expand Down

0 comments on commit eb4a173

Please sign in to comment.