-
Notifications
You must be signed in to change notification settings - Fork 313
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable pygrain v3 #330
Enable pygrain v3 #330
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's talk live to figure out a path forwards here. This adds too much duplication and complexity to MaxText IMO.
@@ -15,12 +15,14 @@ | |||
""" | |||
|
|||
"""Create an Orbax CheckpointManager with specified (Async or not) Checkpointer.""" | |||
# pylint: disable=line-too-long |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldn't do this!
) | ||
max_logging.log("Checkpoint manager created!") | ||
return mngr | ||
|
||
def create_orbax_checkpoint_manager_pygrain( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm very worried about this level of duplication making MaxText harder to understand and use. I think we have to hold back on this change until we're ready to always recommend Grain based on just this added mental load for users.
(I do think we can land lots of things separately.)
first_checkpoint_path: if there is no checkpoint in the checkpoint manager, | ||
return the Params from the first_checkpoint_path if they exist. This | ||
enables loading just the parameters and is intended for finetuning. | ||
load_parameters_path: This enables loading just the parameters and is intended |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this name. It was driving me crazy the last time I read the code.
max_logging.log(f"restoring state from this run's directory latest step \ | ||
{latest_step}") | ||
return checkpoint_manager.restore(latest_step, abstract_unboxed_pre_state, | ||
# Set restore_args based whether to load data iterator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm very stressed by this code and don't feel comfortable pushing this in our "simple" reference for customers.
@@ -118,13 +117,22 @@ dataset_path: "" | |||
vocab_size: 32_768 # powers of 2 for sharding | |||
assets_path: "assets" | |||
vocab_relative_path: "tokenizer" # Assumes we're allowed | |||
dataset_name: 'c4/en:3.0.1' | |||
# When using c4-array_record dataset_type, use subfolder path as dataset_name | |||
# array_record files should be located in <dataset_path>/<dataset_name>/*.array_record* |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a script for generating this data?
@@ -80,6 +85,37 @@ def _normalize_features(features): | |||
num_parallel_calls=AUTOTUNE) | |||
|
|||
|
|||
def length_trim(ds, max_len): | |||
""""Trim to Max length""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you -- this should merge on a separate CR!
Note that grain use an optimized packing algorithm to select samples to pack, resulting in denser packing, so we see higher loss overall compared with the original tfds pipeline. In convergence test with grain, loss started >11 and goes down to ~3: https://pantheon.corp.google.com/kubernetes/service/us-east5/v5e-256-bodaborg/default/aireen-v5e256-0104-1257/logs?e=13802955&mods=allow_workbench_image_override&project=tpu-prod-env-multipod
convergence test on the same branch but with tfds pipeline, loss started ~10 and goes down to ~2.6: https://pantheon.corp.google.com/kubernetes/service/us-east5/v5e-256-bodaborg/default/aireen-v5e256-0104-1236/logs?e=13802955&mods=allow_workbench_image_override&project=tpu-prod-env-multipod
This request for trimming long sequences instead of dropping is also addressed. (#274)
Configs for these use cases:
Items I will work on in the next PRs: