Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix subset of hosts dataloading for TPU v4 #586

Merged
merged 1 commit into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,6 @@ vertex_tensorboard_project: ""
# Region to create Vertex AI Tensorboard in for GCE, blank if running via XPK
# Vertex AI supported regions: https://cloud.google.com/vertex-ai/docs/general/locations#available-regions
vertex_tensorboard_region: ""

# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance.
max_checkify: False
59 changes: 26 additions & 33 deletions MaxText/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

"""Input pipeline"""

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
Expand All @@ -25,7 +27,7 @@
from input_pipeline import _grain_data_processing
from input_pipeline import _tfds_data_processing_c4_mlperf
import tokenizer

import multihost_dataloading

def get_tokenizer(tokenizer_path, add_bos=True, add_eos=True):
# Load tokenizer
Expand Down Expand Up @@ -134,43 +136,35 @@ class BadSyntheticDataIterator:

def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = P(*config.data_sharding)
data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(
BadSyntheticDataIterator.get_bad_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0
)

def __iter__(self):
return self
dataset = BadSyntheticDataIterator.get_bad_synthetic_data(config)
self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh)

def __iter__(self):
return self.data_generator

def __next__(self):
with self.mesh:
return self.data_generator(self.config)
return next(self.data_generator)

@staticmethod
def get_bad_synthetic_data(config):
"""fill negative value in synthetic data"""
output = {}
output["inputs"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["inputs_position"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["inputs_segmentation"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["targets"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["targets_position"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["targets_segmentation"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
return output
output['inputs'] = tf.data.Dataset.from_tensor_slices(np.full((1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['inputs_position'] = tf.data.Dataset.from_tensor_slices(np.full((1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['inputs_segmentation'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['targets'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['targets_position'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['targets_segmentation'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
dataset = tf.data.Dataset.zip((output)) # pytype: disable=wrong-arg-types
dataset = dataset.repeat()
dataset = dataset.batch(config.global_batch_size_to_load // jax.process_count())
return dataset


def get_process_loading_real_data(config, mesh):
Expand All @@ -187,8 +181,7 @@ def get_process_loading_real_data(config, mesh):

def make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos):
process_indices = get_process_loading_real_data(config, mesh)
print(len(process_indices), "hosts out of", jax.process_count(), "are loading real data")
if config.expansion_factor_real_data != -1: # assert number of hosts loading real data
if config.expansion_factor_real_data != -1: # assert number of hosts loading real data
assert len(process_indices) == jax.process_count() // config.expansion_factor_real_data
if jax.process_index() in process_indices:
if config.dataset_type == "c4":
Expand Down
10 changes: 10 additions & 0 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import jax.numpy as jnp
from jax import random
from jax.sharding import Mesh
from jax.experimental import checkify

from cloud_tpu_diagnostics import diagnostic
from cloud_tpu_diagnostics.configuration import debug_configuration
Expand Down Expand Up @@ -310,6 +311,14 @@ def record_goodput(recorder, config, step=None, job_start=False, job_end=False):
if step is not None:
recorder.record_step_start_time(step)

def check_example_batch(config, example_batch):
if config.max_checkify:
jittable_f = checkify.checkify(
lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!")
khatwanimohit marked this conversation as resolved.
Show resolved Hide resolved
)
# Check if inputs in batch contains bad synthetic data.
err, _ = jax.jit(jittable_f)(example_batch['inputs'][: config.global_batch_size_to_train_on, :])
err.throw()

def setup_mesh_and_model(config):
"""Set up the mesh and the model for training
Expand Down Expand Up @@ -485,6 +494,7 @@ def train_loop(config, state=None):

with jax.profiler.StepTraceAnnotation("train", step_num=step):
example_batch = load_next_batch(data_iterator, example_batch, config)
check_example_batch(config, example_batch=example_batch)
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
record_goodput(recorder, config, step=step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
Expand Down
Loading