Skip to content

Commit a566160

Browse files
author
maxtext authors
committed
Merge pull request #578 from google:nina/gpu_xlml_llama_test
PiperOrigin-RevId: 623614558
2 parents 24d24b6 + fc216ff commit a566160

23 files changed

+1065
-0
lines changed

end_to_end/gpu/a3/test_llama2_7b.sh

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/bin/bash
2+
3+
# This file is both an integration test that runs once a day on a A3 and documentation for how to get started with Llama2-7b
4+
5+
# The flow of this file is as follows:
6+
# 1. Download the checkpoint from Meta (https://llama.meta.com/llama-downloads/) in your local directory. Convert this PyTorch checkpoint into Orbax checkpoint format for use in MaxText.
7+
# 2. Run training of Llama2-7b.
8+
# 3. Run decoding from the trained checkpoint.
9+
10+
11+
set -ex
12+
idx=$(date +%Y-%m-%d-%H-%M)
13+
14+
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
15+
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
16+
export ASYNC_CHECKPOINTING=false
17+
18+
# We install torch CPU because the checkpoint conversion script MaxText/llama_or_mistral_ckpt.py does not need a TPU/GPU
19+
pip install torch --index-url https://download.pytorch.org/whl/cpu
20+
21+
# We define a var for the path to the Meta checkpoint. Non-Googlers please remember to update the source `META_CHECKPOINT_PATH` to the GCS bucket where you have your Meta checkpoint
22+
export META_CHECKPOINT_PATH=gs://maxtext-llama/llama2-7b/meta-ckpt
23+
24+
# In the following command, we are copying Meta's checkpoint into a local directory `tmp`.
25+
# You can use a different local directory than /tmp/, if you do so, please use the same local path for `base-model-path` when running `python3 MaxText/llama_or_mistral_ckpt.py`
26+
gcloud storage cp -r ${META_CHECKPOINT_PATH} /tmp/
27+
28+
# `CONVERTED_CHECKPOINT_PATH` is the path to the GCS bucket where we want to save our converted (Orbax) checkpoint. Non-Googlers please remember to point `CONVERTED_CHECKPOINT_PATH` to a GCS bucket that you own
29+
export CONVERTED_CHECKPOINT_PATH=gs://maxtext-llama/test/${idx}/decode-ckpt-maxtext-gpu
30+
31+
#Next, run the conversion script `MaxText/llama_or_mistral_ckpt.py` to convert Meta's PyTorch checkpoint in `base-model-path` and save the new converted (Orbax) checkpoint in the `maxtext-model-path`
32+
python3 MaxText/llama_or_mistral_ckpt.py --base-model-path /tmp/meta-ckpt --model-size llama2-7b --maxtext-model-path ${CONVERTED_CHECKPOINT_PATH}
33+
34+
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory exactly inside `CONVERTED_CHECKPOINT_PATH`. This way it is easier to use this path in the `train.py` and `decode.py` commands
35+
export CONVERTED_CHECKPOINT=${CONVERTED_CHECKPOINT_PATH}/0/items
36+
37+
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
38+
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
39+
export DIRECT_PARAMETER_CHECKPOINT_RUN=direct_generate_param_only_checkpoint_${idx}
40+
python3 MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${DIRECT_PARAMETER_CHECKPOINT_RUN} model_name='llama2-7b' hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING}
41+
42+
export RUN_NAME="llama-2-1vm-$(date +%Y-%m-%d-%H-%M)"
43+
44+
# Set environment variables
45+
for ARGUMENT in "$@"; do
46+
IFS='=' read -r KEY VALUE <<< "$ARGUMENT"
47+
export "$KEY"="$VALUE"
48+
done
49+
50+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.85
51+
export CUDA_DEVICE_MAX_CONNECTIONS=1
52+
export NVTE_FUSED_ATTN=1
53+
export NCCL_DEBUG=VERSION
54+
55+
export XLA_FLAGS="--xla_dump_to=$BASE_OUTPUT_PATH/$RUN_NAME/HLO_dumps/
56+
--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true
57+
--xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false --xla_gpu_simplify_all_fp_conversions
58+
--xla_gpu_graph_level=0 --xla_gpu_enable_async_all_reduce=true --xla_gpu_enable_highest_priority_async_stream=true
59+
--xla_gpu_all_reduce_combine_threshold_bytes=134217728 --xla_gpu_all_gather_combine_threshold_bytes=134217728
60+
--xla_gpu_reduce_scatter_combine_threshold_bytes=67108864 --xla_gpu_enable_pipelined_all_gather=true
61+
--xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true
62+
--xla_gpu_enable_while_loop_double_buffering=true --xla_gpu_enable_triton_softmax_fusion=false
63+
--xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false
64+
--xla_disable_hlo_passes=rematerialization"
65+
66+
python MaxText/train.py MaxText/configs/base.yml run_name=$RUN_NAME hardware=gpu steps=30 dcn_data_parallelism=1 ici_fsdp_parallelism=8 per_device_batch_size=4 max_target_length=4096 model_name=llama2-7b enable_checkpointing=true attention=cudnn_flash_te remat_policy=minimal_flash use_iota_embed=true scan_layers=false dataset_type=synthetic async_checkpointing=${ASYNC_CHECKPOINTING} base_output_directory=$BASE_OUTPUT_DIRECTORY enable_profiler=false
67+
68+
export XLA_PYTHON_CLIENT_MEM_FRACTION=0.65
69+
export TF_FORCE_GPU_ALLOW_GROWTH=true
70+
71+
python3 MaxText/decode.py MaxText/configs/base.yml load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items run_name=runner_decode_finetuned_${idx} base_output_directory=${BASE_OUTPUT_DIRECTORY} per_device_batch_size=1 model_name='llama2-7b' ici_autoregressive_parallelism=4 max_prefill_predict_length=4 max_target_length=16 prompt="I love to" attention=dot_product scan_layers=false hardware=gpu async_checkpointing=${ASYNC_CHECKPOINTING}

end_to_end/tpu/eval_assert.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""
2+
Copyright 2023 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
# pylint: skip-file
18+
"""Reads and asserts over target values"""
19+
from absl import app
20+
from typing import Sequence
21+
from math import isclose
22+
from google.cloud import storage
23+
import json
24+
25+
26+
def compute_avg_metric(metrics_file, target, start_line=10):
27+
""" Reads and computes average of target value
28+
If start_line is negative then uses the last lines, e.g. start from end + 1 - |start_line|"""
29+
30+
31+
avg = 0
32+
i = 0
33+
with open(metrics_file, 'r', encoding='utf8') as file:
34+
lines = file.readlines()
35+
if start_line < 0:
36+
start_line = len(lines) + start_line
37+
for line in lines:
38+
# skip the first start_line lines for burn in
39+
if i >= start_line:
40+
vals = json.loads(line)
41+
avg += vals[target]
42+
i+=1
43+
avg /= (i-start_line)
44+
45+
return avg
46+
47+
48+
def assert_metric_average(metrics_file, threshold, target):
49+
avg_value = compute_avg_metric(metrics_file, target)
50+
# Checks for acceptable performance by asserting that the average metric (e.g. TFLOPs)
51+
# is greater than the threshold.
52+
print(f'avg value of target {target} is {avg_value}')
53+
assert avg_value >= float(threshold)
54+
print('assert metric average passed.')
55+
56+
def test_final_loss(metrics_file, target_loss):
57+
target_loss = float(target_loss)
58+
with open(metrics_file, 'r', encoding='utf8') as metrics:
59+
use_last_n_data = 10
60+
avg_final_loss = compute_avg_metric(metrics_file, 'learning/loss', start_line= -1 * use_last_n_data)
61+
print(f"Mean of last {use_last_n_data} losses is {avg_final_loss}")
62+
print(f"Target loss is {target_loss}")
63+
assert avg_final_loss < target_loss
64+
print('Final loss test passed.')
65+
66+
def test_checkpointing(metrics_file, target, dataset_type):
67+
"""Asserts over loss values from loaded checkpoint"""
68+
metrics_file_saved = 'saved_' + metrics_file
69+
metrics_file_restored = 'restored_' + metrics_file
70+
71+
with open(metrics_file_saved, 'r', encoding='utf8') as saved,\
72+
open(metrics_file_restored, 'r', encoding='utf8') as restored:
73+
saved_loss = json.loads(saved.readlines()[-1])[target]
74+
restored_loss = json.loads(restored.readlines()[0])[target]
75+
# Checks that checkpoint restore was successful by comparing loss of last
76+
# step in saved checkpoint to loss of first step in restored checkpoint
77+
print("saved loss: ", saved_loss)
78+
print("restored loss: ", restored_loss)
79+
if dataset_type=='c4':
80+
assert isclose(saved_loss, restored_loss, rel_tol=0.1)
81+
elif dataset_type=='c4-array_record':
82+
assert saved_loss==restored_loss
83+
else:
84+
raise ValueError(f"Unknown dataset_type {dataset_type}. dataset_type must be c4, c4-array_record or synthetic")
85+
print('checkpointing test passed.')
86+
87+
def test_determinism(metrics_file, target):
88+
"""Asserts over loss values from two runs"""
89+
run_1 = 'run_1_' + metrics_file
90+
run_2 = 'run_2_' + metrics_file
91+
92+
with open(run_1, 'r', encoding='utf8') as run_1_file,\
93+
open(run_2, 'r', encoding='utf8') as run_2_file:
94+
run_1_loss = json.loads(run_1_file.readlines()[-1])[target]
95+
run_2_loss = json.loads(run_2_file.readlines()[-1])[target]
96+
# Check that the two runs have the same loss
97+
print(f"Run 1 loss:{run_1_loss}", flush=True)
98+
print(f"Run 2 loss:{run_2_loss}", flush=True)
99+
assert run_1_loss==run_2_loss
100+
print('determinism test passed.')
101+
102+
def test_vocab_creation(target):
103+
bucket_name = target.split("/")[2]
104+
vocab_path = "/".join(target.split("/")[3:])
105+
storage_client = storage.Client()
106+
assert storage.Blob(bucket=storage_client.bucket(bucket_name), name=vocab_path).exists(storage_client)
107+
print('vocab creation test passed.')
108+
109+
def test_start_step(metrics_file, start_step_target):
110+
with open(metrics_file, 'r', encoding='utf8') as metrics:
111+
start_step = json.loads(metrics.readlines()[0])["step"]
112+
print(f"Start step is {start_step}, start step target is {start_step_target}")
113+
assert start_step==float(start_step_target)
114+
print("Start step test passed.")
115+
116+
def main(argv: Sequence[str]) -> None:
117+
118+
_, test_scenario, *test_vars = argv
119+
120+
if test_scenario == 'metrics_average':
121+
assert_metric_average(*test_vars)
122+
elif test_scenario == 'checkpoint_save_restore':
123+
test_checkpointing(*test_vars, dataset_type='c4')
124+
elif test_scenario == 'grain_checkpoint_save_restore':
125+
test_checkpointing(*test_vars, dataset_type='c4-array_record')
126+
elif test_scenario == 'determinism':
127+
test_determinism(*test_vars)
128+
elif test_scenario == 'vocab_creation':
129+
test_vocab_creation(*test_vars)
130+
elif test_scenario == 'final_loss':
131+
test_final_loss(*test_vars)
132+
elif test_scenario == 'test_start_step':
133+
test_start_step(*test_vars)
134+
else:
135+
raise ValueError(f"Unrecognized test_scenario {test_scenario}")
136+
137+
138+
if __name__ == "__main__":
139+
app.run(main)

end_to_end/tpu/gemma/2b/test_gemma.sh

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/bin/bash
2+
3+
# This file is both an integration test that runs once a day on a v4-8 and documentation for how to get started with Gemma-2b.
4+
5+
# The flow of this file is as follows:
6+
# 1. Convert the checkpoint downloaded from Kaggle to make it compatible with MaxText
7+
# 2. Run decoding, finetuning of Gemma 2B with the converted checkpoint. Also, run pretraining of Gemma 2B
8+
# 3. Convert the scanned checkpoint from step 1 into unscanned checkpoint format and run more efficient decoding.
9+
# 4. Run decoding from the finetuned checkpoint from step 2
10+
# 5. Ahead of Time Compilation for running Gemma 2B on v5e-256
11+
12+
13+
set -ex
14+
idx=$(date +%Y-%m-%d-%H-%M)
15+
export MODEL_VARIATION='2b'
16+
17+
# After downloading checkpoints, copy them to GCS bucket at $CHKPT_BUCKET \
18+
# Non-Googlers please remember to use seperate GCS paths for uploading model weights from kaggle ($CHKPT_BUCKET) and MaxText compatible weights ($MODEL_BUCKET).
19+
# Non-Googlers please remember to point these variables to GCS buckets that you own, this script uses internal buckets for testing.
20+
export CHKPT_BUCKET=gs://maxtext-gemma/flax
21+
export MODEL_BUCKET=gs://maxtext-gemma
22+
python MaxText/convert_gemma_chkpt.py --base_model_path ${CHKPT_BUCKET}/${MODEL_VARIATION} --maxtext_model_path ${MODEL_BUCKET}/${MODEL_VARIATION}/${idx} --model_size ${MODEL_VARIATION}
23+
24+
# Non-Googlers please remember to point `DATASET_PATH` to the GCS bucket where you have your training data
25+
export DATASET_PATH=gs://maxtext-dataset
26+
# Non-Googlers please remember to point `BASE_OUTPUT_DIRECTORY` to a GCS bucket that you own, this bucket will store all the files generated by MaxText during a run
27+
export BASE_OUTPUT_DIRECTORY=gs://runner-maxtext-logs
28+
# We define `CONVERTED_CHECKPOINT` to refer to the checkpoint subdirectory. This way it is easier to use this path in the `train.py` and `decode.py` commands
29+
export CONVERTED_CHECKPOINT=${MODEL_BUCKET}/${MODEL_VARIATION}/${idx}/0/items
30+
export RUN_NAME=unscanned_chkpt_${idx}
31+
# Note that the `CONVERTED_CHECKPOINT` is in a `scanned` format which is great for training but for efficient decoding performance we want the checkpoint in an `unscanned` format.
32+
# We can do this by running `MaxText/generate_param_only_checkpoint.py` on `CONVERTED_CHECKPOINT` with `force_unroll=true`.
33+
python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_parameters_path=${CONVERTED_CHECKPOINT} run_name=${RUN_NAME} model_name='gemma-2b' force_unroll=true
34+
35+
export UNSCANNED_CKPT_PATH=${BASE_OUTPUT_DIRECTORY}/${RUN_NAME}/checkpoints/0/items
36+
37+
# We run decoding on the `UNSCANNED_CKPT_PATH` for efficient decoding on the unscanned version of the checkpoint. Note that this checkpoint only has parameters and no optimizer state.
38+
# So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
39+
# We compare our decoded results by asserting with golden outputs using `autoregressive_decode_assert`
40+
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${UNSCANNED_CKPT_PATH} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" travel and I love to write about it"
41+
42+
# We can also run decoding (albeit in a bit unoptimized way) by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Note again that this checkpoint only has parameters and no optimizer state. So, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`
43+
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false model_name=gemma-2b attention=dot_product prompt="I love to" autoregressive_decode_assert=" cook and bake. I love to eat"
44+
45+
# Alternatively, we skip to running finetuning by using the scanned converted checkpoint located at `CONVERTED_CHECKPOINT`. Again, we use it by specifying`load_parameters_path=${CONVERTED_CHECKPOINT}`. Note that scanned checkpoint helps with efficient finetuning
46+
export FINETUNE_RUN_NAME=runner_finetune_${idx}
47+
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma load_parameters_path=${CONVERTED_CHECKPOINT} per_device_batch_size=1 run_name=${FINETUNE_RUN_NAME} max_target_length=8192 steps=10 async_checkpointing=false model_name=gemma-2b checkpoint_period=5
48+
49+
# We also run pre-training, this is similar to the finetuning command except we don't pass any checkpoint directory to load parameters from
50+
python MaxText/train.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} dataset_path=${DATASET_PATH} tokenizer_path=assets/tokenizer.gemma per_device_batch_size=1 run_name=runner_pretrain_${idx} max_target_length=8192 steps=5 enable_checkpointing=false model_name=gemma-2b
51+
52+
# Note that the finetune run checkpoint generates the `full state` which has both parameters and optimizer state. For decoding, we only need to use the parameters.
53+
# So, we can use the `MaxText/generate_param_only_checkpoint.py` to convert the full state checkpoint into a parameter only checkpoint for more efficient memory use. Note that the path provided to the flag `load_full_state_path` is the path to the checkpoint subdirectory inside the `BASE_OUTPUT_DIRECTORY` from our previous finetuning run.
54+
# `force_unroll=true` is converting the output parameter only checkpoint into an unscanned format for efficient decoding
55+
export PARAM_RUN_NAME=param_chkpt_${idx}
56+
python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_DIRECTORY} load_full_state_path=${BASE_OUTPUT_DIRECTORY}/${FINETUNE_RUN_NAME}/checkpoints/5/items run_name=${PARAM_RUN_NAME} model_name='gemma-2b' force_unroll=true
57+
58+
# Now, run decoding on the checkpoint generated from our finetune run.
59+
python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=assets/tokenizer.gemma load_parameters_path=${BASE_OUTPUT_DIRECTORY}/${PARAM_RUN_NAME}/checkpoints/0/items per_device_batch_size=1 run_name=runner_$(date +%Y-%m-%d-%H-%M) max_prefill_predict_length=8 max_target_length=16 dataset_type=synthetic steps=10 async_checkpointing=false scan_layers=false model_name=gemma-2b attention=dot_product prompt="I love to"
60+
61+
# We recommend training/finetuning Gemma on v5e-256 using the following sharding strategy to achieve optimal performance.
62+
# This below command does Ahead Of Time Cross Compilation (https://github.com/google/maxtext?tab=readme-ov-file#ahead-of-time-compilation-aot) for our recommended v5e-256 configuration for Gemma 2B.
63+
# To actually run it on real v5e-256's simple replace the train_compile.py with a train.py and get rid of compile_topology args.
64+
python MaxText/train_compile.py MaxText/configs/base.yml model_name=gemma-2b ici_fsdp_transpose_parallelism=16 per_device_batch_size=2 compile_topology=v5e-256 compile_topology_num_slices=1

0 commit comments

Comments
 (0)