Skip to content

Commit

Permalink
Merge pull request #586 from google:mohit/subset_v4
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 626387060
  • Loading branch information
maxtext authors committed Apr 19, 2024
2 parents f52e6f7 + 20f2a0d commit 6ec7556
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 33 deletions.
3 changes: 3 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,6 @@ vertex_tensorboard_project: ""
# Region to create Vertex AI Tensorboard in for GCE, blank if running via XPK
# Vertex AI supported regions: https://cloud.google.com/vertex-ai/docs/general/locations#available-regions
vertex_tensorboard_region: ""

# If set to True, MaxText will perform extra checks using jax.checkify. Note that this will effect performance.
max_checkify: False
59 changes: 26 additions & 33 deletions MaxText/input_pipeline/input_pipeline_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

"""Input pipeline"""

import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
Expand All @@ -25,7 +27,7 @@
from input_pipeline import _grain_data_processing
from input_pipeline import _tfds_data_processing_c4_mlperf
import tokenizer

import multihost_dataloading

def get_tokenizer(tokenizer_path, add_bos=True, add_eos=True):
# Load tokenizer
Expand Down Expand Up @@ -134,43 +136,35 @@ class BadSyntheticDataIterator:

def __init__(self, config, mesh):
self.mesh = mesh
self.config = config
data_pspec = P(*config.data_sharding)
data_pspec_shardings = jax.tree_map(lambda p: jax.sharding.NamedSharding(mesh, p), data_pspec)
self.data_generator = jax.jit(
BadSyntheticDataIterator.get_bad_synthetic_data, out_shardings=data_pspec_shardings, static_argnums=0
)

def __iter__(self):
return self
dataset = BadSyntheticDataIterator.get_bad_synthetic_data(config)
self.data_generator = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh)

def __iter__(self):
return self.data_generator

def __next__(self):
with self.mesh:
return self.data_generator(self.config)
return next(self.data_generator)

@staticmethod
def get_bad_synthetic_data(config):
"""fill negative value in synthetic data"""
output = {}
output["inputs"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["inputs_position"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["inputs_segmentation"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["targets"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["targets_position"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
output["targets_segmentation"] = jax.numpy.full(
(config.global_batch_size_to_load, config.max_target_length), -1, dtype=jax.numpy.int32
)
return output
output['inputs'] = tf.data.Dataset.from_tensor_slices(np.full((1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['inputs_position'] = tf.data.Dataset.from_tensor_slices(np.full((1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['inputs_segmentation'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['targets'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['targets_position'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
output['targets_segmentation'] = tf.data.Dataset.from_tensor_slices(np.full( (1, config.max_target_length),
-1, dtype=jax.numpy.int32))
dataset = tf.data.Dataset.zip((output)) # pytype: disable=wrong-arg-types
dataset = dataset.repeat()
dataset = dataset.batch(config.global_batch_size_to_load // jax.process_count())
return dataset


def get_process_loading_real_data(config, mesh):
Expand All @@ -187,8 +181,7 @@ def get_process_loading_real_data(config, mesh):

def make_mixed_train_iterator_and_tokenizer(config, mesh, add_bos, add_eos):
process_indices = get_process_loading_real_data(config, mesh)
print(len(process_indices), "hosts out of", jax.process_count(), "are loading real data")
if config.expansion_factor_real_data != -1: # assert number of hosts loading real data
if config.expansion_factor_real_data != -1: # assert number of hosts loading real data
assert len(process_indices) == jax.process_count() // config.expansion_factor_real_data
if jax.process_index() in process_indices:
if config.dataset_type == "c4":
Expand Down
10 changes: 10 additions & 0 deletions MaxText/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import jax.numpy as jnp
from jax import random
from jax.sharding import Mesh
from jax.experimental import checkify

from cloud_tpu_diagnostics import diagnostic
from cloud_tpu_diagnostics.configuration import debug_configuration
Expand Down Expand Up @@ -310,6 +311,14 @@ def record_goodput(recorder, config, step=None, job_start=False, job_end=False):
if step is not None:
recorder.record_step_start_time(step)

def check_example_batch(config, example_batch):
if config.max_checkify:
jittable_f = checkify.checkify(
lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!")
)
# Check if inputs in batch contains bad synthetic data.
err, _ = jax.jit(jittable_f)(example_batch['inputs'][: config.global_batch_size_to_train_on, :])
err.throw()

def setup_mesh_and_model(config):
"""Set up the mesh and the model for training
Expand Down Expand Up @@ -485,6 +494,7 @@ def train_loop(config, state=None):

with jax.profiler.StepTraceAnnotation("train", step_num=step):
example_batch = load_next_batch(data_iterator, example_batch, config)
check_example_batch(config, example_batch=example_batch)
nextrng = jax.jit(jax.random.fold_in)(init_rng, step)
record_goodput(recorder, config, step=step)
with mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
Expand Down

0 comments on commit 6ec7556

Please sign in to comment.