Skip to content

Replace pjit with jax.jit by converting PartitionSpec to NamedShardings. #193

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
- name: Test decode.py
run: |
source venv/bin/activate
python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2
python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4
- name: Test int8_training
run: |
source venv/bin/activate
Expand Down
10 changes: 5 additions & 5 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import jax
import jax.numpy as jnp
from jax import random
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh

Expand Down Expand Up @@ -144,11 +143,12 @@ def decode_loop(config, state=None):

state, state_mesh_annotations = max_utils.setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager)

p_predict_step = pjit(
state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
replicated_sharding = jax.sharding.NamedSharding(mesh, P(None, None))
p_predict_step = jax.jit(
functools.partial(predict_step, model=model, config=config),
in_shardings=(P(None, None),
state_mesh_annotations,
None),
in_shardings=(replicated_sharding, state_mesh_shardings, None),
out_shardings=None
)

Expand Down
9 changes: 4 additions & 5 deletions MaxText/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P

import tokenizer
Expand Down Expand Up @@ -288,10 +287,10 @@ def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = P(*config.data_sharding)
with self.mesh:
self.data_generator = pjit(SyntheticDataIterator.raw_generate_synthetic_data,
in_shardings=None,
out_shardings=data_pspec,
data_pspec_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(SyntheticDataIterator.raw_generate_synthetic_data,
out_shardings=data_pspec_shardings,
static_argnums=0)
def __call__(self):
with self.mesh:
Expand Down
10 changes: 5 additions & 5 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
import jax.numpy as jnp
from jax.experimental import mesh_utils

from jax.experimental.pjit import pjit

import json
import flax
from flax.training import train_state
Expand Down Expand Up @@ -205,7 +203,7 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
unboxed_abstract_state = unbox_logicallypartioned_trainstate(abstract_state)

# Initialization
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
with nn_partitioning.axis_rules(config.logical_axis_rules):
state_mesh_annotations = nn.logical_to_mesh(state_logical_annotations)
state, raw_params = checkpointing.load_state_if_possible(checkpoint_manager,
config.load_parameters_path,
Expand All @@ -215,11 +213,13 @@ def setup_initial_state(model, tx, config, rng, mesh, checkpoint_manager):
mesh,
state_mesh_annotations)

state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
if not state:
state = pjit(
state = jax.jit(
init_train_state_partial,
in_shardings=None,
out_shardings=state_mesh_annotations
out_shardings=state_mesh_shardings
)(rng)
if raw_params: # If we loaded a partial state, we need to merge it.
state = state.replace(params = raw_params)
Expand Down
13 changes: 7 additions & 6 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@

import jax.numpy as jnp
from jax import random
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh

Expand Down Expand Up @@ -251,12 +250,14 @@ def train_loop(config, state=None):
per_device_tflops = calculate_training_tflops(num_model_parameters, config)

# Define compiled top-level functions.
p_train_step = pjit(
state_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), state_mesh_annotations)
data_sharding = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
p_train_step = jax.jit(
train_step,
in_shardings=(state_mesh_annotations,
data_pspec,
None),
out_shardings=(state_mesh_annotations, None, None),
in_shardings=(state_mesh_shardings, data_sharding, None),
out_shardings=(state_mesh_shardings, None, None),
static_argnums=(0,1,),
donate_argnums=2)

Expand Down
3 changes: 1 addition & 2 deletions end_to_end/test_decode.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ fi
#Train
python3 MaxText/decode.py MaxText/configs/base.yml run_name=$RUN_NAME\
steps=50 enable_checkpointing=False metrics_file='metrics.txt'\
base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \
ici_tensor_parallelism=4
base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH

python3 end_to_end/eval_assert.py metrics_average metrics.txt $NUM_TOKEN_THRESHOLD num_tokens
35 changes: 24 additions & 11 deletions pedagogical_examples/shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import jax
from jax.sharding import PartitionSpec
from jax.experimental.pjit import pjit
from jax.sharding import Mesh
from jax.experimental import mesh_utils
from jax.experimental.compilation_cache import compilation_cache as cc
Expand Down Expand Up @@ -222,32 +221,46 @@ def training_step(in_act, in_layers):

print("finished includes ", flush = True)

pjit_func = pjit(
replicated_sharding = jax.sharding.NamedSharding(mesh, data_sharding)

parameter_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

data_pspec_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

jit_func = jax.jit(
training_step,
in_shardings=(data_sharding, parameter_sharding),
out_shardings=parameter_sharding,
in_shardings=(replicated_sharding, parameter_mesh_shardings),
out_shardings=data_pspec_shardings,
)

pjit_gen_data = pjit(
data_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), data_sharding)

jit_gen_data = jax.jit(
gen_data,
in_shardings=None,
out_shardings=data_sharding
out_shardings=data_mesh_shardings
)

pjit_gen_layers = pjit(
parameter_mesh_shardings = jax.tree_map(
lambda p: jax.sharding.NamedSharding(mesh, p), parameter_sharding)

jit_gen_layers = jax.jit(
gen_layers,
in_shardings=None,
out_shardings=parameter_sharding
out_shardings=parameter_mesh_shardings
)

# starting the profiler outside `with` statement,
# will call it right before the computation once b/301309635 is resolved
activate_profiler(args.profiler_path)
with Mesh(mesh.devices, mesh.axis_names):
key = jax.random.PRNGKey(0)
presharded_X = jax.block_until_ready(pjit_gen_data(key))
presharded_layers = jax.block_until_ready(pjit_gen_layers(key))
presharded_X = jax.block_until_ready(jit_gen_data(key))
presharded_layers = jax.block_until_ready(jit_gen_layers(key))
TFLOPs_per_device = parameters * 6 * BATCH / 10**12 / len(jax.devices())
time = simple_timeit(lambda : jax.block_until_ready(pjit_func(presharded_X, presharded_layers)))
time = simple_timeit(lambda : jax.block_until_ready(jit_func(presharded_X, presharded_layers)))
print(f"time is {time} seconds, TFLOP is {TFLOPs_per_device}, TFLOP/s is {TFLOPs_per_device/time}", flush = True)
deactivate_profiler(args.profiler_path)