diff --git a/README.md b/README.md index 0be06cef6..9dafc9e05 100644 --- a/README.md +++ b/README.md @@ -212,10 +212,3 @@ base_output_directory=gs://my-output-bucket dataset_path=gs://my-dataset-bucket ``` In the save step of example 2 above we included exporting the compiler flag `LIBTPU_INIT_ARGS` and `learning_rate` because those affect the compiled object `my_compiled_train.pickle.` The sizes of the model (e.g. `global_parameter_scale`, `max_sequence_length` and `per_device_batch`) are fixed when you initially compile via `compile_train.py`, you will see a size error if you try to run the saved compiled object with different sizes than you compiled with. However a subtle note is that the **learning rate schedule** is also fixed when you run `compile_train` - which is determined by both `steps` and `learning_rate`. The optimizer parameters such as `adam_b1` are passed only as shaped objects to the compiler - thus their real values are determined when you run `train.py`, not during the compilation. If you do pass in different shapes (e.g. `per_device_batch`), you will get a clear error message reporting that the compiled signature has different expected shapes than what was input. If you attempt to run on different hardware than the compilation targets requested via `compile_topology`, you will get an error saying there is a failure to map the devices from the compiled to your real devices. Using different XLA flags or a LIBTPU than what was compiled will probably run silently with the environment you compiled in without error. However there is no guaranteed behavior in this case; you should run in the same environment you compiled in. - -## Supported Open Models - -MaxText supports training and inference of various open models. Follow user guides under [getting started](https://github.com/google/maxtext/tree/main/getting_started) section to know more. - -* [Gemma](https://ai.google.dev/gemma): a family of open-weights Large Language Model (LLM) by [Google DeepMind](https://deepmind.google/), based on Gemini research and technology. \ -You can run decode and finetuning using instructions mentioned [here](https://github.com/google/maxtext/blob/main/getting_started/Run_Gemma.md). diff --git a/end_to_end/gemma/2b/test_gemma.sh b/end_to_end/gemma/2b/test_gemma.sh new file mode 100644 index 000000000..74d776951 --- /dev/null +++ b/end_to_end/gemma/2b/test_gemma.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma-2b. + +# The flow of this file is as follows: +# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText +# 2. Run decoding, finetuning of Gemma 2B with the converted checkpoint. Also, run pretraining of Gemma 2B +# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. +# 4. Run decoding from the finetuned checkpoint from step 2 +# 5. Ahead of Time Compilation for running Gemma 2B on v5e-256 + + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +export MODEL_VARIATION='2b' + +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# Non-Googlers please remember to use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. +export CHKPT_BUCKET=gs://maxtext-gemma/flax +export MODEL_BUCKET=gs://maxtext-gemma +python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} + +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands +export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items +export RUN_NAME=unscanned_chkpt_${idx} +# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. +# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. +python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-2b' force_unroll=true + +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items + +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write about it" + +# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" + +# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning +export FINETUNE_RUN_NAME=runner_finetune_${idx} +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5 + +# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b + +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. +# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding +export PARAM_RUN_NAME=param_chkpt_${idx} +python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true + +# Now, run decoding on the checkpoint generated from our finetune run. +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" + +# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. +# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B. +# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. +python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/end_to_end/gemma/7b/1_test_gemma.sh b/end_to_end/gemma/7b/1_test_gemma.sh new file mode 100644 index 000000000..a718e844b --- /dev/null +++ b/end_to_end/gemma/7b/1_test_gemma.sh @@ -0,0 +1,31 @@ +#!/bin/bash + +# This file, combined with step 2 in the same directory, demonstrates converting a Gemma checkpoint from Kaggle and running various MaxText operations on it. +# This step is tested nightly on an ordinary CPU VM. + +# The flow of this file is as follows: +# 1. Pull the checkpoint from a GCS bucket and uploads the new MaxText compatible checkpoint to destination GCS bucket. +# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. + +# Example Usage: bash end_to_end/gemma/7b/1_test_gemma.sh +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +MODEL_VARIATION='7b' + +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \ +# Please use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET). +# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing. +export CHKPT_BUCKET=gs://maxtext-gemma/flax +export MODEL_BUCKET=gs://maxtext-gemma +JAX_PLATFORMS=cpu python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION} +echo "Writen MaxText compatible checkpoint to ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}" + +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. +export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items +# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format. +# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`. +export RUN_NAME=unscanned_chkpt_${idx} +JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-7b' force_unroll=true +echo "Written MaxText compatible unscanned checkpoint to ${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items" diff --git a/end_to_end/gemma/7b/2_test_gemma.sh b/end_to_end/gemma/7b/2_test_gemma.sh new file mode 100644 index 000000000..6f8e37b79 --- /dev/null +++ b/end_to_end/gemma/7b/2_test_gemma.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# This file is both an integration test that runs once a day on a v4-16 and documentation for how to get started with Gemma-7b. +# Please make sure you have run end_to_end/gemma/7b/1_test_gemma.sh before running commands from this file. + +# The flow of this file is as follows: +# 1. Run decoding, finetuning of Gemma 7B with the converted checkpoint obtained from end_to_end/gemma/7b/1_test_gemma.sh. Also, run pretraining of Gemma 7B +# 2. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding. +# 3. Run decoding from the finetuned checkpoint from step 1 +# 4. Ahead of Time Compilation for running Gemma 7B on v5e-256 + +set -ex +idx=$(date +%Y-%m-%d-%H-%M) +export MODEL_VARIATION='7b' + +# Non-Googlers please remember to MODEL_BUCKET to GCS bucket where this script uses internal buckets for testing. +export MODEL_BUCKET=gs://maxtext-gemma +# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data +export DATASET_PATH=gs://maxtext-dataset +# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run +export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs +# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands +export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items +export RUN_NAME=unscanned_chkpt_${idx} +# We defined path to unscanned checkpoint created in 1_test_gemma.sh +export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items + +# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state. +# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people’s faces" + +# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}` +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" see the look on people's faces" + +# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning +export FINETUNE_RUN_NAME=runner_finetune_${idx} +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-7b checkpoint_period=5 + +# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from +python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-7b + +# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters. +# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run. +# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding +export PARAM_RUN_NAME=param_chkpt_${idx} +python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-7b' force_unroll=true + +# Now, run decoding on the checkpoint generated from our finetune run. +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-7b attention=dot_product prompt="I love to" + +# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance. +# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 7B. +# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args. +python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-7b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1 diff --git a/end_to_end/gemma/Run_Gemma.md b/end_to_end/gemma/Run_Gemma.md new file mode 100644 index 000000000..b099cf883 --- /dev/null +++ b/end_to_end/gemma/Run_Gemma.md @@ -0,0 +1,31 @@ + + +# Gemma +[Gemma](https://ai.google.dev/gemma) is a family of lightweight, state-of-the art open models built from research and technology that we used to create the Gemini models. + +Following the instructions at [kaggle](https://www.kaggle.com/models/google/gemma/frameworks/maxText) will let you download Gemma model weights. You will have to consent to license for Gemma using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). + +After downloading the weights run [test_convert_chkpt.sh](https://github.com/google/maxtext/blob/main/end_to_end/gemma/test_convert_chkpt.sh), which converts the checkpoint to be compatible with MaxText and uploads them to a GCS bucket. You can run decode and finetuning using instructions mentioned in the test scripts at [end_to_end/gemma](https://github.com/google/maxtext/blob/main/end_to_end/gemma). + +## MaxText supports pretraining and finetuning with high performance + +Model Flop utilization for training on v5e and v5p TPUs. + +| Model | v5e-256 (bf16) | v5p-128 (bf16) | v5e-256 (int8) | v5p-128 (int8) | +| -------- | -------------- | -------------- | -------------- | -------------- | +| Gemma-2b | 58% | 55% | 64% | 68% | +| Gemma-7b | 58% | 60% | 70% | 70% | diff --git a/end_to_end/test_decode.sh b/end_to_end/test_decode.sh index 4337339d2..fc432e025 100644 --- a/end_to_end/test_decode.sh +++ b/end_to_end/test_decode.sh @@ -26,3 +26,11 @@ python3 MaxText/decode.py MaxText/configs/base.yml run_name=$RUN_NAME\ steps=50 enable_checkpointing=False metrics_file=/tmp/${RUN_NAME}_metrics.txt \ base_output_directory=$OUTPUT_PATH dataset_path=$DATASET_PATH \ attention=dot_product ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} + + +# Get latest converted Gemma 2B checkpoint from internal GCS bucket +export GEMMA_2B_CKPT_PATH=$(gsutil ls gs://maxtext-gemma/2b | sort -r | head -1) +# Decode with different sampling strategies. +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=weighted decode_sampling_temperature=.00001 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=nucleus decode_sampling_nucleus_p=0 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" +python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${GEMMA_2B_CKPT_PATH}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product decode_sampling_strategy=topk decode_sampling_top_k=1 prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat" diff --git a/end_to_end/test_gemma.sh b/end_to_end/test_gemma.sh deleted file mode 100644 index 897fda4f2..000000000 --- a/end_to_end/test_gemma.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -set -ex -idx=$(date +%Y-%m-%d-%H-%M) -# convert 2.5B checkpoint -export base_model_path=gs://maxtext-gemma/flax/2b -export maxtext_model_path=gs://maxtext-gemma/2b/${idx} -python MaxText/convert_gemma_chkpt.py --base_model_path ${base_model_path} --maxtext_model_path ${maxtext_model_path} --model_size 2b -# Test Gemma 2.5B decode -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=weighted decode_sampling_temperature=.00001 -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=nucleus decode_sampling_nucleus_p=0 -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write. I" decode_sampling_strategy=topk decode_sampling_top_k=1 - -# convert 7B checkpoint -export base_model_path=gs://maxtext-gemma/flax/7b -export maxtext_model_path=gs://maxtext-gemma/7b/${idx} -python MaxText/convert_gemma_chkpt.py --base_model_path ${base_model_path} --maxtext_model_path ${maxtext_model_path} --model_size 7b -# Test Gemma 7B decode -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${maxtext_model_path}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-7b attention=dot_product prompt="I love to" autoregressive_decode_assert=" use this product in my hair. It" diff --git a/getting_started/Run_Gemma.md b/getting_started/Run_Gemma.md deleted file mode 100644 index 6ca618acb..000000000 --- a/getting_started/Run_Gemma.md +++ /dev/null @@ -1,56 +0,0 @@ - - -## About Gemma - -Gemma is a family of lightweight, state-of-the art open models built from research and technology that we used to create the Gemini models. To get started on decoding and finetuning of Gemma, you will first need to download weights from [kaggle](https://www.kaggle.com/models/google/gemma?rvi=1) - -Following commands will let you download Gemma-2B model weights along with its tokenizer, convert the orbax checkpoints to be compatible with MaxText and upload it to a GCS bucket. \ -Values for environment variables $KAGGLE_USERNAME and $KAGGLE_KEY can be set using your kaggle account's [API credentials](https://github.com/Kaggle/kaggle-api?tab=readme-ov-file#api-credentials). \ -Please use seperate GCS buckets for uploading model weights from kaggle ($MODEL_BUCKET) and MaxText compatible weights ($CHKPT_BUCKET). -``` -wget https://www.kaggle.com/api/v1/models/google/gemma/maxtext/2b/1/download --user=$KAGGLE_USERNAME --password=$KAGGLE_KEY --auth-no-challenge -# Extract downloaded model -tar -xf download -# export variables $CHKPT_BUCKET and $MODEL_BUCKET which are google cloud buckets to store weights -gsutil -m cp -r 2b/* $CHKPT_BUCKET/2b -gsutil -m cp tokenizer.model $CHKPT_BUCKET/tokenizer.model - -python MaxText/convert_gemma_chkpt.py --base_model_path $CHKPT_BUCKET/2b --maxtext_model_path $MODEL_BUCKET/2b --model_size 2b -``` - -### Run `decode.py`. - -``` -python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=$CHKPT_BUCKET/tokenizer.model load_parameters_path=$MODEL_BUCKET/{MODEL_VARIATION}/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=64 dataset_type=synthetic steps=10 async_checkpointing=false attention=dot_product model_name=gemma-2b prompt="Kaggle is good for" -``` - -### MaxText supports fine-tuning with high performance. - -Command for training Gemma-2b from scratch on 1 slice of v5e-256. -``` -python MaxText/train.py MaxText/configs/base.yml base_output_directory=$BASE_OUTPUT_DIR model_name=gemma-2b dataset_path=$DATASET_PATH enable_checkpointing=false tokenizer_path=$CHKPT_BUCKET/tokenizer.model steps=10 ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 remat_policy=minimal max_target_length=8192 -``` - -### Performance - -Model Flop utilization for training on v5e and v5p TPUs. - - -| Model | v5e-256 (bf16) | v5p-128 (bf16) | v5e-256 (int8) | v5p-128 (int8) | -| -------- | -------------- | -------------- | -------------- | -------------- | -| Gemma-2b | 58.21% | 55.36% | 64.68% | 67.80% | -| Gemma-7b | 57.70% | 60.16% | 70.31% | 70.12% |