-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #529 from google:mohit/test_gemma
PiperOrigin-RevId: 621236206
- Loading branch information
Showing
8 changed files
with
189 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#!/bin/bash | ||
|
||
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma-2b. | ||
|
||
# The flow of this file is as follows: | ||
# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText | ||
# 2. Run decoding, finetuning of Gemma 2B with the converted checkpoint. Also, run pretraining of Gemma 2B | ||
# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. | ||
# 4. Run decoding from the finetuned checkpoint from step 2 | ||
# 5. Ahead of Time Compilation for running Gemma 2B on v5e-256 | ||
|
||
|
||
set -ex | ||
idx=$(date +%Y-%m-%d-%H-%M) | ||
export MODEL_VARIATION='2b' | ||
|
||
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ | ||
# Non-Googlers please remember to use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). | ||
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. | ||
export CHKPT_BUCKET=gs://maxtext-gemma/flax | ||
export MODEL_BUCKET=gs://maxtext-gemma | ||
python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} | ||
|
||
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data | ||
export DATASET_PATH=gs://maxtext-dataset | ||
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run | ||
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs | ||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands | ||
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items | ||
export RUN_NAME=unscanned_chkpt_${idx} | ||
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. | ||
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. | ||
python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-2b' force_unroll=true | ||
|
||
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items | ||
|
||
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. | ||
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write about it" | ||
|
||
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" | ||
|
||
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning | ||
export FINETUNE_RUN_NAME=runner_finetune_${idx} | ||
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5 | ||
|
||
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from | ||
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b | ||
|
||
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. | ||
# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. | ||
# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding | ||
export PARAM_RUN_NAME=param_chkpt_${idx} | ||
python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true | ||
|
||
# Now, run decoding on the checkpoint generated from our finetune run. | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" | ||
|
||
# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. | ||
# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B. | ||
# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. | ||
python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
#!/bin/bash | ||
|
||
# This file, combined with step 2 in the same directory, demonstrates converting a Gemma checkpoint from Kaggle and running various MaxText operations on it. | ||
# This step is tested nightly on an ordinary CPU VM. | ||
|
||
# The flow of this file is as follows: | ||
# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. | ||
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. | ||
|
||
# Example Usage: bash end_to_end/gemma/7b/1_test_gemma.sh | ||
set -ex | ||
idx=$(date +%Y-%m-%d-%H-%M) | ||
MODEL_VARIATION='7b' | ||
|
||
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run | ||
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs | ||
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ | ||
# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). | ||
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. | ||
export CHKPT_BUCKET=gs://maxtext-gemma/flax | ||
export MODEL_BUCKET=gs://maxtext-gemma | ||
JAX_PLATFORMS=cpu python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} | ||
echo "Writen MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}" | ||
|
||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. | ||
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items | ||
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. | ||
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. | ||
export RUN_NAME=unscanned_chkpt_${idx} | ||
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true | ||
echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
#!/bin/bash | ||
|
||
# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma-7b. | ||
# Please make sure you have run end_to_end/gemma/7b/1_test_gemma.sh before running commands from this file. | ||
|
||
# The flow of this file is as follows: | ||
# 1. Run decoding, finetuning of Gemma 7B with the converted checkpoint obtained from end_to_end/gemma/7b/1_test_gemma.sh. Also, run pretraining of Gemma 7B | ||
# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. | ||
# 3. Run decoding from the finetuned checkpoint from step 1 | ||
# 4. Ahead of Time Compilation for running Gemma 7B on v5e-256 | ||
|
||
set -ex | ||
idx=$(date +%Y-%m-%d-%H-%M) | ||
export MODEL_VARIATION='7b' | ||
|
||
# Non-Googlers please remember to MODEL_BUCKET to GCS bucket where this script uses internal buckets for testing. | ||
export MODEL_BUCKET=gs://maxtext-gemma | ||
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data | ||
export DATASET_PATH=gs://maxtext-dataset | ||
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run | ||
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs | ||
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands | ||
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items | ||
export RUN_NAME=unscanned_chkpt_${idx} | ||
# We defined path to unscanned checkpoint created in 1_test_gemma.sh | ||
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items | ||
|
||
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. | ||
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people’s faces" | ||
|
||
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people's faces" | ||
|
||
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning | ||
export FINETUNE_RUN_NAME=runner_finetune_${idx} | ||
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-7b checkpoint_period=5 | ||
|
||
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from | ||
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b | ||
|
||
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. | ||
# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. | ||
# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding | ||
export PARAM_RUN_NAME=param_chkpt_${idx} | ||
python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true | ||
|
||
# Now, run decoding on the checkpoint generated from our finetune run. | ||
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" | ||
|
||
# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. | ||
# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. | ||
# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. | ||
python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
<!-- | ||
Copyright 2023 Google LLC | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
https://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
--> | ||
|
||
# Gemma | ||
[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the art open models built from research and technology that we used to create the Gemini models. | ||
|
||
Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). | ||
|
||
After downloading the weights run [test_convert_chkpt.sh](https://github.com/google/maxtext/blob/main/end_to_end/gemma/test_convert_chkpt.sh), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma). | ||
|
||
## MaxText supports pretraining and finetuning with high performance | ||
|
||
Model Flop utilization for training on v5e and v5p TPUs. | ||
|
||
| Model | v5e-256 (bf16) | v5p-128 (bf16) | v5e-256 (int8) | v5p-128 (int8) | | ||
| -------- | -------------- | -------------- | -------------- | -------------- | | ||
| Gemma-2b | 58% | 55% | 64% | 68% | | ||
| Gemma-7b | 58% | 60% | 70% | 70% | |
Oops, something went wrong.