Skip to content

Commit

Permalink
Use jax.jit in Maxtext and remove pjit by converting PartitionSpec to…
Browse files Browse the repository at this point in the history
… NamedShardings.

Update unit test for decode.py
  • Loading branch information
michelle-yooh committed Oct 4, 2023
1 parent ca0cc34 commit 688c698
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
- name: Test decode.py
run: |
source venv/bin/activate
python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2
python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4
- name: Test int8_training
run: |
source venv/bin/activate
Expand Down
10 changes: 5 additions & 5 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import jax
import jax.numpy as jnp
from jax import random
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh

Expand Down Expand Up @@ -144,11 +143,12 @@ def decode_loop(config, state=None):

state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager)

p_predict_step = pjit(
state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
replicated_sharding = jax.sharding.NamedSharding(mesh, P(None, None))
p_predict_step = jax.jit(
functools.partial(predict_step, model=model, config=config),
in_shardings=(P(None, None),
state_mesh_annotations,
None),
in_shardings=(replicated_sharding, state_mesh_shardings, None),
out_shardings=None
)

Expand Down
9 changes: 4 additions & 5 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P

import tokenizer
Expand Down Expand Up @@ -288,10 +287,10 @@ def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = P(*config.data_sharding)
with self.mesh:
self.data_generator = pjit(SyntheticDataIterator.raw_generate_synthetic_data,
in_shardings=None,
out_shardings=data_pspec,
data_pspec_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(SyntheticDataIterator.raw_generate_synthetic_data,
out_shardings=data_pspec_shardings,
static_argnums=0)
def __call__(self):
with self.mesh:
Expand Down
10 changes: 5 additions & 5 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import jax.numpy as jnp
from jax.experimental import mesh_utils

from jax.experimental.pjit import pjit

import json
import flax
from flax.training import train_state
Expand Down Expand Up @@ -205,7 +203,7 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
unboxed_abstract_state = unbox_logicallypartioned_trainstate(abstract_state)

# Initialization
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
state, raw_params = checkpointing.load_state_if_possible(checkpoint_manager,
config.load_parameters_path,
Expand All @@ -215,11 +213,13 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
mesh,
state_mesh_annotations)

state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
if not state:
state = pjit(
state = jax.jit(
init_train_state_partial,
in_shardings=None,
out_shardings=state_mesh_annotations
out_shardings=state_mesh_shardings
)(rng)
if raw_params: # If we loaded a partial state, we need to merge it.
state = state.replace(params = raw_params)
Expand Down
13 changes: 7 additions & 6 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@

import jax.numpy as jnp
from jax import random
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh

Expand Down Expand Up @@ -251,12 +250,14 @@ def train_loop(config, state=None):
per_device_tflops = calculate_training_tflops(num_model_parameters, config)

# Define compiled top-level functions.
p_train_step = pjit(
state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
p_train_step = jax.jit(
train_step,
in_shardings=(state_mesh_annotations,
data_pspec,
None),
out_shardings=(state_mesh_annotations, None, None),
in_shardings=(state_mesh_shardings, data_sharding, None),
out_shardings=(state_mesh_shardings, None, None),
static_argnums=(0,1,),
donate_argnums=2)

Expand Down
3 changes: 1 addition & 2 deletions end_to_end/test_decode.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ fi
#Train
python3 MaxText/decode.py MaxText/configs/base.yml run_name=$RUN_NAME\
steps=50 enable_checkpointing=False metrics_file='metrics.txt'\
base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \
ici_tensor_parallelism=4
base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH

python3 end_to_end/eval_assert.py metrics_average metrics.txt $NUM_TOKEN_THRESHOLD num_tokens
35 changes: 24 additions & 11 deletions pedagogical_examples/shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import jax
from jax.sharding import PartitionSpec
from jax.experimental.pjit import pjit
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax.experimental.compilation_cache import compilation_cache as cc
Expand Down Expand Up @@ -222,32 +221,46 @@ def training_step(in_act, in_layers):

print("finished includes ", flush = True)

pjit_func = pjit(
replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding)

parameter_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

data_pspec_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

jit_func = jax.jit(
training_step,
in_shardings=(data_sharding, parameter_sharding),
out_shardings=parameter_sharding,
in_shardings=(replicated_sharding, parameter_mesh_shardings),
out_shardings=data_pspec_shardings,
)

pjit_gen_data = pjit(
data_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding)

jit_gen_data = jax.jit(
gen_data,
in_shardings=None,
out_shardings=data_sharding
out_shardings=data_mesh_shardings
)

pjit_gen_layers = pjit(
parameter_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

jit_gen_layers = jax.jit(
gen_layers,
in_shardings=None,
out_shardings=parameter_sharding
out_shardings=parameter_mesh_shardings
)

# starting the profiler outside `with` statement,
# will call it right before the computation once b/301309635 is resolved
activate_profiler(args.profiler_path)
with Mesh(mesh.devices, mesh.axis_names):
key = jax.random.PRNGKey(0)
presharded_X = jax.block_until_ready(pjit_gen_data(key))
presharded_layers = jax.block_until_ready(pjit_gen_layers(key))
presharded_X = jax.block_until_ready(jit_gen_data(key))
presharded_layers = jax.block_until_ready(jit_gen_layers(key))
TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices())
time = simple_timeit(lambda : jax.block_until_ready(pjit_func(presharded_X, presharded_layers)))
time = simple_timeit(lambda : jax.block_until_ready(jit_func(presharded_X, presharded_layers)))
print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush = True)
deactivate_profiler(args.profiler_path)

0 comments on commit 688c698

Please sign in to comment.