diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 6024553a2..72b6bb53d 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -55,7 +55,7 @@ jobs: - name: Test decode.py run: | source venv/bin/activate - python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 + python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 - name: Test int8_training run: | source venv/bin/activate diff --git a/MaxText/decode.py b/MaxText/decode.py index 2792edf56..4f55dc3e2 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -36,7 +36,6 @@ import jax import jax.numpy as jnp from jax import random -from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax.sharding import Mesh @@ -144,11 +143,12 @@ def decode_loop(config, state=None): state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager) - p_predict_step = pjit( + state_mesh_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + replicated_sharding = jax.sharding.NamedSharding(mesh, P(None, None)) + p_predict_step = jax.jit( functools.partial(predict_step, model=model, config=config), - in_shardings=(P(None, None), - state_mesh_annotations, - None), + in_shardings=(replicated_sharding, state_mesh_shardings, None), out_shardings=None ) diff --git a/MaxText/input_pipeline.py b/MaxText/input_pipeline.py index ba9cd971d..67890a0c2 100644 --- a/MaxText/input_pipeline.py +++ b/MaxText/input_pipeline.py @@ -24,7 +24,6 @@ import tensorflow as tf import tensorflow_datasets as tfds import jax -from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P import tokenizer @@ -288,10 +287,10 @@ def __init__(self, config, mesh): self.mesh = mesh self.config = config data_pspec = P(*config.data_sharding) - with self.mesh: - self.data_generator = pjit(SyntheticDataIterator.raw_generate_synthetic_data, - in_shardings=None, - out_shardings=data_pspec, + data_pspec_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + self.data_generator = jax.jit(SyntheticDataIterator.raw_generate_synthetic_data, + out_shardings=data_pspec_shardings, static_argnums=0) def __call__(self): with self.mesh: diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index d4f45438b..57b595de3 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -26,8 +26,6 @@ import jax.numpy as jnp from jax.experimental import mesh_utils -from jax.experimental.pjit import pjit - import json import flax from flax.training import train_state @@ -205,7 +203,7 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager): unboxed_abstract_state = unbox_logicallypartioned_trainstate(abstract_state) # Initialization - with mesh, nn_partitioning.axis_rules(config.logical_axis_rules): + with nn_partitioning.axis_rules(config.logical_axis_rules): state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations) state, raw_params = checkpointing.load_state_if_possible(checkpoint_manager, config.load_parameters_path, @@ -215,11 +213,13 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager): mesh, state_mesh_annotations) + state_mesh_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) if not state: - state = pjit( + state = jax.jit( init_train_state_partial, in_shardings=None, - out_shardings=state_mesh_annotations + out_shardings=state_mesh_shardings )(rng) if raw_params: # If we loaded a partial state, we need to merge it. state = state.replace(params = raw_params) diff --git a/MaxText/train.py b/MaxText/train.py index ad96e348f..6b37faed8 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -45,7 +45,6 @@ import jax.numpy as jnp from jax import random -from jax.experimental.pjit import pjit from jax.sharding import PartitionSpec as P from jax.sharding import Mesh @@ -251,12 +250,14 @@ def train_loop(config, state=None): per_device_tflops = calculate_training_tflops(num_model_parameters, config) # Define compiled top-level functions. - p_train_step = pjit( + state_mesh_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations) + data_sharding = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec) + p_train_step = jax.jit( train_step, - in_shardings=(state_mesh_annotations, - data_pspec, - None), - out_shardings=(state_mesh_annotations, None, None), + in_shardings=(state_mesh_shardings, data_sharding, None), + out_shardings=(state_mesh_shardings, None, None), static_argnums=(0,1,), donate_argnums=2) diff --git a/end_to_end/test_decode.sh b/end_to_end/test_decode.sh index 094963b0f..c05d40ef6 100644 --- a/end_to_end/test_decode.sh +++ b/end_to_end/test_decode.sh @@ -17,7 +17,6 @@ fi #Train python3 MaxText/decode.py MaxText/configs/base.yml run_name=$RUN_NAME\ steps=50 enable_checkpointing=False metrics_file='metrics.txt'\ - base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ - ici_tensor_parallelism=4 + base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH python3 end_to_end/eval_assert.py metrics_average metrics.txt $NUM_TOKEN_THRESHOLD num_tokens diff --git a/pedagogical_examples/shardings.py b/pedagogical_examples/shardings.py index 2c1275631..2aefd4538 100644 --- a/pedagogical_examples/shardings.py +++ b/pedagogical_examples/shardings.py @@ -20,7 +20,6 @@ import jax from jax.sharding import PartitionSpec -from jax.experimental.pjit import pjit from jax.sharding import Mesh from jax.experimental import mesh_utils from jax.experimental.compilation_cache import compilation_cache as cc @@ -222,22 +221,36 @@ def training_step(in_act, in_layers): print("finished includes ", flush = True) -pjit_func = pjit( +replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding) + +parameter_mesh_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + +data_pspec_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + +jit_func = jax.jit( training_step, - in_shardings=(data_sharding, parameter_sharding), - out_shardings=parameter_sharding, + in_shardings=(replicated_sharding, parameter_mesh_shardings), + out_shardings=data_pspec_shardings, ) -pjit_gen_data = pjit( +data_mesh_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding) + +jit_gen_data = jax.jit( gen_data, in_shardings=None, - out_shardings=data_sharding + out_shardings=data_mesh_shardings ) -pjit_gen_layers = pjit( +parameter_mesh_shardings = jax.tree_map( + lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding) + +jit_gen_layers = jax.jit( gen_layers, in_shardings=None, - out_shardings=parameter_sharding + out_shardings=parameter_mesh_shardings ) # starting the profiler outside `with` statement, @@ -245,9 +258,9 @@ def training_step(in_act, in_layers): activate_profiler(args.profiler_path) with Mesh(mesh.devices, mesh.axis_names): key = jax.random.PRNGKey(0) - presharded_X = jax.block_until_ready(pjit_gen_data(key)) - presharded_layers = jax.block_until_ready(pjit_gen_layers(key)) + presharded_X = jax.block_until_ready(jit_gen_data(key)) + presharded_layers = jax.block_until_ready(jit_gen_layers(key)) TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices()) - time = simple_timeit(lambda : jax.block_until_ready(pjit_func(presharded_X, presharded_layers))) + time = simple_timeit(lambda : jax.block_until_ready(jit_func(presharded_X, presharded_layers))) print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush = True) deactivate_profiler(args.profiler_path)