Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable pygrain v3 #330

Closed
wants to merge 23 commits into from
Closed

Enable pygrain v3 #330

wants to merge 23 commits into from

Conversation

aireenmei
Copy link
Collaborator

@aireenmei aireenmei commented Jan 12, 2024

Note that grain use an optimized packing algorithm to select samples to pack, resulting in denser packing, so we see higher loss overall compared with the original tfds pipeline. In convergence test with grain, loss started >11 and goes down to ~3: https://pantheon.corp.google.com/kubernetes/service/us-east5/v5e-256-bodaborg/default/aireen-v5e256-0104-1257/logs?e=13802955&mods=allow_workbench_image_override&project=tpu-prod-env-multipod

convergence test on the same branch but with tfds pipeline, loss started ~10 and goes down to ~2.6: https://pantheon.corp.google.com/kubernetes/service/us-east5/v5e-256-bodaborg/default/aireen-v5e256-0104-1236/logs?e=13802955&mods=allow_workbench_image_override&project=tpu-prod-env-multipod

This request for trimming long sequences instead of dropping is also addressed. (#274)

Configs for these use cases:

  1. Pretrain with grain (end_to_end/test_convergence_1b_params_grain.sh): run gcsfuse_setup.sh to mount gcs bucket, then train with "dataset_type=c4-array_record". When "dataset_type=c4-array_record" is set, the newly saved training checkpoint always contains data iterator.
  2. When resume training from ckpt, user can choose to resume data iterator or not by setting load_data_iterator_from_checkpoint, will raise error when no data iterator found in ckpt.
  3. When finetune, user set load_parameters_path to load only parameters, data iterator won't be loaded.
  4. When decode, user set load_parameters_path to load only parameters, data iterator won't be loaded, no data iterator in decode ckpt.

Items I will work on in the next PRs:

  1. add tests (unit tests and github workflow)
  2. For supporting Llama2, adding add_bos and add_eos to the way grain loads tokenizer currently results in OOM on GPU "Test train.py with per_device_batch_size < 1". I will add later, appreciate any suggestion.

@aireenmei aireenmei marked this pull request as ready for review January 12, 2024 16:18
@aireenmei aireenmei requested a review from rwitten as a code owner January 12, 2024 16:18
Copy link
Collaborator

@rwitten rwitten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's talk live to figure out a path forwards here. This adds too much duplication and complexity to MaxText IMO.

@@ -15,12 +15,14 @@
"""

"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer."""
# pylint: disable=line-too-long
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't do this!

)
max_logging.log("Checkpoint manager created!")
return mngr

def create_orbax_checkpoint_manager_pygrain(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very worried about this level of duplication making MaxText harder to understand and use. I think we have to hold back on this change until we're ready to always recommend Grain based on just this added mental load for users.

(I do think we can land lots of things separately.)

first_checkpoint_path: if there is no checkpoint in the checkpoint manager,
return the Params from the first_checkpoint_path if they exist. This
enables loading just the parameters and is intended for finetuning.
load_parameters_path: This enables loading just the parameters and is intended
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this name. It was driving me crazy the last time I read the code.

max_logging.log(f"restoring state from this run's directory latest step \
{latest_step}")
return checkpoint_manager.restore(latest_step, abstract_unboxed_pre_state,
# Set restore_args based whether to load data iterator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm very stressed by this code and don't feel comfortable pushing this in our "simple" reference for customers.

@@ -118,13 +117,22 @@ dataset_path: ""
vocab_size: 32_768 # powers of 2 for sharding
assets_path: "assets"
vocab_relative_path: "tokenizer" # Assumes we're allowed
dataset_name: 'c4/en:3.0.1'
# When using c4-array_record dataset_type, use subfolder path as dataset_name
# array_record files should be located in <dataset_path>/<dataset_name>/*.array_record*
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a script for generating this data?

@@ -80,6 +85,37 @@ def _normalize_features(features):
num_parallel_calls=AUTOTUNE)


def length_trim(ds, max_len):
""""Trim to Max length"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you -- this should merge on a separate CR!

@rwitten rwitten removed their assignment Jan 18, 2024
@aireenmei aireenmei closed this Feb 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants